1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BuiltinDialect.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/IntegerSet.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/IR/SymbolTable.h"
18 #include "mlir/IR/Types.h"
19 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
20 #include "llvm/ADT/APSInt.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/Support/Endian.h"
23 
24 using namespace mlir;
25 using namespace mlir::detail;
26 
27 //===----------------------------------------------------------------------===//
28 /// Tablegen Attribute Definitions
29 //===----------------------------------------------------------------------===//
30 
31 #define GET_ATTRDEF_CLASSES
32 #include "mlir/IR/BuiltinAttributes.cpp.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // BuiltinDialect
36 //===----------------------------------------------------------------------===//
37 
38 void BuiltinDialect::registerAttributes() {
39   addAttributes<AffineMapAttr, ArrayAttr, DenseArrayBaseAttr,
40                 DenseIntOrFPElementsAttr, DenseStringElementsAttr,
41                 DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr,
42                 IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
43                 SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>();
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // ArrayAttr
48 //===----------------------------------------------------------------------===//
49 
50 void ArrayAttr::walkImmediateSubElements(
51     function_ref<void(Attribute)> walkAttrsFn,
52     function_ref<void(Type)> walkTypesFn) const {
53   for (Attribute attr : getValue())
54     walkAttrsFn(attr);
55 }
56 
57 SubElementAttrInterface ArrayAttr::replaceImmediateSubAttribute(
58     ArrayRef<std::pair<size_t, Attribute>> replacements) const {
59   std::vector<Attribute> vector = getValue().vec();
60   for (auto &it : replacements) {
61     vector[it.first] = it.second;
62   }
63   return get(getContext(), vector);
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // DictionaryAttr
68 //===----------------------------------------------------------------------===//
69 
70 /// Helper function that does either an in place sort or sorts from source array
71 /// into destination. If inPlace then storage is both the source and the
72 /// destination, else value is the source and storage destination. Returns
73 /// whether source was sorted.
74 template <bool inPlace>
75 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
76                                SmallVectorImpl<NamedAttribute> &storage) {
77   // Specialize for the common case.
78   switch (value.size()) {
79   case 0:
80     // Zero already sorted.
81     if (!inPlace)
82       storage.clear();
83     break;
84   case 1:
85     // One already sorted but may need to be copied.
86     if (!inPlace)
87       storage.assign({value[0]});
88     break;
89   case 2: {
90     bool isSorted = value[0] < value[1];
91     if (inPlace) {
92       if (!isSorted)
93         std::swap(storage[0], storage[1]);
94     } else if (isSorted) {
95       storage.assign({value[0], value[1]});
96     } else {
97       storage.assign({value[1], value[0]});
98     }
99     return !isSorted;
100   }
101   default:
102     if (!inPlace)
103       storage.assign(value.begin(), value.end());
104     // Check to see they are sorted already.
105     bool isSorted = llvm::is_sorted(value);
106     // If not, do a general sort.
107     if (!isSorted)
108       llvm::array_pod_sort(storage.begin(), storage.end());
109     return !isSorted;
110   }
111   return false;
112 }
113 
114 /// Returns an entry with a duplicate name from the given sorted array of named
115 /// attributes. Returns llvm::None if all elements have unique names.
116 static Optional<NamedAttribute>
117 findDuplicateElement(ArrayRef<NamedAttribute> value) {
118   const Optional<NamedAttribute> none{llvm::None};
119   if (value.size() < 2)
120     return none;
121 
122   if (value.size() == 2)
123     return value[0].getName() == value[1].getName() ? value[0] : none;
124 
125   const auto *it = std::adjacent_find(value.begin(), value.end(),
126                                       [](NamedAttribute l, NamedAttribute r) {
127                                         return l.getName() == r.getName();
128                                       });
129   return it != value.end() ? *it : none;
130 }
131 
132 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
133                           SmallVectorImpl<NamedAttribute> &storage) {
134   bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
135   assert(!findDuplicateElement(storage) &&
136          "DictionaryAttr element names must be unique");
137   return isSorted;
138 }
139 
140 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
141   bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
142   assert(!findDuplicateElement(array) &&
143          "DictionaryAttr element names must be unique");
144   return isSorted;
145 }
146 
147 Optional<NamedAttribute>
148 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
149                               bool isSorted) {
150   if (!isSorted)
151     dictionaryAttrSort</*inPlace=*/true>(array, array);
152   return findDuplicateElement(array);
153 }
154 
155 DictionaryAttr DictionaryAttr::get(MLIRContext *context,
156                                    ArrayRef<NamedAttribute> value) {
157   if (value.empty())
158     return DictionaryAttr::getEmpty(context);
159 
160   // We need to sort the element list to canonicalize it.
161   SmallVector<NamedAttribute, 8> storage;
162   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
163     value = storage;
164   assert(!findDuplicateElement(value) &&
165          "DictionaryAttr element names must be unique");
166   return Base::get(context, value);
167 }
168 /// Construct a dictionary with an array of values that is known to already be
169 /// sorted by name and uniqued.
170 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
171                                              ArrayRef<NamedAttribute> value) {
172   if (value.empty())
173     return DictionaryAttr::getEmpty(context);
174   // Ensure that the attribute elements are unique and sorted.
175   assert(llvm::is_sorted(
176              value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) &&
177          "expected attribute values to be sorted");
178   assert(!findDuplicateElement(value) &&
179          "DictionaryAttr element names must be unique");
180   return Base::get(context, value);
181 }
182 
183 /// Return the specified attribute if present, null otherwise.
184 Attribute DictionaryAttr::get(StringRef name) const {
185   auto it = impl::findAttrSorted(begin(), end(), name);
186   return it.second ? it.first->getValue() : Attribute();
187 }
188 Attribute DictionaryAttr::get(StringAttr name) const {
189   auto it = impl::findAttrSorted(begin(), end(), name);
190   return it.second ? it.first->getValue() : Attribute();
191 }
192 
193 /// Return the specified named attribute if present, None otherwise.
194 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
195   auto it = impl::findAttrSorted(begin(), end(), name);
196   return it.second ? *it.first : Optional<NamedAttribute>();
197 }
198 Optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const {
199   auto it = impl::findAttrSorted(begin(), end(), name);
200   return it.second ? *it.first : Optional<NamedAttribute>();
201 }
202 
203 /// Return whether the specified attribute is present.
204 bool DictionaryAttr::contains(StringRef name) const {
205   return impl::findAttrSorted(begin(), end(), name).second;
206 }
207 bool DictionaryAttr::contains(StringAttr name) const {
208   return impl::findAttrSorted(begin(), end(), name).second;
209 }
210 
211 DictionaryAttr::iterator DictionaryAttr::begin() const {
212   return getValue().begin();
213 }
214 DictionaryAttr::iterator DictionaryAttr::end() const {
215   return getValue().end();
216 }
217 size_t DictionaryAttr::size() const { return getValue().size(); }
218 
219 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
220   return Base::get(context, ArrayRef<NamedAttribute>());
221 }
222 
223 void DictionaryAttr::walkImmediateSubElements(
224     function_ref<void(Attribute)> walkAttrsFn,
225     function_ref<void(Type)> walkTypesFn) const {
226   for (const NamedAttribute &attr : getValue())
227     walkAttrsFn(attr.getValue());
228 }
229 
230 SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute(
231     ArrayRef<std::pair<size_t, Attribute>> replacements) const {
232   std::vector<NamedAttribute> vec = getValue().vec();
233   for (auto &it : replacements)
234     vec[it.first].setValue(it.second);
235 
236   // The above only modifies the mapped value, but not the key, and therefore
237   // not the order of the elements. It remains sorted
238   return getWithSorted(getContext(), vec);
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // StringAttr
243 //===----------------------------------------------------------------------===//
244 
245 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) {
246   return Base::get(context, "", NoneType::get(context));
247 }
248 
249 /// Twine support for StringAttr.
250 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) {
251   // Fast-path empty twine.
252   if (twine.isTriviallyEmpty())
253     return get(context);
254   SmallVector<char, 32> tempStr;
255   return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context));
256 }
257 
258 /// Twine support for StringAttr.
259 StringAttr StringAttr::get(const Twine &twine, Type type) {
260   SmallVector<char, 32> tempStr;
261   return Base::get(type.getContext(), twine.toStringRef(tempStr), type);
262 }
263 
264 StringRef StringAttr::getValue() const { return getImpl()->value; }
265 
266 Dialect *StringAttr::getReferencedDialect() const {
267   return getImpl()->referencedDialect;
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // FloatAttr
272 //===----------------------------------------------------------------------===//
273 
274 double FloatAttr::getValueAsDouble() const {
275   return getValueAsDouble(getValue());
276 }
277 double FloatAttr::getValueAsDouble(APFloat value) {
278   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
279     bool losesInfo = false;
280     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
281                   &losesInfo);
282   }
283   return value.convertToDouble();
284 }
285 
286 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
287                                 Type type, APFloat value) {
288   // Verify that the type is correct.
289   if (!type.isa<FloatType>())
290     return emitError() << "expected floating point type";
291 
292   // Verify that the type semantics match that of the value.
293   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
294     return emitError()
295            << "FloatAttr type doesn't match the type implied by its value";
296   }
297   return success();
298 }
299 
300 //===----------------------------------------------------------------------===//
301 // SymbolRefAttr
302 //===----------------------------------------------------------------------===//
303 
304 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
305                                  ArrayRef<FlatSymbolRefAttr> nestedRefs) {
306   return get(StringAttr::get(ctx, value), nestedRefs);
307 }
308 
309 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
310   return get(ctx, value, {}).cast<FlatSymbolRefAttr>();
311 }
312 
313 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
314   return get(value, {}).cast<FlatSymbolRefAttr>();
315 }
316 
317 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) {
318   auto symName =
319       symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
320   assert(symName && "value does not have a valid symbol name");
321   return SymbolRefAttr::get(symName);
322 }
323 
324 StringAttr SymbolRefAttr::getLeafReference() const {
325   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
326   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
327 }
328 
329 //===----------------------------------------------------------------------===//
330 // IntegerAttr
331 //===----------------------------------------------------------------------===//
332 
333 int64_t IntegerAttr::getInt() const {
334   assert((getType().isIndex() || getType().isSignlessInteger()) &&
335          "must be signless integer");
336   return getValue().getSExtValue();
337 }
338 
339 int64_t IntegerAttr::getSInt() const {
340   assert(getType().isSignedInteger() && "must be signed integer");
341   return getValue().getSExtValue();
342 }
343 
344 uint64_t IntegerAttr::getUInt() const {
345   assert(getType().isUnsignedInteger() && "must be unsigned integer");
346   return getValue().getZExtValue();
347 }
348 
349 /// Return the value as an APSInt which carries the signed from the type of
350 /// the attribute.  This traps on signless integers types!
351 APSInt IntegerAttr::getAPSInt() const {
352   assert(!getType().isSignlessInteger() &&
353          "Signless integers don't carry a sign for APSInt");
354   return APSInt(getValue(), getType().isUnsignedInteger());
355 }
356 
357 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
358                                   Type type, APInt value) {
359   if (IntegerType integerType = type.dyn_cast<IntegerType>()) {
360     if (integerType.getWidth() != value.getBitWidth())
361       return emitError() << "integer type bit width (" << integerType.getWidth()
362                          << ") doesn't match value bit width ("
363                          << value.getBitWidth() << ")";
364     return success();
365   }
366   if (type.isa<IndexType>())
367     return success();
368   return emitError() << "expected integer or index type";
369 }
370 
371 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
372   auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
373   return attr.cast<BoolAttr>();
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // BoolAttr
378 //===----------------------------------------------------------------------===//
379 
380 bool BoolAttr::getValue() const {
381   auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
382   return storage->value.getBoolValue();
383 }
384 
385 bool BoolAttr::classof(Attribute attr) {
386   IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
387   return intAttr && intAttr.getType().isSignlessInteger(1);
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // OpaqueAttr
392 //===----------------------------------------------------------------------===//
393 
394 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
395                                  StringAttr dialect, StringRef attrData,
396                                  Type type) {
397   if (!Dialect::isValidNamespace(dialect.strref()))
398     return emitError() << "invalid dialect namespace '" << dialect << "'";
399 
400   // Check that the dialect is actually registered.
401   MLIRContext *context = dialect.getContext();
402   if (!context->allowsUnregisteredDialects() &&
403       !context->getLoadedDialect(dialect.strref())) {
404     return emitError()
405            << "#" << dialect << "<\"" << attrData << "\"> : " << type
406            << " attribute created with unregistered dialect. If this is "
407               "intended, please call allowUnregisteredDialects() on the "
408               "MLIRContext, or use -allow-unregistered-dialect with "
409               "the MLIR opt tool used";
410   }
411 
412   return success();
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // DenseElementsAttr Utilities
417 //===----------------------------------------------------------------------===//
418 
419 /// Get the bitwidth of a dense element type within the buffer.
420 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
421 static size_t getDenseElementStorageWidth(size_t origWidth) {
422   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
423 }
424 static size_t getDenseElementStorageWidth(Type elementType) {
425   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
426 }
427 
428 /// Set a bit to a specific value.
429 static void setBit(char *rawData, size_t bitPos, bool value) {
430   if (value)
431     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
432   else
433     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
434 }
435 
436 /// Return the value of the specified bit.
437 static bool getBit(const char *rawData, size_t bitPos) {
438   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
439 }
440 
441 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
442 /// BE format.
443 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
444                                          char *result) {
445   assert(llvm::support::endian::system_endianness() == // NOLINT
446          llvm::support::endianness::big);              // NOLINT
447   assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
448 
449   // Copy the words filled with data.
450   // For example, when `value` has 2 words, the first word is filled with data.
451   // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
452   size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
453   std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
454               numFilledWords, result);
455   // Convert last word of APInt to LE format and store it in char
456   // array(`valueLE`).
457   // ex. last word of `value` (BE): |------ij|  ==> `valueLE` (LE): |ji------|
458   size_t lastWordPos = numFilledWords;
459   SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
460   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
461       reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
462       valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
463   // Extract actual APInt data from `valueLE`, convert endianness to BE format,
464   // and store it in `result`.
465   // ex. `valueLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|ij|
466   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
467       valueLE.begin(), result + lastWordPos,
468       (numBytes - lastWordPos) * CHAR_BIT, 1);
469 }
470 
471 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
472 /// format.
473 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
474                                          APInt &result) {
475   assert(llvm::support::endian::system_endianness() == // NOLINT
476          llvm::support::endianness::big);              // NOLINT
477   assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
478 
479   // Copy the data that fills the word of `result` from `inArray`.
480   // For example, when `result` has 2 words, the first word will be filled with
481   // data. So, the first 8 bytes are copied from `inArray` here.
482   // `inArray` (10 bytes, BE): |abcdefgh|ij|
483   //                     ==> `result` (2 words, BE): |abcdefgh|--------|
484   size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
485   std::copy_n(
486       inArray, numFilledWords,
487       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
488 
489   // Convert array data which will be last word of `result` to LE format, and
490   // store it in char array(`inArrayLE`).
491   // ex. `inArray` (last two bytes, BE): |ij|  ==> `inArrayLE` (LE): |ji------|
492   size_t lastWordPos = numFilledWords;
493   SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
494   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
495       inArray + lastWordPos, inArrayLE.begin(),
496       (numBytes - lastWordPos) * CHAR_BIT, 1);
497 
498   // Convert `inArrayLE` to BE format, and store it in last word of `result`.
499   // ex. `inArrayLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|------ij|
500   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
501       inArrayLE.begin(),
502       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
503           lastWordPos,
504       APInt::APINT_BITS_PER_WORD, 1);
505 }
506 
507 /// Writes value to the bit position `bitPos` in array `rawData`.
508 static void writeBits(char *rawData, size_t bitPos, APInt value) {
509   size_t bitWidth = value.getBitWidth();
510 
511   // If the bitwidth is 1 we just toggle the specific bit.
512   if (bitWidth == 1)
513     return setBit(rawData, bitPos, value.isOneValue());
514 
515   // Otherwise, the bit position is guaranteed to be byte aligned.
516   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
517   if (llvm::support::endian::system_endianness() ==
518       llvm::support::endianness::big) {
519     // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
520     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
521     // work correctly in BE format.
522     // ex. `value` (2 words including 10 bytes)
523     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------|
524     copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
525                                  rawData + (bitPos / CHAR_BIT));
526   } else {
527     std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
528                 llvm::divideCeil(bitWidth, CHAR_BIT),
529                 rawData + (bitPos / CHAR_BIT));
530   }
531 }
532 
533 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
534 /// `rawData`.
535 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
536   // Handle a boolean bit position.
537   if (bitWidth == 1)
538     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
539 
540   // Otherwise, the bit position must be 8-bit aligned.
541   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
542   APInt result(bitWidth, 0);
543   if (llvm::support::endian::system_endianness() ==
544       llvm::support::endianness::big) {
545     // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
546     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
547     // work correctly in BE format.
548     // ex. `result` (2 words including 10 bytes)
549     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------| This function
550     copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
551                                  llvm::divideCeil(bitWidth, CHAR_BIT), result);
552   } else {
553     std::copy_n(rawData + (bitPos / CHAR_BIT),
554                 llvm::divideCeil(bitWidth, CHAR_BIT),
555                 const_cast<char *>(
556                     reinterpret_cast<const char *>(result.getRawData())));
557   }
558   return result;
559 }
560 
561 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
562 /// the same element count as 'type'.
563 template <typename Values>
564 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
565   return (values.size() == 1) ||
566          (type.getNumElements() == static_cast<int64_t>(values.size()));
567 }
568 
569 //===----------------------------------------------------------------------===//
570 // DenseElementsAttr Iterators
571 //===----------------------------------------------------------------------===//
572 
573 //===----------------------------------------------------------------------===//
574 // AttributeElementIterator
575 
576 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
577     DenseElementsAttr attr, size_t index)
578     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
579                                       Attribute, Attribute, Attribute>(
580           attr.getAsOpaquePointer(), index) {}
581 
582 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
583   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
584   Type eltTy = owner.getElementType();
585   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
586     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
587   if (eltTy.isa<IndexType>())
588     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
589   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
590     IntElementIterator intIt(owner, index);
591     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
592     return FloatAttr::get(eltTy, *floatIt);
593   }
594   if (auto complexTy = eltTy.dyn_cast<ComplexType>()) {
595     auto complexEltTy = complexTy.getElementType();
596     ComplexIntElementIterator complexIntIt(owner, index);
597     if (complexEltTy.isa<IntegerType>()) {
598       auto value = *complexIntIt;
599       auto real = IntegerAttr::get(complexEltTy, value.real());
600       auto imag = IntegerAttr::get(complexEltTy, value.imag());
601       return ArrayAttr::get(complexTy.getContext(),
602                             ArrayRef<Attribute>{real, imag});
603     }
604 
605     ComplexFloatElementIterator complexFloatIt(
606         complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt);
607     auto value = *complexFloatIt;
608     auto real = FloatAttr::get(complexEltTy, value.real());
609     auto imag = FloatAttr::get(complexEltTy, value.imag());
610     return ArrayAttr::get(complexTy.getContext(),
611                           ArrayRef<Attribute>{real, imag});
612   }
613   if (owner.isa<DenseStringElementsAttr>()) {
614     ArrayRef<StringRef> vals = owner.getRawStringData();
615     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
616   }
617   llvm_unreachable("unexpected element type");
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // BoolElementIterator
622 
623 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
624     DenseElementsAttr attr, size_t dataIndex)
625     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
626           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
627 
628 bool DenseElementsAttr::BoolElementIterator::operator*() const {
629   return getBit(getData(), getDataIndex());
630 }
631 
632 //===----------------------------------------------------------------------===//
633 // IntElementIterator
634 
635 DenseElementsAttr::IntElementIterator::IntElementIterator(
636     DenseElementsAttr attr, size_t dataIndex)
637     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
638           attr.getRawData().data(), attr.isSplat(), dataIndex),
639       bitWidth(getDenseElementBitWidth(attr.getElementType())) {}
640 
641 APInt DenseElementsAttr::IntElementIterator::operator*() const {
642   return readBits(getData(),
643                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
644                   bitWidth);
645 }
646 
647 //===----------------------------------------------------------------------===//
648 // ComplexIntElementIterator
649 
650 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
651     DenseElementsAttr attr, size_t dataIndex)
652     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
653                                       std::complex<APInt>, std::complex<APInt>,
654                                       std::complex<APInt>>(
655           attr.getRawData().data(), attr.isSplat(), dataIndex) {
656   auto complexType = attr.getElementType().cast<ComplexType>();
657   bitWidth = getDenseElementBitWidth(complexType.getElementType());
658 }
659 
660 std::complex<APInt>
661 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
662   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
663   size_t offset = getDataIndex() * storageWidth * 2;
664   return {readBits(getData(), offset, bitWidth),
665           readBits(getData(), offset + storageWidth, bitWidth)};
666 }
667 
668 //===----------------------------------------------------------------------===//
669 // DenseArrayAttr
670 //===----------------------------------------------------------------------===//
671 
672 /// Custom storage to ensure proper memory alignment for the allocation of
673 /// DenseArray of any element type.
674 struct mlir::detail::DenseArrayBaseAttrStorage : public AttributeStorage {
675   using KeyTy = std::tuple<ShapedType, DenseArrayBaseAttr::EltType,
676                            ::llvm::ArrayRef<char>>;
677   DenseArrayBaseAttrStorage(ShapedType type,
678                             DenseArrayBaseAttr::EltType eltType,
679                             ::llvm::ArrayRef<char> elements)
680       : AttributeStorage(type), eltType(eltType), elements(elements) {}
681 
682   bool operator==(const KeyTy &tblgenKey) const {
683     return (getType() == std::get<0>(tblgenKey)) &&
684            (eltType == std::get<1>(tblgenKey)) &&
685            (elements == std::get<2>(tblgenKey));
686   }
687 
688   static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) {
689     return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey),
690                                 std::get<2>(tblgenKey));
691   }
692 
693   static DenseArrayBaseAttrStorage *
694   construct(AttributeStorageAllocator &allocator, const KeyTy &tblgenKey) {
695     auto type = std::get<0>(tblgenKey);
696     auto eltType = std::get<1>(tblgenKey);
697     auto elements = std::get<2>(tblgenKey);
698     if (!elements.empty()) {
699       char *alloc = static_cast<char *>(
700           allocator.allocate(elements.size(), alignof(uint64_t)));
701       std::uninitialized_copy(elements.begin(), elements.end(), alloc);
702       elements = ArrayRef<char>(alloc, elements.size());
703     }
704     return new (allocator.allocate<DenseArrayBaseAttrStorage>())
705         DenseArrayBaseAttrStorage(type, eltType, elements);
706   }
707 
708   DenseArrayBaseAttr::EltType eltType;
709   ::llvm::ArrayRef<char> elements;
710 };
711 
712 DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const {
713   return getImpl()->eltType;
714 }
715 
716 const int8_t *
717 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
718   return cast<DenseI8ArrayAttr>().asArrayRef().begin();
719 }
720 const int16_t *
721 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int16_t>) const {
722   return cast<DenseI16ArrayAttr>().asArrayRef().begin();
723 }
724 const int32_t *
725 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int32_t>) const {
726   return cast<DenseI32ArrayAttr>().asArrayRef().begin();
727 }
728 const int64_t *
729 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int64_t>) const {
730   return cast<DenseI64ArrayAttr>().asArrayRef().begin();
731 }
732 const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken<float>) const {
733   return cast<DenseF32ArrayAttr>().asArrayRef().begin();
734 }
735 const double *
736 DenseArrayBaseAttr::value_begin_impl(OverloadToken<double>) const {
737   return cast<DenseF64ArrayAttr>().asArrayRef().begin();
738 }
739 
740 void DenseArrayBaseAttr::print(AsmPrinter &printer) const {
741   print(printer.getStream());
742 }
743 
744 void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const {
745   switch (getElementType()) {
746   case DenseArrayBaseAttr::EltType::I8:
747     this->cast<DenseI8ArrayAttr>().printWithoutBraces(os);
748     return;
749   case DenseArrayBaseAttr::EltType::I16:
750     this->cast<DenseI16ArrayAttr>().printWithoutBraces(os);
751     return;
752   case DenseArrayBaseAttr::EltType::I32:
753     this->cast<DenseI32ArrayAttr>().printWithoutBraces(os);
754     return;
755   case DenseArrayBaseAttr::EltType::I64:
756     this->cast<DenseI64ArrayAttr>().printWithoutBraces(os);
757     return;
758   case DenseArrayBaseAttr::EltType::F32:
759     this->cast<DenseF32ArrayAttr>().printWithoutBraces(os);
760     return;
761   case DenseArrayBaseAttr::EltType::F64:
762     this->cast<DenseF64ArrayAttr>().printWithoutBraces(os);
763     return;
764   }
765   llvm_unreachable("<unknown DenseArrayBaseAttr>");
766 }
767 
768 void DenseArrayBaseAttr::print(raw_ostream &os) const {
769   os << "[";
770   printWithoutBraces(os);
771   os << "]";
772 }
773 
774 template <typename T>
775 void DenseArrayAttr<T>::print(AsmPrinter &printer) const {
776   print(printer.getStream());
777 }
778 
779 template <typename T>
780 void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const {
781   ArrayRef<T> values{*this};
782   llvm::interleaveComma(values, os);
783 }
784 
785 /// Specialization for int8_t for forcing printing as number instead of chars.
786 template <>
787 void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const {
788   ArrayRef<int8_t> values{*this};
789   llvm::interleaveComma(values, os, [&](int64_t v) { os << v; });
790 }
791 
792 template <typename T>
793 void DenseArrayAttr<T>::print(raw_ostream &os) const {
794   os << "[";
795   printWithoutBraces(os);
796   os << "]";
797 }
798 
799 /// Parse a single element: generic template for int types, specialized for
800 /// floating points below.
801 template <typename T>
802 static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) {
803   return parser.parseInteger(value);
804 }
805 
806 template <>
807 ParseResult parseDenseArrayAttrElt<float>(AsmParser &parser, float &value) {
808   double doubleVal;
809   if (parser.parseFloat(doubleVal))
810     return failure();
811   value = doubleVal;
812   return success();
813 }
814 
815 template <>
816 ParseResult parseDenseArrayAttrElt<double>(AsmParser &parser, double &value) {
817   return parser.parseFloat(value);
818 }
819 
820 /// Parse a DenseArrayAttr without the braces: `1, 2, 3`
821 template <typename T>
822 Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
823                                                 Type odsType) {
824   SmallVector<T> data;
825   if (failed(parser.parseCommaSeparatedList([&]() {
826         T value;
827         if (parseDenseArrayAttrElt(parser, value))
828           return failure();
829         data.push_back(value);
830         return success();
831       })))
832     return {};
833   return get(parser.getContext(), data);
834 }
835 
836 /// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
837 template <typename T>
838 Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
839   if (parser.parseLSquare())
840     return {};
841   // Handle empty list case.
842   if (succeeded(parser.parseOptionalRSquare()))
843     return get(parser.getContext(), {});
844   Attribute result = parseWithoutBraces(parser, odsType);
845   if (parser.parseRSquare())
846     return {};
847   return result;
848 }
849 
850 /// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
851 template <typename T>
852 DenseArrayAttr<T>::operator ArrayRef<T>() const {
853   ArrayRef<char> raw = getImpl()->elements;
854   assert((raw.size() % sizeof(T)) == 0);
855   return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
856                      raw.size() / sizeof(T));
857 }
858 
859 namespace {
860 /// Mapping from C++ element type to MLIR DenseArrayAttr internals.
861 template <typename T>
862 struct denseArrayAttrEltTypeBuilder;
863 template <>
864 struct denseArrayAttrEltTypeBuilder<int8_t> {
865   constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8;
866   static ShapedType getShapedType(MLIRContext *context,
867                                   ArrayRef<int64_t> shape) {
868     return VectorType::get(shape, IntegerType::get(context, 8));
869   }
870 };
871 template <>
872 struct denseArrayAttrEltTypeBuilder<int16_t> {
873   constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16;
874   static ShapedType getShapedType(MLIRContext *context,
875                                   ArrayRef<int64_t> shape) {
876     return VectorType::get(shape, IntegerType::get(context, 16));
877   }
878 };
879 template <>
880 struct denseArrayAttrEltTypeBuilder<int32_t> {
881   constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32;
882   static ShapedType getShapedType(MLIRContext *context,
883                                   ArrayRef<int64_t> shape) {
884     return VectorType::get(shape, IntegerType::get(context, 32));
885   }
886 };
887 template <>
888 struct denseArrayAttrEltTypeBuilder<int64_t> {
889   constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64;
890   static ShapedType getShapedType(MLIRContext *context,
891                                   ArrayRef<int64_t> shape) {
892     return VectorType::get(shape, IntegerType::get(context, 64));
893   }
894 };
895 template <>
896 struct denseArrayAttrEltTypeBuilder<float> {
897   constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32;
898   static ShapedType getShapedType(MLIRContext *context,
899                                   ArrayRef<int64_t> shape) {
900     return VectorType::get(shape, Float32Type::get(context));
901   }
902 };
903 template <>
904 struct denseArrayAttrEltTypeBuilder<double> {
905   constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64;
906   static ShapedType getShapedType(MLIRContext *context,
907                                   ArrayRef<int64_t> shape) {
908     return VectorType::get(shape, Float64Type::get(context));
909   }
910 };
911 } // namespace
912 
913 /// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
914 template <typename T>
915 DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context,
916                                          ArrayRef<T> content) {
917   auto size = static_cast<int64_t>(content.size());
918   auto shapedType = denseArrayAttrEltTypeBuilder<T>::getShapedType(
919       context, size ? ArrayRef<int64_t>{size} : ArrayRef<int64_t>{});
920   auto eltType = denseArrayAttrEltTypeBuilder<T>::eltType;
921   auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
922                                  content.size() * sizeof(T));
923   return Base::get(context, shapedType, eltType, rawArray)
924       .template cast<DenseArrayAttr<T>>();
925 }
926 
927 template <typename T>
928 bool DenseArrayAttr<T>::classof(Attribute attr) {
929   return attr.isa<DenseArrayBaseAttr>() &&
930          attr.cast<DenseArrayBaseAttr>().getElementType() ==
931              denseArrayAttrEltTypeBuilder<T>::eltType;
932 }
933 
934 namespace mlir {
935 namespace detail {
936 // Explicit instantiation for all the supported DenseArrayAttr.
937 template class DenseArrayAttr<int8_t>;
938 template class DenseArrayAttr<int16_t>;
939 template class DenseArrayAttr<int32_t>;
940 template class DenseArrayAttr<int64_t>;
941 template class DenseArrayAttr<float>;
942 template class DenseArrayAttr<double>;
943 } // namespace detail
944 } // namespace mlir
945 
946 //===----------------------------------------------------------------------===//
947 // DenseElementsAttr
948 //===----------------------------------------------------------------------===//
949 
950 /// Method for support type inquiry through isa, cast and dyn_cast.
951 bool DenseElementsAttr::classof(Attribute attr) {
952   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
953 }
954 
955 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
956                                          ArrayRef<Attribute> values) {
957   assert(hasSameElementsOrSplat(type, values));
958 
959   // If the element type is not based on int/float/index, assume it is a string
960   // type.
961   auto eltType = type.getElementType();
962   if (!type.getElementType().isIntOrIndexOrFloat()) {
963     SmallVector<StringRef, 8> stringValues;
964     stringValues.reserve(values.size());
965     for (Attribute attr : values) {
966       assert(attr.isa<StringAttr>() &&
967              "expected string value for non integer/index/float element");
968       stringValues.push_back(attr.cast<StringAttr>().getValue());
969     }
970     return get(type, stringValues);
971   }
972 
973   // Otherwise, get the raw storage width to use for the allocation.
974   size_t bitWidth = getDenseElementBitWidth(eltType);
975   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
976 
977   // Compress the attribute values into a character buffer.
978   SmallVector<char, 8> data(
979       llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT));
980   APInt intVal;
981   for (unsigned i = 0, e = values.size(); i < e; ++i) {
982     assert(eltType == values[i].getType() &&
983            "expected attribute value to have element type");
984     if (eltType.isa<FloatType>())
985       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
986     else if (eltType.isa<IntegerType, IndexType>())
987       intVal = values[i].cast<IntegerAttr>().getValue();
988     else
989       llvm_unreachable("unexpected element type");
990 
991     assert(intVal.getBitWidth() == bitWidth &&
992            "expected value to have same bitwidth as element type");
993     writeBits(data.data(), i * storageBitWidth, intVal);
994   }
995 
996   // Handle the special encoding of splat of bool.
997   if (values.size() == 1 && values[0].getType().isInteger(1))
998     data[0] = data[0] ? -1 : 0;
999 
1000   return DenseIntOrFPElementsAttr::getRaw(type, data);
1001 }
1002 
1003 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1004                                          ArrayRef<bool> values) {
1005   assert(hasSameElementsOrSplat(type, values));
1006   assert(type.getElementType().isInteger(1));
1007 
1008   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
1009 
1010   if (!values.empty()) {
1011     bool isSplat = true;
1012     bool firstValue = values[0];
1013     for (int i = 0, e = values.size(); i != e; ++i) {
1014       isSplat &= values[i] == firstValue;
1015       setBit(buff.data(), i, values[i]);
1016     }
1017 
1018     // Splat of bool is encoded as a byte with all-ones in it.
1019     if (isSplat) {
1020       buff.resize(1);
1021       buff[0] = values[0] ? -1 : 0;
1022     }
1023   }
1024 
1025   return DenseIntOrFPElementsAttr::getRaw(type, buff);
1026 }
1027 
1028 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1029                                          ArrayRef<StringRef> values) {
1030   assert(!type.getElementType().isIntOrFloat());
1031   return DenseStringElementsAttr::get(type, values);
1032 }
1033 
1034 /// Constructs a dense integer elements attribute from an array of APInt
1035 /// values. Each APInt value is expected to have the same bitwidth as the
1036 /// element type of 'type'.
1037 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1038                                          ArrayRef<APInt> values) {
1039   assert(type.getElementType().isIntOrIndex());
1040   assert(hasSameElementsOrSplat(type, values));
1041   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
1042   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1043 }
1044 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1045                                          ArrayRef<std::complex<APInt>> values) {
1046   ComplexType complex = type.getElementType().cast<ComplexType>();
1047   assert(complex.getElementType().isa<IntegerType>());
1048   assert(hasSameElementsOrSplat(type, values));
1049   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
1050   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
1051                           values.size() * 2);
1052   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals);
1053 }
1054 
1055 // Constructs a dense float elements attribute from an array of APFloat
1056 // values. Each APFloat value is expected to have the same bitwidth as the
1057 // element type of 'type'.
1058 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1059                                          ArrayRef<APFloat> values) {
1060   assert(type.getElementType().isa<FloatType>());
1061   assert(hasSameElementsOrSplat(type, values));
1062   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
1063   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1064 }
1065 DenseElementsAttr
1066 DenseElementsAttr::get(ShapedType type,
1067                        ArrayRef<std::complex<APFloat>> values) {
1068   ComplexType complex = type.getElementType().cast<ComplexType>();
1069   assert(complex.getElementType().isa<FloatType>());
1070   assert(hasSameElementsOrSplat(type, values));
1071   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
1072                            values.size() * 2);
1073   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
1074   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals);
1075 }
1076 
1077 /// Construct a dense elements attribute from a raw buffer representing the
1078 /// data for this attribute. Users should generally not use this methods as
1079 /// the expected buffer format may not be a form the user expects.
1080 DenseElementsAttr
1081 DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) {
1082   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer);
1083 }
1084 
1085 /// Returns true if the given buffer is a valid raw buffer for the given type.
1086 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
1087                                          ArrayRef<char> rawBuffer,
1088                                          bool &detectedSplat) {
1089   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
1090   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
1091   int64_t numElements = type.getNumElements();
1092 
1093   // The initializer is always a splat if the result type has a single element.
1094   detectedSplat = numElements == 1;
1095 
1096   // Storage width of 1 is special as it is packed by the bit.
1097   if (storageWidth == 1) {
1098     // Check for a splat, or a buffer equal to the number of elements which
1099     // consists of either all 0's or all 1's.
1100     if (rawBuffer.size() == 1) {
1101       auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
1102       if (rawByte == 0 || rawByte == 0xff) {
1103         detectedSplat = true;
1104         return true;
1105       }
1106     }
1107 
1108     // This is a valid non-splat buffer if it has the right size.
1109     return rawBufferWidth == llvm::alignTo<8>(numElements);
1110   }
1111 
1112   // All other types are 8-bit aligned, so we can just check the buffer width
1113   // to know if only a single initializer element was passed in.
1114   if (rawBufferWidth == storageWidth) {
1115     detectedSplat = true;
1116     return true;
1117   }
1118 
1119   // The raw buffer is valid if it has the right size.
1120   return rawBufferWidth == storageWidth * numElements;
1121 }
1122 
1123 /// Check the information for a C++ data type, check if this type is valid for
1124 /// the current attribute. This method is used to verify specific type
1125 /// invariants that the templatized 'getValues' method cannot.
1126 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
1127                               bool isSigned) {
1128   // Make sure that the data element size is the same as the type element width.
1129   if (getDenseElementBitWidth(type) !=
1130       static_cast<size_t>(dataEltSize * CHAR_BIT))
1131     return false;
1132 
1133   // Check that the element type is either float or integer or index.
1134   if (!isInt)
1135     return type.isa<FloatType>();
1136   if (type.isIndex())
1137     return true;
1138 
1139   auto intType = type.dyn_cast<IntegerType>();
1140   if (!intType)
1141     return false;
1142 
1143   // Make sure signedness semantics is consistent.
1144   if (intType.isSignless())
1145     return true;
1146   return intType.isSigned() ? isSigned : !isSigned;
1147 }
1148 
1149 /// Defaults down the subclass implementation.
1150 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
1151                                                    ArrayRef<char> data,
1152                                                    int64_t dataEltSize,
1153                                                    bool isInt, bool isSigned) {
1154   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
1155                                                  isSigned);
1156 }
1157 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
1158                                                       ArrayRef<char> data,
1159                                                       int64_t dataEltSize,
1160                                                       bool isInt,
1161                                                       bool isSigned) {
1162   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
1163                                                     isInt, isSigned);
1164 }
1165 
1166 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
1167                                           bool isSigned) const {
1168   return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned);
1169 }
1170 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
1171                                        bool isSigned) const {
1172   return ::isValidIntOrFloat(
1173       getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2,
1174       isInt, isSigned);
1175 }
1176 
1177 /// Returns true if this attribute corresponds to a splat, i.e. if all element
1178 /// values are the same.
1179 bool DenseElementsAttr::isSplat() const {
1180   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
1181 }
1182 
1183 /// Return if the given complex type has an integer element type.
1184 LLVM_ATTRIBUTE_UNUSED static bool isComplexOfIntType(Type type) {
1185   return type.cast<ComplexType>().getElementType().isa<IntegerType>();
1186 }
1187 
1188 auto DenseElementsAttr::getComplexIntValues() const
1189     -> iterator_range_impl<ComplexIntElementIterator> {
1190   assert(isComplexOfIntType(getElementType()) &&
1191          "expected complex integral type");
1192   return {getType(), ComplexIntElementIterator(*this, 0),
1193           ComplexIntElementIterator(*this, getNumElements())};
1194 }
1195 auto DenseElementsAttr::complex_value_begin() const
1196     -> ComplexIntElementIterator {
1197   assert(isComplexOfIntType(getElementType()) &&
1198          "expected complex integral type");
1199   return ComplexIntElementIterator(*this, 0);
1200 }
1201 auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator {
1202   assert(isComplexOfIntType(getElementType()) &&
1203          "expected complex integral type");
1204   return ComplexIntElementIterator(*this, getNumElements());
1205 }
1206 
1207 /// Return the held element values as a range of APFloat. The element type of
1208 /// this attribute must be of float type.
1209 auto DenseElementsAttr::getFloatValues() const
1210     -> iterator_range_impl<FloatElementIterator> {
1211   auto elementType = getElementType().cast<FloatType>();
1212   const auto &elementSemantics = elementType.getFloatSemantics();
1213   return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()),
1214           FloatElementIterator(elementSemantics, raw_int_end())};
1215 }
1216 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
1217   auto elementType = getElementType().cast<FloatType>();
1218   return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin());
1219 }
1220 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
1221   auto elementType = getElementType().cast<FloatType>();
1222   return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end());
1223 }
1224 
1225 auto DenseElementsAttr::getComplexFloatValues() const
1226     -> iterator_range_impl<ComplexFloatElementIterator> {
1227   Type eltTy = getElementType().cast<ComplexType>().getElementType();
1228   assert(eltTy.isa<FloatType>() && "expected complex float type");
1229   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
1230   return {getType(),
1231           {semantics, {*this, 0}},
1232           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
1233 }
1234 auto DenseElementsAttr::complex_float_value_begin() const
1235     -> ComplexFloatElementIterator {
1236   Type eltTy = getElementType().cast<ComplexType>().getElementType();
1237   assert(eltTy.isa<FloatType>() && "expected complex float type");
1238   return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}};
1239 }
1240 auto DenseElementsAttr::complex_float_value_end() const
1241     -> ComplexFloatElementIterator {
1242   Type eltTy = getElementType().cast<ComplexType>().getElementType();
1243   assert(eltTy.isa<FloatType>() && "expected complex float type");
1244   return {eltTy.cast<FloatType>().getFloatSemantics(),
1245           {*this, static_cast<size_t>(getNumElements())}};
1246 }
1247 
1248 /// Return the raw storage data held by this attribute.
1249 ArrayRef<char> DenseElementsAttr::getRawData() const {
1250   return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
1251 }
1252 
1253 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1254   return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
1255 }
1256 
1257 /// Return a new DenseElementsAttr that has the same data as the current
1258 /// attribute, but has been reshaped to 'newType'. The new type must have the
1259 /// same total number of elements as well as element type.
1260 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1261   ShapedType curType = getType();
1262   if (curType == newType)
1263     return *this;
1264 
1265   assert(newType.getElementType() == curType.getElementType() &&
1266          "expected the same element type");
1267   assert(newType.getNumElements() == curType.getNumElements() &&
1268          "expected the same number of elements");
1269   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1270 }
1271 
1272 DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
1273   assert(isSplat() && "expected a splat type");
1274 
1275   ShapedType curType = getType();
1276   if (curType == newType)
1277     return *this;
1278 
1279   assert(newType.getElementType() == curType.getElementType() &&
1280          "expected the same element type");
1281   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1282 }
1283 
1284 /// Return a new DenseElementsAttr that has the same data as the current
1285 /// attribute, but has bitcast elements such that it is now 'newType'. The new
1286 /// type must have the same shape and element types of the same bitwidth as the
1287 /// current type.
1288 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
1289   ShapedType curType = getType();
1290   Type curElType = curType.getElementType();
1291   if (curElType == newElType)
1292     return *this;
1293 
1294   assert(getDenseElementBitWidth(newElType) ==
1295              getDenseElementBitWidth(curElType) &&
1296          "expected element types with the same bitwidth");
1297   return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
1298                                           getRawData());
1299 }
1300 
1301 DenseElementsAttr
1302 DenseElementsAttr::mapValues(Type newElementType,
1303                              function_ref<APInt(const APInt &)> mapping) const {
1304   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1305 }
1306 
1307 DenseElementsAttr DenseElementsAttr::mapValues(
1308     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1309   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1310 }
1311 
1312 ShapedType DenseElementsAttr::getType() const {
1313   return Attribute::getType().cast<ShapedType>();
1314 }
1315 
1316 Type DenseElementsAttr::getElementType() const {
1317   return getType().getElementType();
1318 }
1319 
1320 int64_t DenseElementsAttr::getNumElements() const {
1321   return getType().getNumElements();
1322 }
1323 
1324 //===----------------------------------------------------------------------===//
1325 // DenseIntOrFPElementsAttr
1326 //===----------------------------------------------------------------------===//
1327 
1328 /// Utility method to write a range of APInt values to a buffer.
1329 template <typename APRangeT>
1330 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1331                                 APRangeT &&values) {
1332   size_t numValues = llvm::size(values);
1333   data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT));
1334   size_t offset = 0;
1335   for (auto it = values.begin(), e = values.end(); it != e;
1336        ++it, offset += storageWidth) {
1337     assert((*it).getBitWidth() <= storageWidth);
1338     writeBits(data.data(), offset, *it);
1339   }
1340 
1341   // Handle the special encoding of splat of a boolean.
1342   if (numValues == 1 && (*values.begin()).getBitWidth() == 1)
1343     data[0] = data[0] ? -1 : 0;
1344 }
1345 
1346 /// Constructs a dense elements attribute from an array of raw APFloat values.
1347 /// Each APFloat value is expected to have the same bitwidth as the element
1348 /// type of 'type'. 'type' must be a vector or tensor with static shape.
1349 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1350                                                    size_t storageWidth,
1351                                                    ArrayRef<APFloat> values) {
1352   std::vector<char> data;
1353   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1354   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1355   return DenseIntOrFPElementsAttr::getRaw(type, data);
1356 }
1357 
1358 /// Constructs a dense elements attribute from an array of raw APInt values.
1359 /// Each APInt value is expected to have the same bitwidth as the element type
1360 /// of 'type'.
1361 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1362                                                    size_t storageWidth,
1363                                                    ArrayRef<APInt> values) {
1364   std::vector<char> data;
1365   writeAPIntsToBuffer(storageWidth, data, values);
1366   return DenseIntOrFPElementsAttr::getRaw(type, data);
1367 }
1368 
1369 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1370                                                    ArrayRef<char> data) {
1371   assert((type.isa<RankedTensorType, VectorType>()) &&
1372          "type must be ranked tensor or vector");
1373   assert(type.hasStaticShape() && "type must have static shape");
1374   bool isSplat = false;
1375   bool isValid = isValidRawBuffer(type, data, isSplat);
1376   assert(isValid);
1377   (void)isValid;
1378   return Base::get(type.getContext(), type, data, isSplat);
1379 }
1380 
1381 /// Overload of the raw 'get' method that asserts that the given type is of
1382 /// complex type. This method is used to verify type invariants that the
1383 /// templatized 'get' method cannot.
1384 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1385                                                           ArrayRef<char> data,
1386                                                           int64_t dataEltSize,
1387                                                           bool isInt,
1388                                                           bool isSigned) {
1389   assert(::isValidIntOrFloat(
1390       type.getElementType().cast<ComplexType>().getElementType(),
1391       dataEltSize / 2, isInt, isSigned));
1392 
1393   int64_t numElements = data.size() / dataEltSize;
1394   (void)numElements;
1395   assert(numElements == 1 || numElements == type.getNumElements());
1396   return getRaw(type, data);
1397 }
1398 
1399 /// Overload of the 'getRaw' method that asserts that the given type is of
1400 /// integer type. This method is used to verify type invariants that the
1401 /// templatized 'get' method cannot.
1402 DenseElementsAttr
1403 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1404                                            int64_t dataEltSize, bool isInt,
1405                                            bool isSigned) {
1406   assert(
1407       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1408 
1409   int64_t numElements = data.size() / dataEltSize;
1410   assert(numElements == 1 || numElements == type.getNumElements());
1411   (void)numElements;
1412   return getRaw(type, data);
1413 }
1414 
1415 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1416     const char *inRawData, char *outRawData, size_t elementBitWidth,
1417     size_t numElements) {
1418   using llvm::support::ulittle16_t;
1419   using llvm::support::ulittle32_t;
1420   using llvm::support::ulittle64_t;
1421 
1422   assert(llvm::support::endian::system_endianness() == // NOLINT
1423          llvm::support::endianness::big);              // NOLINT
1424   // NOLINT to avoid warning message about replacing by static_assert()
1425 
1426   // Following std::copy_n always converts endianness on BE machine.
1427   switch (elementBitWidth) {
1428   case 16: {
1429     const ulittle16_t *inRawDataPos =
1430         reinterpret_cast<const ulittle16_t *>(inRawData);
1431     uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1432     std::copy_n(inRawDataPos, numElements, outDataPos);
1433     break;
1434   }
1435   case 32: {
1436     const ulittle32_t *inRawDataPos =
1437         reinterpret_cast<const ulittle32_t *>(inRawData);
1438     uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1439     std::copy_n(inRawDataPos, numElements, outDataPos);
1440     break;
1441   }
1442   case 64: {
1443     const ulittle64_t *inRawDataPos =
1444         reinterpret_cast<const ulittle64_t *>(inRawData);
1445     uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1446     std::copy_n(inRawDataPos, numElements, outDataPos);
1447     break;
1448   }
1449   default: {
1450     size_t nBytes = elementBitWidth / CHAR_BIT;
1451     for (size_t i = 0; i < nBytes; i++)
1452       std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1453     break;
1454   }
1455   }
1456 }
1457 
1458 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1459     ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1460     ShapedType type) {
1461   size_t numElements = type.getNumElements();
1462   Type elementType = type.getElementType();
1463   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1464     elementType = complexTy.getElementType();
1465     numElements = numElements * 2;
1466   }
1467   size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1468   assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1469          inRawData.size() <= outRawData.size());
1470   if (elementBitWidth <= CHAR_BIT)
1471     std::memcpy(outRawData.begin(), inRawData.begin(), inRawData.size());
1472   else
1473     convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1474                                     elementBitWidth, numElements);
1475 }
1476 
1477 //===----------------------------------------------------------------------===//
1478 // DenseFPElementsAttr
1479 //===----------------------------------------------------------------------===//
1480 
1481 template <typename Fn, typename Attr>
1482 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1483                                 Type newElementType,
1484                                 llvm::SmallVectorImpl<char> &data) {
1485   size_t bitWidth = getDenseElementBitWidth(newElementType);
1486   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1487 
1488   ShapedType newArrayType;
1489   if (inType.isa<RankedTensorType>())
1490     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1491   else if (inType.isa<UnrankedTensorType>())
1492     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1493   else if (auto vType = inType.dyn_cast<VectorType>())
1494     newArrayType = VectorType::get(vType.getShape(), newElementType,
1495                                    vType.getNumScalableDims());
1496   else
1497     assert(newArrayType && "Unhandled tensor type");
1498 
1499   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1500   data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT));
1501 
1502   // Functor used to process a single element value of the attribute.
1503   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1504     auto newInt = mapping(value);
1505     assert(newInt.getBitWidth() == bitWidth);
1506     writeBits(data.data(), index * storageBitWidth, newInt);
1507   };
1508 
1509   // Check for the splat case.
1510   if (attr.isSplat()) {
1511     processElt(*attr.begin(), /*index=*/0);
1512     return newArrayType;
1513   }
1514 
1515   // Otherwise, process all of the element values.
1516   uint64_t elementIdx = 0;
1517   for (auto value : attr)
1518     processElt(value, elementIdx++);
1519   return newArrayType;
1520 }
1521 
1522 DenseElementsAttr DenseFPElementsAttr::mapValues(
1523     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1524   llvm::SmallVector<char, 8> elementData;
1525   auto newArrayType =
1526       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1527 
1528   return getRaw(newArrayType, elementData);
1529 }
1530 
1531 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1532 bool DenseFPElementsAttr::classof(Attribute attr) {
1533   return attr.isa<DenseElementsAttr>() &&
1534          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1535 }
1536 
1537 //===----------------------------------------------------------------------===//
1538 // DenseIntElementsAttr
1539 //===----------------------------------------------------------------------===//
1540 
1541 DenseElementsAttr DenseIntElementsAttr::mapValues(
1542     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1543   llvm::SmallVector<char, 8> elementData;
1544   auto newArrayType =
1545       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1546   return getRaw(newArrayType, elementData);
1547 }
1548 
1549 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1550 bool DenseIntElementsAttr::classof(Attribute attr) {
1551   return attr.isa<DenseElementsAttr>() &&
1552          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1553 }
1554 
1555 //===----------------------------------------------------------------------===//
1556 // OpaqueElementsAttr
1557 //===----------------------------------------------------------------------===//
1558 
1559 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1560   Dialect *dialect = getContext()->getLoadedDialect(getDialect());
1561   if (!dialect)
1562     return true;
1563   auto *interface = llvm::dyn_cast<DialectDecodeAttributesInterface>(dialect);
1564   if (!interface)
1565     return true;
1566   return failed(interface->decode(*this, result));
1567 }
1568 
1569 LogicalResult
1570 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1571                            StringAttr dialect, StringRef value,
1572                            ShapedType type) {
1573   if (!Dialect::isValidNamespace(dialect.strref()))
1574     return emitError() << "invalid dialect namespace '" << dialect << "'";
1575   return success();
1576 }
1577 
1578 //===----------------------------------------------------------------------===//
1579 // SparseElementsAttr
1580 //===----------------------------------------------------------------------===//
1581 
1582 /// Get a zero APFloat for the given sparse attribute.
1583 APFloat SparseElementsAttr::getZeroAPFloat() const {
1584   auto eltType = getElementType().cast<FloatType>();
1585   return APFloat(eltType.getFloatSemantics());
1586 }
1587 
1588 /// Get a zero APInt for the given sparse attribute.
1589 APInt SparseElementsAttr::getZeroAPInt() const {
1590   auto eltType = getElementType().cast<IntegerType>();
1591   return APInt::getZero(eltType.getWidth());
1592 }
1593 
1594 /// Get a zero attribute for the given attribute type.
1595 Attribute SparseElementsAttr::getZeroAttr() const {
1596   auto eltType = getElementType();
1597 
1598   // Handle floating point elements.
1599   if (eltType.isa<FloatType>())
1600     return FloatAttr::get(eltType, 0);
1601 
1602   // Handle complex elements.
1603   if (auto complexTy = eltType.dyn_cast<ComplexType>()) {
1604     auto eltType = complexTy.getElementType();
1605     Attribute zero;
1606     if (eltType.isa<FloatType>())
1607       zero = FloatAttr::get(eltType, 0);
1608     else // must be integer
1609       zero = IntegerAttr::get(eltType, 0);
1610     return ArrayAttr::get(complexTy.getContext(),
1611                           ArrayRef<Attribute>{zero, zero});
1612   }
1613 
1614   // Handle string type.
1615   if (getValues().isa<DenseStringElementsAttr>())
1616     return StringAttr::get("", eltType);
1617 
1618   // Otherwise, this is an integer.
1619   return IntegerAttr::get(eltType, 0);
1620 }
1621 
1622 /// Flatten, and return, all of the sparse indices in this attribute in
1623 /// row-major order.
1624 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1625   std::vector<ptrdiff_t> flatSparseIndices;
1626 
1627   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1628   // as a 1-D index array.
1629   auto sparseIndices = getIndices();
1630   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1631   if (sparseIndices.isSplat()) {
1632     SmallVector<uint64_t, 8> indices(getType().getRank(),
1633                                      *sparseIndexValues.begin());
1634     flatSparseIndices.push_back(getFlattenedIndex(indices));
1635     return flatSparseIndices;
1636   }
1637 
1638   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1639   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1640   size_t rank = getType().getRank();
1641   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1642     flatSparseIndices.push_back(getFlattenedIndex(
1643         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1644   return flatSparseIndices;
1645 }
1646 
1647 LogicalResult
1648 SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1649                            ShapedType type, DenseIntElementsAttr sparseIndices,
1650                            DenseElementsAttr values) {
1651   ShapedType valuesType = values.getType();
1652   if (valuesType.getRank() != 1)
1653     return emitError() << "expected 1-d tensor for sparse element values";
1654 
1655   // Verify the indices and values shape.
1656   ShapedType indicesType = sparseIndices.getType();
1657   auto emitShapeError = [&]() {
1658     return emitError() << "expected shape ([" << type.getShape()
1659                        << "]); inferred shape of indices literal (["
1660                        << indicesType.getShape()
1661                        << "]); inferred shape of values literal (["
1662                        << valuesType.getShape() << "])";
1663   };
1664   // Verify indices shape.
1665   size_t rank = type.getRank(), indicesRank = indicesType.getRank();
1666   if (indicesRank == 2) {
1667     if (indicesType.getDimSize(1) != static_cast<int64_t>(rank))
1668       return emitShapeError();
1669   } else if (indicesRank != 1 || rank != 1) {
1670     return emitShapeError();
1671   }
1672   // Verify the values shape.
1673   int64_t numSparseIndices = indicesType.getDimSize(0);
1674   if (numSparseIndices != valuesType.getDimSize(0))
1675     return emitShapeError();
1676 
1677   // Verify that the sparse indices are within the value shape.
1678   auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) {
1679     return emitError()
1680            << "sparse index #" << indexNum
1681            << " is not contained within the value shape, with index=[" << index
1682            << "], and type=" << type;
1683   };
1684 
1685   // Handle the case where the index values are a splat.
1686   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1687   if (sparseIndices.isSplat()) {
1688     SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin());
1689     if (!ElementsAttr::isValidIndex(type, indices))
1690       return emitIndexError(0, indices);
1691     return success();
1692   }
1693 
1694   // Otherwise, reinterpret each index as an ArrayRef.
1695   for (size_t i = 0, e = numSparseIndices; i != e; ++i) {
1696     ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank),
1697                              rank);
1698     if (!ElementsAttr::isValidIndex(type, index))
1699       return emitIndexError(i, index);
1700   }
1701 
1702   return success();
1703 }
1704 
1705 //===----------------------------------------------------------------------===//
1706 // TypeAttr
1707 //===----------------------------------------------------------------------===//
1708 
1709 void TypeAttr::walkImmediateSubElements(
1710     function_ref<void(Attribute)> walkAttrsFn,
1711     function_ref<void(Type)> walkTypesFn) const {
1712   walkTypesFn(getValue());
1713 }
1714