1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BuiltinDialect.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Types.h"
17 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/Twine.h"
20 #include "llvm/Support/Endian.h"
21 
22 using namespace mlir;
23 using namespace mlir::detail;
24 
25 //===----------------------------------------------------------------------===//
26 /// Tablegen Attribute Definitions
27 //===----------------------------------------------------------------------===//
28 
29 #define GET_ATTRDEF_CLASSES
30 #include "mlir/IR/BuiltinAttributes.cpp.inc"
31 
32 //===----------------------------------------------------------------------===//
33 // BuiltinDialect
34 //===----------------------------------------------------------------------===//
35 
36 void BuiltinDialect::registerAttributes() {
37   addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
38                 DenseStringElementsAttr, DictionaryAttr, FloatAttr,
39                 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
40                 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
41                 UnitAttr>();
42 }
43 
44 //===----------------------------------------------------------------------===//
45 // DictionaryAttr
46 //===----------------------------------------------------------------------===//
47 
48 /// Helper function that does either an in place sort or sorts from source array
49 /// into destination. If inPlace then storage is both the source and the
50 /// destination, else value is the source and storage destination. Returns
51 /// whether source was sorted.
52 template <bool inPlace>
53 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
54                                SmallVectorImpl<NamedAttribute> &storage) {
55   // Specialize for the common case.
56   switch (value.size()) {
57   case 0:
58     // Zero already sorted.
59     break;
60   case 1:
61     // One already sorted but may need to be copied.
62     if (!inPlace)
63       storage.assign({value[0]});
64     break;
65   case 2: {
66     bool isSorted = value[0] < value[1];
67     if (inPlace) {
68       if (!isSorted)
69         std::swap(storage[0], storage[1]);
70     } else if (isSorted) {
71       storage.assign({value[0], value[1]});
72     } else {
73       storage.assign({value[1], value[0]});
74     }
75     return !isSorted;
76   }
77   default:
78     if (!inPlace)
79       storage.assign(value.begin(), value.end());
80     // Check to see they are sorted already.
81     bool isSorted = llvm::is_sorted(value);
82     // If not, do a general sort.
83     if (!isSorted)
84       llvm::array_pod_sort(storage.begin(), storage.end());
85     return !isSorted;
86   }
87   return false;
88 }
89 
90 /// Returns an entry with a duplicate name from the given sorted array of named
91 /// attributes. Returns llvm::None if all elements have unique names.
92 static Optional<NamedAttribute>
93 findDuplicateElement(ArrayRef<NamedAttribute> value) {
94   const Optional<NamedAttribute> none{llvm::None};
95   if (value.size() < 2)
96     return none;
97 
98   if (value.size() == 2)
99     return value[0].first == value[1].first ? value[0] : none;
100 
101   auto it = std::adjacent_find(
102       value.begin(), value.end(),
103       [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; });
104   return it != value.end() ? *it : none;
105 }
106 
107 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
108                           SmallVectorImpl<NamedAttribute> &storage) {
109   bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
110   assert(!findDuplicateElement(storage) &&
111          "DictionaryAttr element names must be unique");
112   return isSorted;
113 }
114 
115 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
116   bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
117   assert(!findDuplicateElement(array) &&
118          "DictionaryAttr element names must be unique");
119   return isSorted;
120 }
121 
122 Optional<NamedAttribute>
123 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
124                               bool isSorted) {
125   if (!isSorted)
126     dictionaryAttrSort</*inPlace=*/true>(array, array);
127   return findDuplicateElement(array);
128 }
129 
130 DictionaryAttr DictionaryAttr::get(MLIRContext *context,
131                                    ArrayRef<NamedAttribute> value) {
132   if (value.empty())
133     return DictionaryAttr::getEmpty(context);
134   assert(llvm::all_of(value,
135                       [](const NamedAttribute &attr) { return attr.second; }) &&
136          "value cannot have null entries");
137 
138   // We need to sort the element list to canonicalize it.
139   SmallVector<NamedAttribute, 8> storage;
140   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
141     value = storage;
142   assert(!findDuplicateElement(value) &&
143          "DictionaryAttr element names must be unique");
144   return Base::get(context, value);
145 }
146 /// Construct a dictionary with an array of values that is known to already be
147 /// sorted by name and uniqued.
148 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
149                                              ArrayRef<NamedAttribute> value) {
150   if (value.empty())
151     return DictionaryAttr::getEmpty(context);
152   // Ensure that the attribute elements are unique and sorted.
153   assert(llvm::is_sorted(value,
154                          [](NamedAttribute l, NamedAttribute r) {
155                            return l.first.strref() < r.first.strref();
156                          }) &&
157          "expected attribute values to be sorted");
158   assert(!findDuplicateElement(value) &&
159          "DictionaryAttr element names must be unique");
160   return Base::get(context, value);
161 }
162 
163 /// Return the specified attribute if present, null otherwise.
164 Attribute DictionaryAttr::get(StringRef name) const {
165   Optional<NamedAttribute> attr = getNamed(name);
166   return attr ? attr->second : nullptr;
167 }
168 Attribute DictionaryAttr::get(Identifier name) const {
169   Optional<NamedAttribute> attr = getNamed(name);
170   return attr ? attr->second : nullptr;
171 }
172 
173 /// Return the specified named attribute if present, None otherwise.
174 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
175   ArrayRef<NamedAttribute> values = getValue();
176   const auto *it = llvm::lower_bound(values, name);
177   return it != values.end() && it->first == name ? *it
178                                                  : Optional<NamedAttribute>();
179 }
180 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
181   for (auto elt : getValue())
182     if (elt.first == name)
183       return elt;
184   return llvm::None;
185 }
186 
187 DictionaryAttr::iterator DictionaryAttr::begin() const {
188   return getValue().begin();
189 }
190 DictionaryAttr::iterator DictionaryAttr::end() const {
191   return getValue().end();
192 }
193 size_t DictionaryAttr::size() const { return getValue().size(); }
194 
195 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
196   return Base::get(context, ArrayRef<NamedAttribute>());
197 }
198 
199 //===----------------------------------------------------------------------===//
200 // FloatAttr
201 //===----------------------------------------------------------------------===//
202 
203 double FloatAttr::getValueAsDouble() const {
204   return getValueAsDouble(getValue());
205 }
206 double FloatAttr::getValueAsDouble(APFloat value) {
207   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
208     bool losesInfo = false;
209     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
210                   &losesInfo);
211   }
212   return value.convertToDouble();
213 }
214 
215 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
216                                 Type type, APFloat value) {
217   // Verify that the type is correct.
218   if (!type.isa<FloatType>())
219     return emitError() << "expected floating point type";
220 
221   // Verify that the type semantics match that of the value.
222   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
223     return emitError()
224            << "FloatAttr type doesn't match the type implied by its value";
225   }
226   return success();
227 }
228 
229 //===----------------------------------------------------------------------===//
230 // SymbolRefAttr
231 //===----------------------------------------------------------------------===//
232 
233 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
234   return get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
235 }
236 
237 StringRef SymbolRefAttr::getLeafReference() const {
238   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
239   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // IntegerAttr
244 //===----------------------------------------------------------------------===//
245 
246 int64_t IntegerAttr::getInt() const {
247   assert((getType().isIndex() || getType().isSignlessInteger()) &&
248          "must be signless integer");
249   return getValue().getSExtValue();
250 }
251 
252 int64_t IntegerAttr::getSInt() const {
253   assert(getType().isSignedInteger() && "must be signed integer");
254   return getValue().getSExtValue();
255 }
256 
257 uint64_t IntegerAttr::getUInt() const {
258   assert(getType().isUnsignedInteger() && "must be unsigned integer");
259   return getValue().getZExtValue();
260 }
261 
262 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
263                                   Type type, APInt value) {
264   if (IntegerType integerType = type.dyn_cast<IntegerType>()) {
265     if (integerType.getWidth() != value.getBitWidth())
266       return emitError() << "integer type bit width (" << integerType.getWidth()
267                          << ") doesn't match value bit width ("
268                          << value.getBitWidth() << ")";
269     return success();
270   }
271   if (type.isa<IndexType>())
272     return success();
273   return emitError() << "expected integer or index type";
274 }
275 
276 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
277   auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
278   return attr.cast<BoolAttr>();
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // BoolAttr
283 
284 bool BoolAttr::getValue() const {
285   auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
286   return storage->value.getBoolValue();
287 }
288 
289 bool BoolAttr::classof(Attribute attr) {
290   IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
291   return intAttr && intAttr.getType().isSignlessInteger(1);
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // OpaqueAttr
296 //===----------------------------------------------------------------------===//
297 
298 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
299                                  Identifier dialect, StringRef attrData,
300                                  Type type) {
301   if (!Dialect::isValidNamespace(dialect.strref()))
302     return emitError() << "invalid dialect namespace '" << dialect << "'";
303 
304   // Check that the dialect is actually registered.
305   MLIRContext *context = dialect.getContext();
306   if (!context->allowsUnregisteredDialects() &&
307       !context->getLoadedDialect(dialect.strref())) {
308     return emitError()
309            << "#" << dialect << "<\"" << attrData << "\"> : " << type
310            << " attribute created with unregistered dialect. If this is "
311               "intended, please call allowUnregisteredDialects() on the "
312               "MLIRContext, or use -allow-unregistered-dialect with "
313               "mlir-opt";
314   }
315 
316   return success();
317 }
318 
319 //===----------------------------------------------------------------------===//
320 // ElementsAttr
321 //===----------------------------------------------------------------------===//
322 
323 ShapedType ElementsAttr::getType() const {
324   return Attribute::getType().cast<ShapedType>();
325 }
326 
327 /// Returns the number of elements held by this attribute.
328 int64_t ElementsAttr::getNumElements() const {
329   return getType().getNumElements();
330 }
331 
332 /// Return the value at the given index. If index does not refer to a valid
333 /// element, then a null attribute is returned.
334 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
335   if (auto denseAttr = dyn_cast<DenseElementsAttr>())
336     return denseAttr.getValue(index);
337   if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>())
338     return opaqueAttr.getValue(index);
339   return cast<SparseElementsAttr>().getValue(index);
340 }
341 
342 /// Return if the given 'index' refers to a valid element in this attribute.
343 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
344   auto type = getType();
345 
346   // Verify that the rank of the indices matches the held type.
347   auto rank = type.getRank();
348   if (rank == 0 && index.size() == 1 && index[0] == 0)
349     return true;
350   if (rank != static_cast<int64_t>(index.size()))
351     return false;
352 
353   // Verify that all of the indices are within the shape dimensions.
354   auto shape = type.getShape();
355   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
356     int64_t dim = static_cast<int64_t>(index[i]);
357     return 0 <= dim && dim < shape[i];
358   });
359 }
360 
361 ElementsAttr
362 ElementsAttr::mapValues(Type newElementType,
363                         function_ref<APInt(const APInt &)> mapping) const {
364   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
365     return intOrFpAttr.mapValues(newElementType, mapping);
366   llvm_unreachable("unsupported ElementsAttr subtype");
367 }
368 
369 ElementsAttr
370 ElementsAttr::mapValues(Type newElementType,
371                         function_ref<APInt(const APFloat &)> mapping) const {
372   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
373     return intOrFpAttr.mapValues(newElementType, mapping);
374   llvm_unreachable("unsupported ElementsAttr subtype");
375 }
376 
377 /// Method for support type inquiry through isa, cast and dyn_cast.
378 bool ElementsAttr::classof(Attribute attr) {
379   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr,
380                   OpaqueElementsAttr, SparseElementsAttr>();
381 }
382 
383 /// Returns the 1 dimensional flattened row-major index from the given
384 /// multi-dimensional index.
385 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
386   assert(isValidIndex(index) && "expected valid multi-dimensional index");
387   auto type = getType();
388 
389   // Reduce the provided multidimensional index into a flattended 1D row-major
390   // index.
391   auto rank = type.getRank();
392   auto shape = type.getShape();
393   uint64_t valueIndex = 0;
394   uint64_t dimMultiplier = 1;
395   for (int i = rank - 1; i >= 0; --i) {
396     valueIndex += index[i] * dimMultiplier;
397     dimMultiplier *= shape[i];
398   }
399   return valueIndex;
400 }
401 
402 //===----------------------------------------------------------------------===//
403 // DenseElementsAttr Utilities
404 //===----------------------------------------------------------------------===//
405 
406 /// Get the bitwidth of a dense element type within the buffer.
407 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
408 static size_t getDenseElementStorageWidth(size_t origWidth) {
409   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
410 }
411 static size_t getDenseElementStorageWidth(Type elementType) {
412   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
413 }
414 
415 /// Set a bit to a specific value.
416 static void setBit(char *rawData, size_t bitPos, bool value) {
417   if (value)
418     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
419   else
420     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
421 }
422 
423 /// Return the value of the specified bit.
424 static bool getBit(const char *rawData, size_t bitPos) {
425   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
426 }
427 
428 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
429 /// BE format.
430 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
431                                          char *result) {
432   assert(llvm::support::endian::system_endianness() == // NOLINT
433          llvm::support::endianness::big);              // NOLINT
434   assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
435 
436   // Copy the words filled with data.
437   // For example, when `value` has 2 words, the first word is filled with data.
438   // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
439   size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
440   std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
441               numFilledWords, result);
442   // Convert last word of APInt to LE format and store it in char
443   // array(`valueLE`).
444   // ex. last word of `value` (BE): |------ij|  ==> `valueLE` (LE): |ji------|
445   size_t lastWordPos = numFilledWords;
446   SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
447   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
448       reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
449       valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
450   // Extract actual APInt data from `valueLE`, convert endianness to BE format,
451   // and store it in `result`.
452   // ex. `valueLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|ij|
453   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
454       valueLE.begin(), result + lastWordPos,
455       (numBytes - lastWordPos) * CHAR_BIT, 1);
456 }
457 
458 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
459 /// format.
460 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
461                                          APInt &result) {
462   assert(llvm::support::endian::system_endianness() == // NOLINT
463          llvm::support::endianness::big);              // NOLINT
464   assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
465 
466   // Copy the data that fills the word of `result` from `inArray`.
467   // For example, when `result` has 2 words, the first word will be filled with
468   // data. So, the first 8 bytes are copied from `inArray` here.
469   // `inArray` (10 bytes, BE): |abcdefgh|ij|
470   //                     ==> `result` (2 words, BE): |abcdefgh|--------|
471   size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
472   std::copy_n(
473       inArray, numFilledWords,
474       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
475 
476   // Convert array data which will be last word of `result` to LE format, and
477   // store it in char array(`inArrayLE`).
478   // ex. `inArray` (last two bytes, BE): |ij|  ==> `inArrayLE` (LE): |ji------|
479   size_t lastWordPos = numFilledWords;
480   SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
481   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
482       inArray + lastWordPos, inArrayLE.begin(),
483       (numBytes - lastWordPos) * CHAR_BIT, 1);
484 
485   // Convert `inArrayLE` to BE format, and store it in last word of `result`.
486   // ex. `inArrayLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|------ij|
487   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
488       inArrayLE.begin(),
489       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
490           lastWordPos,
491       APInt::APINT_BITS_PER_WORD, 1);
492 }
493 
494 /// Writes value to the bit position `bitPos` in array `rawData`.
495 static void writeBits(char *rawData, size_t bitPos, APInt value) {
496   size_t bitWidth = value.getBitWidth();
497 
498   // If the bitwidth is 1 we just toggle the specific bit.
499   if (bitWidth == 1)
500     return setBit(rawData, bitPos, value.isOneValue());
501 
502   // Otherwise, the bit position is guaranteed to be byte aligned.
503   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
504   if (llvm::support::endian::system_endianness() ==
505       llvm::support::endianness::big) {
506     // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
507     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
508     // work correctly in BE format.
509     // ex. `value` (2 words including 10 bytes)
510     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------|
511     copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
512                                  rawData + (bitPos / CHAR_BIT));
513   } else {
514     std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
515                 llvm::divideCeil(bitWidth, CHAR_BIT),
516                 rawData + (bitPos / CHAR_BIT));
517   }
518 }
519 
520 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
521 /// `rawData`.
522 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
523   // Handle a boolean bit position.
524   if (bitWidth == 1)
525     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
526 
527   // Otherwise, the bit position must be 8-bit aligned.
528   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
529   APInt result(bitWidth, 0);
530   if (llvm::support::endian::system_endianness() ==
531       llvm::support::endianness::big) {
532     // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
533     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
534     // work correctly in BE format.
535     // ex. `result` (2 words including 10 bytes)
536     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------| This function
537     copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
538                                  llvm::divideCeil(bitWidth, CHAR_BIT), result);
539   } else {
540     std::copy_n(rawData + (bitPos / CHAR_BIT),
541                 llvm::divideCeil(bitWidth, CHAR_BIT),
542                 const_cast<char *>(
543                     reinterpret_cast<const char *>(result.getRawData())));
544   }
545   return result;
546 }
547 
548 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
549 /// the same element count as 'type'.
550 template <typename Values>
551 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
552   return (values.size() == 1) ||
553          (type.getNumElements() == static_cast<int64_t>(values.size()));
554 }
555 
556 //===----------------------------------------------------------------------===//
557 // DenseElementsAttr Iterators
558 //===----------------------------------------------------------------------===//
559 
560 //===----------------------------------------------------------------------===//
561 // AttributeElementIterator
562 
563 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
564     DenseElementsAttr attr, size_t index)
565     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
566                                       Attribute, Attribute, Attribute>(
567           attr.getAsOpaquePointer(), index) {}
568 
569 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
570   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
571   Type eltTy = owner.getType().getElementType();
572   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
573     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
574   if (eltTy.isa<IndexType>())
575     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
576   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
577     IntElementIterator intIt(owner, index);
578     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
579     return FloatAttr::get(eltTy, *floatIt);
580   }
581   if (owner.isa<DenseStringElementsAttr>()) {
582     ArrayRef<StringRef> vals = owner.getRawStringData();
583     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
584   }
585   llvm_unreachable("unexpected element type");
586 }
587 
588 //===----------------------------------------------------------------------===//
589 // BoolElementIterator
590 
591 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
592     DenseElementsAttr attr, size_t dataIndex)
593     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
594           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
595 
596 bool DenseElementsAttr::BoolElementIterator::operator*() const {
597   return getBit(getData(), getDataIndex());
598 }
599 
600 //===----------------------------------------------------------------------===//
601 // IntElementIterator
602 
603 DenseElementsAttr::IntElementIterator::IntElementIterator(
604     DenseElementsAttr attr, size_t dataIndex)
605     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
606           attr.getRawData().data(), attr.isSplat(), dataIndex),
607       bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
608 
609 APInt DenseElementsAttr::IntElementIterator::operator*() const {
610   return readBits(getData(),
611                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
612                   bitWidth);
613 }
614 
615 //===----------------------------------------------------------------------===//
616 // ComplexIntElementIterator
617 
618 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
619     DenseElementsAttr attr, size_t dataIndex)
620     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
621                                       std::complex<APInt>, std::complex<APInt>,
622                                       std::complex<APInt>>(
623           attr.getRawData().data(), attr.isSplat(), dataIndex) {
624   auto complexType = attr.getType().getElementType().cast<ComplexType>();
625   bitWidth = getDenseElementBitWidth(complexType.getElementType());
626 }
627 
628 std::complex<APInt>
629 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
630   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
631   size_t offset = getDataIndex() * storageWidth * 2;
632   return {readBits(getData(), offset, bitWidth),
633           readBits(getData(), offset + storageWidth, bitWidth)};
634 }
635 
636 //===----------------------------------------------------------------------===//
637 // FloatElementIterator
638 
639 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
640     const llvm::fltSemantics &smt, IntElementIterator it)
641     : llvm::mapped_iterator<IntElementIterator,
642                             std::function<APFloat(const APInt &)>>(
643           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
644 
645 //===----------------------------------------------------------------------===//
646 // ComplexFloatElementIterator
647 
648 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
649     const llvm::fltSemantics &smt, ComplexIntElementIterator it)
650     : llvm::mapped_iterator<
651           ComplexIntElementIterator,
652           std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
653           it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
654             return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
655           }) {}
656 
657 //===----------------------------------------------------------------------===//
658 // DenseElementsAttr
659 //===----------------------------------------------------------------------===//
660 
661 /// Method for support type inquiry through isa, cast and dyn_cast.
662 bool DenseElementsAttr::classof(Attribute attr) {
663   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
664 }
665 
666 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
667                                          ArrayRef<Attribute> values) {
668   assert(hasSameElementsOrSplat(type, values));
669 
670   // If the element type is not based on int/float/index, assume it is a string
671   // type.
672   auto eltType = type.getElementType();
673   if (!type.getElementType().isIntOrIndexOrFloat()) {
674     SmallVector<StringRef, 8> stringValues;
675     stringValues.reserve(values.size());
676     for (Attribute attr : values) {
677       assert(attr.isa<StringAttr>() &&
678              "expected string value for non integer/index/float element");
679       stringValues.push_back(attr.cast<StringAttr>().getValue());
680     }
681     return get(type, stringValues);
682   }
683 
684   // Otherwise, get the raw storage width to use for the allocation.
685   size_t bitWidth = getDenseElementBitWidth(eltType);
686   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
687 
688   // Compress the attribute values into a character buffer.
689   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
690                             values.size());
691   APInt intVal;
692   for (unsigned i = 0, e = values.size(); i < e; ++i) {
693     assert(eltType == values[i].getType() &&
694            "expected attribute value to have element type");
695     if (eltType.isa<FloatType>())
696       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
697     else if (eltType.isa<IntegerType, IndexType>())
698       intVal = values[i].cast<IntegerAttr>().getValue();
699     else
700       llvm_unreachable("unexpected element type");
701 
702     assert(intVal.getBitWidth() == bitWidth &&
703            "expected value to have same bitwidth as element type");
704     writeBits(data.data(), i * storageBitWidth, intVal);
705   }
706   return DenseIntOrFPElementsAttr::getRaw(type, data,
707                                           /*isSplat=*/(values.size() == 1));
708 }
709 
710 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
711                                          ArrayRef<bool> values) {
712   assert(hasSameElementsOrSplat(type, values));
713   assert(type.getElementType().isInteger(1));
714 
715   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
716   for (int i = 0, e = values.size(); i != e; ++i)
717     setBit(buff.data(), i, values[i]);
718   return DenseIntOrFPElementsAttr::getRaw(type, buff,
719                                           /*isSplat=*/(values.size() == 1));
720 }
721 
722 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
723                                          ArrayRef<StringRef> values) {
724   assert(!type.getElementType().isIntOrFloat());
725   return DenseStringElementsAttr::get(type, values);
726 }
727 
728 /// Constructs a dense integer elements attribute from an array of APInt
729 /// values. Each APInt value is expected to have the same bitwidth as the
730 /// element type of 'type'.
731 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
732                                          ArrayRef<APInt> values) {
733   assert(type.getElementType().isIntOrIndex());
734   assert(hasSameElementsOrSplat(type, values));
735   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
736   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
737                                           /*isSplat=*/(values.size() == 1));
738 }
739 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
740                                          ArrayRef<std::complex<APInt>> values) {
741   ComplexType complex = type.getElementType().cast<ComplexType>();
742   assert(complex.getElementType().isa<IntegerType>());
743   assert(hasSameElementsOrSplat(type, values));
744   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
745   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
746                           values.size() * 2);
747   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
748                                           /*isSplat=*/(values.size() == 1));
749 }
750 
751 // Constructs a dense float elements attribute from an array of APFloat
752 // values. Each APFloat value is expected to have the same bitwidth as the
753 // element type of 'type'.
754 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
755                                          ArrayRef<APFloat> values) {
756   assert(type.getElementType().isa<FloatType>());
757   assert(hasSameElementsOrSplat(type, values));
758   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
759   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
760                                           /*isSplat=*/(values.size() == 1));
761 }
762 DenseElementsAttr
763 DenseElementsAttr::get(ShapedType type,
764                        ArrayRef<std::complex<APFloat>> values) {
765   ComplexType complex = type.getElementType().cast<ComplexType>();
766   assert(complex.getElementType().isa<FloatType>());
767   assert(hasSameElementsOrSplat(type, values));
768   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
769                            values.size() * 2);
770   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
771   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
772                                           /*isSplat=*/(values.size() == 1));
773 }
774 
775 /// Construct a dense elements attribute from a raw buffer representing the
776 /// data for this attribute. Users should generally not use this methods as
777 /// the expected buffer format may not be a form the user expects.
778 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
779                                                       ArrayRef<char> rawBuffer,
780                                                       bool isSplatBuffer) {
781   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
782 }
783 
784 /// Returns true if the given buffer is a valid raw buffer for the given type.
785 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
786                                          ArrayRef<char> rawBuffer,
787                                          bool &detectedSplat) {
788   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
789   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
790 
791   // Storage width of 1 is special as it is packed by the bit.
792   if (storageWidth == 1) {
793     // Check for a splat, or a buffer equal to the number of elements.
794     if ((detectedSplat = rawBuffer.size() == 1))
795       return true;
796     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
797   }
798   // All other types are 8-bit aligned.
799   if ((detectedSplat = rawBufferWidth == storageWidth))
800     return true;
801   return rawBufferWidth == (storageWidth * type.getNumElements());
802 }
803 
804 /// Check the information for a C++ data type, check if this type is valid for
805 /// the current attribute. This method is used to verify specific type
806 /// invariants that the templatized 'getValues' method cannot.
807 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
808                               bool isSigned) {
809   // Make sure that the data element size is the same as the type element width.
810   if (getDenseElementBitWidth(type) !=
811       static_cast<size_t>(dataEltSize * CHAR_BIT))
812     return false;
813 
814   // Check that the element type is either float or integer or index.
815   if (!isInt)
816     return type.isa<FloatType>();
817   if (type.isIndex())
818     return true;
819 
820   auto intType = type.dyn_cast<IntegerType>();
821   if (!intType)
822     return false;
823 
824   // Make sure signedness semantics is consistent.
825   if (intType.isSignless())
826     return true;
827   return intType.isSigned() ? isSigned : !isSigned;
828 }
829 
830 /// Defaults down the subclass implementation.
831 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
832                                                    ArrayRef<char> data,
833                                                    int64_t dataEltSize,
834                                                    bool isInt, bool isSigned) {
835   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
836                                                  isSigned);
837 }
838 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
839                                                       ArrayRef<char> data,
840                                                       int64_t dataEltSize,
841                                                       bool isInt,
842                                                       bool isSigned) {
843   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
844                                                     isInt, isSigned);
845 }
846 
847 /// A method used to verify specific type invariants that the templatized 'get'
848 /// method cannot.
849 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
850                                           bool isSigned) const {
851   return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
852                              isSigned);
853 }
854 
855 /// Check the information for a C++ data type, check if this type is valid for
856 /// the current attribute.
857 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
858                                        bool isSigned) const {
859   return ::isValidIntOrFloat(
860       getType().getElementType().cast<ComplexType>().getElementType(),
861       dataEltSize / 2, isInt, isSigned);
862 }
863 
864 /// Returns true if this attribute corresponds to a splat, i.e. if all element
865 /// values are the same.
866 bool DenseElementsAttr::isSplat() const {
867   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
868 }
869 
870 /// Return the held element values as a range of Attributes.
871 auto DenseElementsAttr::getAttributeValues() const
872     -> llvm::iterator_range<AttributeElementIterator> {
873   return {attr_value_begin(), attr_value_end()};
874 }
875 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
876   return AttributeElementIterator(*this, 0);
877 }
878 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
879   return AttributeElementIterator(*this, getNumElements());
880 }
881 
882 /// Return the held element values as a range of bool. The element type of
883 /// this attribute must be of integer type of bitwidth 1.
884 auto DenseElementsAttr::getBoolValues() const
885     -> llvm::iterator_range<BoolElementIterator> {
886   auto eltType = getType().getElementType().dyn_cast<IntegerType>();
887   assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
888   (void)eltType;
889   return {BoolElementIterator(*this, 0),
890           BoolElementIterator(*this, getNumElements())};
891 }
892 
893 /// Return the held element values as a range of APInts. The element type of
894 /// this attribute must be of integer type.
895 auto DenseElementsAttr::getIntValues() const
896     -> llvm::iterator_range<IntElementIterator> {
897   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
898   return {raw_int_begin(), raw_int_end()};
899 }
900 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
901   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
902   return raw_int_begin();
903 }
904 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
905   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
906   return raw_int_end();
907 }
908 auto DenseElementsAttr::getComplexIntValues() const
909     -> llvm::iterator_range<ComplexIntElementIterator> {
910   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
911   (void)eltTy;
912   assert(eltTy.isa<IntegerType>() && "expected complex integral type");
913   return {ComplexIntElementIterator(*this, 0),
914           ComplexIntElementIterator(*this, getNumElements())};
915 }
916 
917 /// Return the held element values as a range of APFloat. The element type of
918 /// this attribute must be of float type.
919 auto DenseElementsAttr::getFloatValues() const
920     -> llvm::iterator_range<FloatElementIterator> {
921   auto elementType = getType().getElementType().cast<FloatType>();
922   const auto &elementSemantics = elementType.getFloatSemantics();
923   return {FloatElementIterator(elementSemantics, raw_int_begin()),
924           FloatElementIterator(elementSemantics, raw_int_end())};
925 }
926 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
927   return getFloatValues().begin();
928 }
929 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
930   return getFloatValues().end();
931 }
932 auto DenseElementsAttr::getComplexFloatValues() const
933     -> llvm::iterator_range<ComplexFloatElementIterator> {
934   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
935   assert(eltTy.isa<FloatType>() && "expected complex float type");
936   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
937   return {{semantics, {*this, 0}},
938           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
939 }
940 
941 /// Return the raw storage data held by this attribute.
942 ArrayRef<char> DenseElementsAttr::getRawData() const {
943   return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
944 }
945 
946 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
947   return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
948 }
949 
950 /// Return a new DenseElementsAttr that has the same data as the current
951 /// attribute, but has been reshaped to 'newType'. The new type must have the
952 /// same total number of elements as well as element type.
953 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
954   ShapedType curType = getType();
955   if (curType == newType)
956     return *this;
957 
958   (void)curType;
959   assert(newType.getElementType() == curType.getElementType() &&
960          "expected the same element type");
961   assert(newType.getNumElements() == curType.getNumElements() &&
962          "expected the same number of elements");
963   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
964 }
965 
966 DenseElementsAttr
967 DenseElementsAttr::mapValues(Type newElementType,
968                              function_ref<APInt(const APInt &)> mapping) const {
969   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
970 }
971 
972 DenseElementsAttr DenseElementsAttr::mapValues(
973     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
974   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
975 }
976 
977 //===----------------------------------------------------------------------===//
978 // DenseIntOrFPElementsAttr
979 //===----------------------------------------------------------------------===//
980 
981 /// Utility method to write a range of APInt values to a buffer.
982 template <typename APRangeT>
983 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
984                                 APRangeT &&values) {
985   data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
986   size_t offset = 0;
987   for (auto it = values.begin(), e = values.end(); it != e;
988        ++it, offset += storageWidth) {
989     assert((*it).getBitWidth() <= storageWidth);
990     writeBits(data.data(), offset, *it);
991   }
992 }
993 
994 /// Constructs a dense elements attribute from an array of raw APFloat values.
995 /// Each APFloat value is expected to have the same bitwidth as the element
996 /// type of 'type'. 'type' must be a vector or tensor with static shape.
997 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
998                                                    size_t storageWidth,
999                                                    ArrayRef<APFloat> values,
1000                                                    bool isSplat) {
1001   std::vector<char> data;
1002   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1003   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1004   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1005 }
1006 
1007 /// Constructs a dense elements attribute from an array of raw APInt values.
1008 /// Each APInt value is expected to have the same bitwidth as the element type
1009 /// of 'type'.
1010 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1011                                                    size_t storageWidth,
1012                                                    ArrayRef<APInt> values,
1013                                                    bool isSplat) {
1014   std::vector<char> data;
1015   writeAPIntsToBuffer(storageWidth, data, values);
1016   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1017 }
1018 
1019 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1020                                                    ArrayRef<char> data,
1021                                                    bool isSplat) {
1022   assert((type.isa<RankedTensorType, VectorType>()) &&
1023          "type must be ranked tensor or vector");
1024   assert(type.hasStaticShape() && "type must have static shape");
1025   return Base::get(type.getContext(), type, data, isSplat);
1026 }
1027 
1028 /// Overload of the raw 'get' method that asserts that the given type is of
1029 /// complex type. This method is used to verify type invariants that the
1030 /// templatized 'get' method cannot.
1031 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1032                                                           ArrayRef<char> data,
1033                                                           int64_t dataEltSize,
1034                                                           bool isInt,
1035                                                           bool isSigned) {
1036   assert(::isValidIntOrFloat(
1037       type.getElementType().cast<ComplexType>().getElementType(),
1038       dataEltSize / 2, isInt, isSigned));
1039 
1040   int64_t numElements = data.size() / dataEltSize;
1041   assert(numElements == 1 || numElements == type.getNumElements());
1042   return getRaw(type, data, /*isSplat=*/numElements == 1);
1043 }
1044 
1045 /// Overload of the 'getRaw' method that asserts that the given type is of
1046 /// integer type. This method is used to verify type invariants that the
1047 /// templatized 'get' method cannot.
1048 DenseElementsAttr
1049 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1050                                            int64_t dataEltSize, bool isInt,
1051                                            bool isSigned) {
1052   assert(
1053       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1054 
1055   int64_t numElements = data.size() / dataEltSize;
1056   assert(numElements == 1 || numElements == type.getNumElements());
1057   return getRaw(type, data, /*isSplat=*/numElements == 1);
1058 }
1059 
1060 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1061     const char *inRawData, char *outRawData, size_t elementBitWidth,
1062     size_t numElements) {
1063   using llvm::support::ulittle16_t;
1064   using llvm::support::ulittle32_t;
1065   using llvm::support::ulittle64_t;
1066 
1067   assert(llvm::support::endian::system_endianness() == // NOLINT
1068          llvm::support::endianness::big);              // NOLINT
1069   // NOLINT to avoid warning message about replacing by static_assert()
1070 
1071   // Following std::copy_n always converts endianness on BE machine.
1072   switch (elementBitWidth) {
1073   case 16: {
1074     const ulittle16_t *inRawDataPos =
1075         reinterpret_cast<const ulittle16_t *>(inRawData);
1076     uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1077     std::copy_n(inRawDataPos, numElements, outDataPos);
1078     break;
1079   }
1080   case 32: {
1081     const ulittle32_t *inRawDataPos =
1082         reinterpret_cast<const ulittle32_t *>(inRawData);
1083     uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1084     std::copy_n(inRawDataPos, numElements, outDataPos);
1085     break;
1086   }
1087   case 64: {
1088     const ulittle64_t *inRawDataPos =
1089         reinterpret_cast<const ulittle64_t *>(inRawData);
1090     uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1091     std::copy_n(inRawDataPos, numElements, outDataPos);
1092     break;
1093   }
1094   default: {
1095     size_t nBytes = elementBitWidth / CHAR_BIT;
1096     for (size_t i = 0; i < nBytes; i++)
1097       std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1098     break;
1099   }
1100   }
1101 }
1102 
1103 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1104     ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1105     ShapedType type) {
1106   size_t numElements = type.getNumElements();
1107   Type elementType = type.getElementType();
1108   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1109     elementType = complexTy.getElementType();
1110     numElements = numElements * 2;
1111   }
1112   size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1113   assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1114          inRawData.size() <= outRawData.size());
1115   convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1116                                   elementBitWidth, numElements);
1117 }
1118 
1119 //===----------------------------------------------------------------------===//
1120 // DenseFPElementsAttr
1121 //===----------------------------------------------------------------------===//
1122 
1123 template <typename Fn, typename Attr>
1124 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1125                                 Type newElementType,
1126                                 llvm::SmallVectorImpl<char> &data) {
1127   size_t bitWidth = getDenseElementBitWidth(newElementType);
1128   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1129 
1130   ShapedType newArrayType;
1131   if (inType.isa<RankedTensorType>())
1132     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1133   else if (inType.isa<UnrankedTensorType>())
1134     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1135   else if (inType.isa<VectorType>())
1136     newArrayType = VectorType::get(inType.getShape(), newElementType);
1137   else
1138     assert(newArrayType && "Unhandled tensor type");
1139 
1140   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1141   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1142 
1143   // Functor used to process a single element value of the attribute.
1144   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1145     auto newInt = mapping(value);
1146     assert(newInt.getBitWidth() == bitWidth);
1147     writeBits(data.data(), index * storageBitWidth, newInt);
1148   };
1149 
1150   // Check for the splat case.
1151   if (attr.isSplat()) {
1152     processElt(*attr.begin(), /*index=*/0);
1153     return newArrayType;
1154   }
1155 
1156   // Otherwise, process all of the element values.
1157   uint64_t elementIdx = 0;
1158   for (auto value : attr)
1159     processElt(value, elementIdx++);
1160   return newArrayType;
1161 }
1162 
1163 DenseElementsAttr DenseFPElementsAttr::mapValues(
1164     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1165   llvm::SmallVector<char, 8> elementData;
1166   auto newArrayType =
1167       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1168 
1169   return getRaw(newArrayType, elementData, isSplat());
1170 }
1171 
1172 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1173 bool DenseFPElementsAttr::classof(Attribute attr) {
1174   return attr.isa<DenseElementsAttr>() &&
1175          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1176 }
1177 
1178 //===----------------------------------------------------------------------===//
1179 // DenseIntElementsAttr
1180 //===----------------------------------------------------------------------===//
1181 
1182 DenseElementsAttr DenseIntElementsAttr::mapValues(
1183     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1184   llvm::SmallVector<char, 8> elementData;
1185   auto newArrayType =
1186       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1187 
1188   return getRaw(newArrayType, elementData, isSplat());
1189 }
1190 
1191 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1192 bool DenseIntElementsAttr::classof(Attribute attr) {
1193   return attr.isa<DenseElementsAttr>() &&
1194          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1195 }
1196 
1197 //===----------------------------------------------------------------------===//
1198 // OpaqueElementsAttr
1199 //===----------------------------------------------------------------------===//
1200 
1201 /// Return the value at the given index. If index does not refer to a valid
1202 /// element, then a null attribute is returned.
1203 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1204   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1205   return Attribute();
1206 }
1207 
1208 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1209   Dialect *dialect = getDialect().getDialect();
1210   if (!dialect)
1211     return true;
1212   auto *interface =
1213       dialect->getRegisteredInterface<DialectDecodeAttributesInterface>();
1214   if (!interface)
1215     return true;
1216   return failed(interface->decode(*this, result));
1217 }
1218 
1219 LogicalResult
1220 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1221                            Identifier dialect, StringRef value,
1222                            ShapedType type) {
1223   if (!Dialect::isValidNamespace(dialect.strref()))
1224     return emitError() << "invalid dialect namespace '" << dialect << "'";
1225   return success();
1226 }
1227 
1228 //===----------------------------------------------------------------------===//
1229 // SparseElementsAttr
1230 //===----------------------------------------------------------------------===//
1231 
1232 /// Return the value of the element at the given index.
1233 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1234   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1235   auto type = getType();
1236 
1237   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1238   // as a 1-D index array.
1239   auto sparseIndices = getIndices();
1240   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1241 
1242   // Check to see if the indices are a splat.
1243   if (sparseIndices.isSplat()) {
1244     // If the index is also not a splat of the index value, we know that the
1245     // value is zero.
1246     auto splatIndex = *sparseIndexValues.begin();
1247     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
1248       return getZeroAttr();
1249 
1250     // If the indices are a splat, we also expect the values to be a splat.
1251     assert(getValues().isSplat() && "expected splat values");
1252     return getValues().getSplatValue();
1253   }
1254 
1255   // Build a mapping between known indices and the offset of the stored element.
1256   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
1257   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1258   size_t rank = type.getRank();
1259   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1260     mappedIndices.try_emplace(
1261         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1262 
1263   // Look for the provided index key within the mapped indices. If the provided
1264   // index is not found, then return a zero attribute.
1265   auto it = mappedIndices.find(index);
1266   if (it == mappedIndices.end())
1267     return getZeroAttr();
1268 
1269   // Otherwise, return the held sparse value element.
1270   return getValues().getValue(it->second);
1271 }
1272 
1273 /// Get a zero APFloat for the given sparse attribute.
1274 APFloat SparseElementsAttr::getZeroAPFloat() const {
1275   auto eltType = getType().getElementType().cast<FloatType>();
1276   return APFloat(eltType.getFloatSemantics());
1277 }
1278 
1279 /// Get a zero APInt for the given sparse attribute.
1280 APInt SparseElementsAttr::getZeroAPInt() const {
1281   auto eltType = getType().getElementType().cast<IntegerType>();
1282   return APInt::getNullValue(eltType.getWidth());
1283 }
1284 
1285 /// Get a zero attribute for the given attribute type.
1286 Attribute SparseElementsAttr::getZeroAttr() const {
1287   auto eltType = getType().getElementType();
1288 
1289   // Handle floating point elements.
1290   if (eltType.isa<FloatType>())
1291     return FloatAttr::get(eltType, 0);
1292 
1293   // Otherwise, this is an integer.
1294   // TODO: Handle StringAttr here.
1295   return IntegerAttr::get(eltType, 0);
1296 }
1297 
1298 /// Flatten, and return, all of the sparse indices in this attribute in
1299 /// row-major order.
1300 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1301   std::vector<ptrdiff_t> flatSparseIndices;
1302 
1303   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1304   // as a 1-D index array.
1305   auto sparseIndices = getIndices();
1306   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1307   if (sparseIndices.isSplat()) {
1308     SmallVector<uint64_t, 8> indices(getType().getRank(),
1309                                      *sparseIndexValues.begin());
1310     flatSparseIndices.push_back(getFlattenedIndex(indices));
1311     return flatSparseIndices;
1312   }
1313 
1314   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1315   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1316   size_t rank = getType().getRank();
1317   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1318     flatSparseIndices.push_back(getFlattenedIndex(
1319         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1320   return flatSparseIndices;
1321 }
1322