1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BuiltinDialect.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Types.h"
17 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/Twine.h"
20 #include "llvm/Support/Endian.h"
21 
22 using namespace mlir;
23 using namespace mlir::detail;
24 
25 //===----------------------------------------------------------------------===//
26 /// Tablegen Attribute Definitions
27 //===----------------------------------------------------------------------===//
28 
29 #define GET_ATTRDEF_CLASSES
30 #include "mlir/IR/BuiltinAttributes.cpp.inc"
31 
32 //===----------------------------------------------------------------------===//
33 // BuiltinDialect
34 //===----------------------------------------------------------------------===//
35 
36 void BuiltinDialect::registerAttributes() {
37   addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
38                 DenseStringElementsAttr, DictionaryAttr, FloatAttr,
39                 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
40                 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
41                 UnitAttr>();
42 }
43 
44 //===----------------------------------------------------------------------===//
45 // DictionaryAttr
46 //===----------------------------------------------------------------------===//
47 
48 /// Helper function that does either an in place sort or sorts from source array
49 /// into destination. If inPlace then storage is both the source and the
50 /// destination, else value is the source and storage destination. Returns
51 /// whether source was sorted.
52 template <bool inPlace>
53 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
54                                SmallVectorImpl<NamedAttribute> &storage) {
55   // Specialize for the common case.
56   switch (value.size()) {
57   case 0:
58     // Zero already sorted.
59     break;
60   case 1:
61     // One already sorted but may need to be copied.
62     if (!inPlace)
63       storage.assign({value[0]});
64     break;
65   case 2: {
66     bool isSorted = value[0] < value[1];
67     if (inPlace) {
68       if (!isSorted)
69         std::swap(storage[0], storage[1]);
70     } else if (isSorted) {
71       storage.assign({value[0], value[1]});
72     } else {
73       storage.assign({value[1], value[0]});
74     }
75     return !isSorted;
76   }
77   default:
78     if (!inPlace)
79       storage.assign(value.begin(), value.end());
80     // Check to see they are sorted already.
81     bool isSorted = llvm::is_sorted(value);
82     if (!isSorted) {
83       // If not, do a general sort.
84       llvm::array_pod_sort(storage.begin(), storage.end());
85       value = storage;
86     }
87     return !isSorted;
88   }
89   return false;
90 }
91 
92 /// Returns an entry with a duplicate name from the given sorted array of named
93 /// attributes. Returns llvm::None if all elements have unique names.
94 static Optional<NamedAttribute>
95 findDuplicateElement(ArrayRef<NamedAttribute> value) {
96   const Optional<NamedAttribute> none{llvm::None};
97   if (value.size() < 2)
98     return none;
99 
100   if (value.size() == 2)
101     return value[0].first == value[1].first ? value[0] : none;
102 
103   auto it = std::adjacent_find(
104       value.begin(), value.end(),
105       [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; });
106   return it != value.end() ? *it : none;
107 }
108 
109 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
110                           SmallVectorImpl<NamedAttribute> &storage) {
111   bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
112   assert(!findDuplicateElement(storage) &&
113          "DictionaryAttr element names must be unique");
114   return isSorted;
115 }
116 
117 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
118   bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
119   assert(!findDuplicateElement(array) &&
120          "DictionaryAttr element names must be unique");
121   return isSorted;
122 }
123 
124 Optional<NamedAttribute>
125 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
126                               bool isSorted) {
127   if (!isSorted)
128     dictionaryAttrSort</*inPlace=*/true>(array, array);
129   return findDuplicateElement(array);
130 }
131 
132 DictionaryAttr DictionaryAttr::get(MLIRContext *context,
133                                    ArrayRef<NamedAttribute> value) {
134   if (value.empty())
135     return DictionaryAttr::getEmpty(context);
136   assert(llvm::all_of(value,
137                       [](const NamedAttribute &attr) { return attr.second; }) &&
138          "value cannot have null entries");
139 
140   // We need to sort the element list to canonicalize it.
141   SmallVector<NamedAttribute, 8> storage;
142   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
143     value = storage;
144   assert(!findDuplicateElement(value) &&
145          "DictionaryAttr element names must be unique");
146   return Base::get(context, value);
147 }
148 /// Construct a dictionary with an array of values that is known to already be
149 /// sorted by name and uniqued.
150 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
151                                              ArrayRef<NamedAttribute> value) {
152   if (value.empty())
153     return DictionaryAttr::getEmpty(context);
154   // Ensure that the attribute elements are unique and sorted.
155   assert(llvm::is_sorted(value,
156                          [](NamedAttribute l, NamedAttribute r) {
157                            return l.first.strref() < r.first.strref();
158                          }) &&
159          "expected attribute values to be sorted");
160   assert(!findDuplicateElement(value) &&
161          "DictionaryAttr element names must be unique");
162   return Base::get(context, value);
163 }
164 
165 /// Return the specified attribute if present, null otherwise.
166 Attribute DictionaryAttr::get(StringRef name) const {
167   Optional<NamedAttribute> attr = getNamed(name);
168   return attr ? attr->second : nullptr;
169 }
170 Attribute DictionaryAttr::get(Identifier name) const {
171   Optional<NamedAttribute> attr = getNamed(name);
172   return attr ? attr->second : nullptr;
173 }
174 
175 /// Return the specified named attribute if present, None otherwise.
176 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
177   ArrayRef<NamedAttribute> values = getValue();
178   const auto *it = llvm::lower_bound(values, name);
179   return it != values.end() && it->first == name ? *it
180                                                  : Optional<NamedAttribute>();
181 }
182 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
183   for (auto elt : getValue())
184     if (elt.first == name)
185       return elt;
186   return llvm::None;
187 }
188 
189 DictionaryAttr::iterator DictionaryAttr::begin() const {
190   return getValue().begin();
191 }
192 DictionaryAttr::iterator DictionaryAttr::end() const {
193   return getValue().end();
194 }
195 size_t DictionaryAttr::size() const { return getValue().size(); }
196 
197 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
198   return Base::get(context, ArrayRef<NamedAttribute>());
199 }
200 
201 //===----------------------------------------------------------------------===//
202 // FloatAttr
203 //===----------------------------------------------------------------------===//
204 
205 double FloatAttr::getValueAsDouble() const {
206   return getValueAsDouble(getValue());
207 }
208 double FloatAttr::getValueAsDouble(APFloat value) {
209   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
210     bool losesInfo = false;
211     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
212                   &losesInfo);
213   }
214   return value.convertToDouble();
215 }
216 
217 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
218                                 Type type, APFloat value) {
219   // Verify that the type is correct.
220   if (!type.isa<FloatType>())
221     return emitError() << "expected floating point type";
222 
223   // Verify that the type semantics match that of the value.
224   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
225     return emitError()
226            << "FloatAttr type doesn't match the type implied by its value";
227   }
228   return success();
229 }
230 
231 //===----------------------------------------------------------------------===//
232 // SymbolRefAttr
233 //===----------------------------------------------------------------------===//
234 
235 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
236   return get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
237 }
238 
239 StringRef SymbolRefAttr::getLeafReference() const {
240   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
241   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // IntegerAttr
246 //===----------------------------------------------------------------------===//
247 
248 int64_t IntegerAttr::getInt() const {
249   assert((getType().isIndex() || getType().isSignlessInteger()) &&
250          "must be signless integer");
251   return getValue().getSExtValue();
252 }
253 
254 int64_t IntegerAttr::getSInt() const {
255   assert(getType().isSignedInteger() && "must be signed integer");
256   return getValue().getSExtValue();
257 }
258 
259 uint64_t IntegerAttr::getUInt() const {
260   assert(getType().isUnsignedInteger() && "must be unsigned integer");
261   return getValue().getZExtValue();
262 }
263 
264 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
265                                   Type type, APInt value) {
266   if (IntegerType integerType = type.dyn_cast<IntegerType>()) {
267     if (integerType.getWidth() != value.getBitWidth())
268       return emitError() << "integer type bit width (" << integerType.getWidth()
269                          << ") doesn't match value bit width ("
270                          << value.getBitWidth() << ")";
271     return success();
272   }
273   if (type.isa<IndexType>())
274     return success();
275   return emitError() << "expected integer or index type";
276 }
277 
278 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
279   auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
280   return attr.cast<BoolAttr>();
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // BoolAttr
285 
286 bool BoolAttr::getValue() const {
287   auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
288   return storage->value.getBoolValue();
289 }
290 
291 bool BoolAttr::classof(Attribute attr) {
292   IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
293   return intAttr && intAttr.getType().isSignlessInteger(1);
294 }
295 
296 //===----------------------------------------------------------------------===//
297 // OpaqueAttr
298 //===----------------------------------------------------------------------===//
299 
300 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
301                                  Identifier dialect, StringRef attrData,
302                                  Type type) {
303   if (!Dialect::isValidNamespace(dialect.strref()))
304     return emitError() << "invalid dialect namespace '" << dialect << "'";
305   return success();
306 }
307 
308 //===----------------------------------------------------------------------===//
309 // ElementsAttr
310 //===----------------------------------------------------------------------===//
311 
312 ShapedType ElementsAttr::getType() const {
313   return Attribute::getType().cast<ShapedType>();
314 }
315 
316 /// Returns the number of elements held by this attribute.
317 int64_t ElementsAttr::getNumElements() const {
318   return getType().getNumElements();
319 }
320 
321 /// Return the value at the given index. If index does not refer to a valid
322 /// element, then a null attribute is returned.
323 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
324   if (auto denseAttr = dyn_cast<DenseElementsAttr>())
325     return denseAttr.getValue(index);
326   if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>())
327     return opaqueAttr.getValue(index);
328   return cast<SparseElementsAttr>().getValue(index);
329 }
330 
331 /// Return if the given 'index' refers to a valid element in this attribute.
332 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
333   auto type = getType();
334 
335   // Verify that the rank of the indices matches the held type.
336   auto rank = type.getRank();
337   if (rank == 0 && index.size() == 1 && index[0] == 0)
338     return true;
339   if (rank != static_cast<int64_t>(index.size()))
340     return false;
341 
342   // Verify that all of the indices are within the shape dimensions.
343   auto shape = type.getShape();
344   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
345     return static_cast<int64_t>(index[i]) < shape[i];
346   });
347 }
348 
349 ElementsAttr
350 ElementsAttr::mapValues(Type newElementType,
351                         function_ref<APInt(const APInt &)> mapping) const {
352   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
353     return intOrFpAttr.mapValues(newElementType, mapping);
354   llvm_unreachable("unsupported ElementsAttr subtype");
355 }
356 
357 ElementsAttr
358 ElementsAttr::mapValues(Type newElementType,
359                         function_ref<APInt(const APFloat &)> mapping) const {
360   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
361     return intOrFpAttr.mapValues(newElementType, mapping);
362   llvm_unreachable("unsupported ElementsAttr subtype");
363 }
364 
365 /// Method for support type inquiry through isa, cast and dyn_cast.
366 bool ElementsAttr::classof(Attribute attr) {
367   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr,
368                   OpaqueElementsAttr, SparseElementsAttr>();
369 }
370 
371 /// Returns the 1 dimensional flattened row-major index from the given
372 /// multi-dimensional index.
373 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
374   assert(isValidIndex(index) && "expected valid multi-dimensional index");
375   auto type = getType();
376 
377   // Reduce the provided multidimensional index into a flattended 1D row-major
378   // index.
379   auto rank = type.getRank();
380   auto shape = type.getShape();
381   uint64_t valueIndex = 0;
382   uint64_t dimMultiplier = 1;
383   for (int i = rank - 1; i >= 0; --i) {
384     valueIndex += index[i] * dimMultiplier;
385     dimMultiplier *= shape[i];
386   }
387   return valueIndex;
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // DenseElementsAttr Utilities
392 //===----------------------------------------------------------------------===//
393 
394 /// Get the bitwidth of a dense element type within the buffer.
395 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
396 static size_t getDenseElementStorageWidth(size_t origWidth) {
397   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
398 }
399 static size_t getDenseElementStorageWidth(Type elementType) {
400   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
401 }
402 
403 /// Set a bit to a specific value.
404 static void setBit(char *rawData, size_t bitPos, bool value) {
405   if (value)
406     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
407   else
408     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
409 }
410 
411 /// Return the value of the specified bit.
412 static bool getBit(const char *rawData, size_t bitPos) {
413   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
414 }
415 
416 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
417 /// BE format.
418 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
419                                          char *result) {
420   assert(llvm::support::endian::system_endianness() == // NOLINT
421          llvm::support::endianness::big);              // NOLINT
422   assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
423 
424   // Copy the words filled with data.
425   // For example, when `value` has 2 words, the first word is filled with data.
426   // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
427   size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
428   std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
429               numFilledWords, result);
430   // Convert last word of APInt to LE format and store it in char
431   // array(`valueLE`).
432   // ex. last word of `value` (BE): |------ij|  ==> `valueLE` (LE): |ji------|
433   size_t lastWordPos = numFilledWords;
434   SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
435   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
436       reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
437       valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
438   // Extract actual APInt data from `valueLE`, convert endianness to BE format,
439   // and store it in `result`.
440   // ex. `valueLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|ij|
441   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
442       valueLE.begin(), result + lastWordPos,
443       (numBytes - lastWordPos) * CHAR_BIT, 1);
444 }
445 
446 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
447 /// format.
448 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
449                                          APInt &result) {
450   assert(llvm::support::endian::system_endianness() == // NOLINT
451          llvm::support::endianness::big);              // NOLINT
452   assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
453 
454   // Copy the data that fills the word of `result` from `inArray`.
455   // For example, when `result` has 2 words, the first word will be filled with
456   // data. So, the first 8 bytes are copied from `inArray` here.
457   // `inArray` (10 bytes, BE): |abcdefgh|ij|
458   //                     ==> `result` (2 words, BE): |abcdefgh|--------|
459   size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
460   std::copy_n(
461       inArray, numFilledWords,
462       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
463 
464   // Convert array data which will be last word of `result` to LE format, and
465   // store it in char array(`inArrayLE`).
466   // ex. `inArray` (last two bytes, BE): |ij|  ==> `inArrayLE` (LE): |ji------|
467   size_t lastWordPos = numFilledWords;
468   SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
469   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
470       inArray + lastWordPos, inArrayLE.begin(),
471       (numBytes - lastWordPos) * CHAR_BIT, 1);
472 
473   // Convert `inArrayLE` to BE format, and store it in last word of `result`.
474   // ex. `inArrayLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|------ij|
475   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
476       inArrayLE.begin(),
477       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
478           lastWordPos,
479       APInt::APINT_BITS_PER_WORD, 1);
480 }
481 
482 /// Writes value to the bit position `bitPos` in array `rawData`.
483 static void writeBits(char *rawData, size_t bitPos, APInt value) {
484   size_t bitWidth = value.getBitWidth();
485 
486   // If the bitwidth is 1 we just toggle the specific bit.
487   if (bitWidth == 1)
488     return setBit(rawData, bitPos, value.isOneValue());
489 
490   // Otherwise, the bit position is guaranteed to be byte aligned.
491   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
492   if (llvm::support::endian::system_endianness() ==
493       llvm::support::endianness::big) {
494     // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
495     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
496     // work correctly in BE format.
497     // ex. `value` (2 words including 10 bytes)
498     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------|
499     copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
500                                  rawData + (bitPos / CHAR_BIT));
501   } else {
502     std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
503                 llvm::divideCeil(bitWidth, CHAR_BIT),
504                 rawData + (bitPos / CHAR_BIT));
505   }
506 }
507 
508 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
509 /// `rawData`.
510 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
511   // Handle a boolean bit position.
512   if (bitWidth == 1)
513     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
514 
515   // Otherwise, the bit position must be 8-bit aligned.
516   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
517   APInt result(bitWidth, 0);
518   if (llvm::support::endian::system_endianness() ==
519       llvm::support::endianness::big) {
520     // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
521     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
522     // work correctly in BE format.
523     // ex. `result` (2 words including 10 bytes)
524     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------| This function
525     copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
526                                  llvm::divideCeil(bitWidth, CHAR_BIT), result);
527   } else {
528     std::copy_n(rawData + (bitPos / CHAR_BIT),
529                 llvm::divideCeil(bitWidth, CHAR_BIT),
530                 const_cast<char *>(
531                     reinterpret_cast<const char *>(result.getRawData())));
532   }
533   return result;
534 }
535 
536 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
537 /// the same element count as 'type'.
538 template <typename Values>
539 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
540   return (values.size() == 1) ||
541          (type.getNumElements() == static_cast<int64_t>(values.size()));
542 }
543 
544 //===----------------------------------------------------------------------===//
545 // DenseElementsAttr Iterators
546 //===----------------------------------------------------------------------===//
547 
548 //===----------------------------------------------------------------------===//
549 // AttributeElementIterator
550 
551 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
552     DenseElementsAttr attr, size_t index)
553     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
554                                       Attribute, Attribute, Attribute>(
555           attr.getAsOpaquePointer(), index) {}
556 
557 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
558   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
559   Type eltTy = owner.getType().getElementType();
560   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
561     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
562   if (eltTy.isa<IndexType>())
563     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
564   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
565     IntElementIterator intIt(owner, index);
566     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
567     return FloatAttr::get(eltTy, *floatIt);
568   }
569   if (owner.isa<DenseStringElementsAttr>()) {
570     ArrayRef<StringRef> vals = owner.getRawStringData();
571     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
572   }
573   llvm_unreachable("unexpected element type");
574 }
575 
576 //===----------------------------------------------------------------------===//
577 // BoolElementIterator
578 
579 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
580     DenseElementsAttr attr, size_t dataIndex)
581     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
582           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
583 
584 bool DenseElementsAttr::BoolElementIterator::operator*() const {
585   return getBit(getData(), getDataIndex());
586 }
587 
588 //===----------------------------------------------------------------------===//
589 // IntElementIterator
590 
591 DenseElementsAttr::IntElementIterator::IntElementIterator(
592     DenseElementsAttr attr, size_t dataIndex)
593     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
594           attr.getRawData().data(), attr.isSplat(), dataIndex),
595       bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
596 
597 APInt DenseElementsAttr::IntElementIterator::operator*() const {
598   return readBits(getData(),
599                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
600                   bitWidth);
601 }
602 
603 //===----------------------------------------------------------------------===//
604 // ComplexIntElementIterator
605 
606 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
607     DenseElementsAttr attr, size_t dataIndex)
608     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
609                                       std::complex<APInt>, std::complex<APInt>,
610                                       std::complex<APInt>>(
611           attr.getRawData().data(), attr.isSplat(), dataIndex) {
612   auto complexType = attr.getType().getElementType().cast<ComplexType>();
613   bitWidth = getDenseElementBitWidth(complexType.getElementType());
614 }
615 
616 std::complex<APInt>
617 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
618   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
619   size_t offset = getDataIndex() * storageWidth * 2;
620   return {readBits(getData(), offset, bitWidth),
621           readBits(getData(), offset + storageWidth, bitWidth)};
622 }
623 
624 //===----------------------------------------------------------------------===//
625 // FloatElementIterator
626 
627 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
628     const llvm::fltSemantics &smt, IntElementIterator it)
629     : llvm::mapped_iterator<IntElementIterator,
630                             std::function<APFloat(const APInt &)>>(
631           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
632 
633 //===----------------------------------------------------------------------===//
634 // ComplexFloatElementIterator
635 
636 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
637     const llvm::fltSemantics &smt, ComplexIntElementIterator it)
638     : llvm::mapped_iterator<
639           ComplexIntElementIterator,
640           std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
641           it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
642             return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
643           }) {}
644 
645 //===----------------------------------------------------------------------===//
646 // DenseElementsAttr
647 //===----------------------------------------------------------------------===//
648 
649 /// Method for support type inquiry through isa, cast and dyn_cast.
650 bool DenseElementsAttr::classof(Attribute attr) {
651   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
652 }
653 
654 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
655                                          ArrayRef<Attribute> values) {
656   assert(hasSameElementsOrSplat(type, values));
657 
658   // If the element type is not based on int/float/index, assume it is a string
659   // type.
660   auto eltType = type.getElementType();
661   if (!type.getElementType().isIntOrIndexOrFloat()) {
662     SmallVector<StringRef, 8> stringValues;
663     stringValues.reserve(values.size());
664     for (Attribute attr : values) {
665       assert(attr.isa<StringAttr>() &&
666              "expected string value for non integer/index/float element");
667       stringValues.push_back(attr.cast<StringAttr>().getValue());
668     }
669     return get(type, stringValues);
670   }
671 
672   // Otherwise, get the raw storage width to use for the allocation.
673   size_t bitWidth = getDenseElementBitWidth(eltType);
674   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
675 
676   // Compress the attribute values into a character buffer.
677   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
678                             values.size());
679   APInt intVal;
680   for (unsigned i = 0, e = values.size(); i < e; ++i) {
681     assert(eltType == values[i].getType() &&
682            "expected attribute value to have element type");
683     if (eltType.isa<FloatType>())
684       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
685     else if (eltType.isa<IntegerType>())
686       intVal = values[i].cast<IntegerAttr>().getValue();
687     else
688       llvm_unreachable("unexpected element type");
689 
690     assert(intVal.getBitWidth() == bitWidth &&
691            "expected value to have same bitwidth as element type");
692     writeBits(data.data(), i * storageBitWidth, intVal);
693   }
694   return DenseIntOrFPElementsAttr::getRaw(type, data,
695                                           /*isSplat=*/(values.size() == 1));
696 }
697 
698 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
699                                          ArrayRef<bool> values) {
700   assert(hasSameElementsOrSplat(type, values));
701   assert(type.getElementType().isInteger(1));
702 
703   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
704   for (int i = 0, e = values.size(); i != e; ++i)
705     setBit(buff.data(), i, values[i]);
706   return DenseIntOrFPElementsAttr::getRaw(type, buff,
707                                           /*isSplat=*/(values.size() == 1));
708 }
709 
710 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
711                                          ArrayRef<StringRef> values) {
712   assert(!type.getElementType().isIntOrFloat());
713   return DenseStringElementsAttr::get(type, values);
714 }
715 
716 /// Constructs a dense integer elements attribute from an array of APInt
717 /// values. Each APInt value is expected to have the same bitwidth as the
718 /// element type of 'type'.
719 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
720                                          ArrayRef<APInt> values) {
721   assert(type.getElementType().isIntOrIndex());
722   assert(hasSameElementsOrSplat(type, values));
723   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
724   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
725                                           /*isSplat=*/(values.size() == 1));
726 }
727 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
728                                          ArrayRef<std::complex<APInt>> values) {
729   ComplexType complex = type.getElementType().cast<ComplexType>();
730   assert(complex.getElementType().isa<IntegerType>());
731   assert(hasSameElementsOrSplat(type, values));
732   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
733   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
734                           values.size() * 2);
735   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
736                                           /*isSplat=*/(values.size() == 1));
737 }
738 
739 // Constructs a dense float elements attribute from an array of APFloat
740 // values. Each APFloat value is expected to have the same bitwidth as the
741 // element type of 'type'.
742 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
743                                          ArrayRef<APFloat> values) {
744   assert(type.getElementType().isa<FloatType>());
745   assert(hasSameElementsOrSplat(type, values));
746   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
747   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
748                                           /*isSplat=*/(values.size() == 1));
749 }
750 DenseElementsAttr
751 DenseElementsAttr::get(ShapedType type,
752                        ArrayRef<std::complex<APFloat>> values) {
753   ComplexType complex = type.getElementType().cast<ComplexType>();
754   assert(complex.getElementType().isa<FloatType>());
755   assert(hasSameElementsOrSplat(type, values));
756   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
757                            values.size() * 2);
758   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
759   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
760                                           /*isSplat=*/(values.size() == 1));
761 }
762 
763 /// Construct a dense elements attribute from a raw buffer representing the
764 /// data for this attribute. Users should generally not use this methods as
765 /// the expected buffer format may not be a form the user expects.
766 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
767                                                       ArrayRef<char> rawBuffer,
768                                                       bool isSplatBuffer) {
769   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
770 }
771 
772 /// Returns true if the given buffer is a valid raw buffer for the given type.
773 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
774                                          ArrayRef<char> rawBuffer,
775                                          bool &detectedSplat) {
776   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
777   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
778 
779   // Storage width of 1 is special as it is packed by the bit.
780   if (storageWidth == 1) {
781     // Check for a splat, or a buffer equal to the number of elements.
782     if ((detectedSplat = rawBuffer.size() == 1))
783       return true;
784     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
785   }
786   // All other types are 8-bit aligned.
787   if ((detectedSplat = rawBufferWidth == storageWidth))
788     return true;
789   return rawBufferWidth == (storageWidth * type.getNumElements());
790 }
791 
792 /// Check the information for a C++ data type, check if this type is valid for
793 /// the current attribute. This method is used to verify specific type
794 /// invariants that the templatized 'getValues' method cannot.
795 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
796                               bool isSigned) {
797   // Make sure that the data element size is the same as the type element width.
798   if (getDenseElementBitWidth(type) !=
799       static_cast<size_t>(dataEltSize * CHAR_BIT))
800     return false;
801 
802   // Check that the element type is either float or integer or index.
803   if (!isInt)
804     return type.isa<FloatType>();
805   if (type.isIndex())
806     return true;
807 
808   auto intType = type.dyn_cast<IntegerType>();
809   if (!intType)
810     return false;
811 
812   // Make sure signedness semantics is consistent.
813   if (intType.isSignless())
814     return true;
815   return intType.isSigned() ? isSigned : !isSigned;
816 }
817 
818 /// Defaults down the subclass implementation.
819 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
820                                                    ArrayRef<char> data,
821                                                    int64_t dataEltSize,
822                                                    bool isInt, bool isSigned) {
823   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
824                                                  isSigned);
825 }
826 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
827                                                       ArrayRef<char> data,
828                                                       int64_t dataEltSize,
829                                                       bool isInt,
830                                                       bool isSigned) {
831   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
832                                                     isInt, isSigned);
833 }
834 
835 /// A method used to verify specific type invariants that the templatized 'get'
836 /// method cannot.
837 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
838                                           bool isSigned) const {
839   return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
840                              isSigned);
841 }
842 
843 /// Check the information for a C++ data type, check if this type is valid for
844 /// the current attribute.
845 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
846                                        bool isSigned) const {
847   return ::isValidIntOrFloat(
848       getType().getElementType().cast<ComplexType>().getElementType(),
849       dataEltSize / 2, isInt, isSigned);
850 }
851 
852 /// Returns true if this attribute corresponds to a splat, i.e. if all element
853 /// values are the same.
854 bool DenseElementsAttr::isSplat() const {
855   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
856 }
857 
858 /// Return the held element values as a range of Attributes.
859 auto DenseElementsAttr::getAttributeValues() const
860     -> llvm::iterator_range<AttributeElementIterator> {
861   return {attr_value_begin(), attr_value_end()};
862 }
863 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
864   return AttributeElementIterator(*this, 0);
865 }
866 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
867   return AttributeElementIterator(*this, getNumElements());
868 }
869 
870 /// Return the held element values as a range of bool. The element type of
871 /// this attribute must be of integer type of bitwidth 1.
872 auto DenseElementsAttr::getBoolValues() const
873     -> llvm::iterator_range<BoolElementIterator> {
874   auto eltType = getType().getElementType().dyn_cast<IntegerType>();
875   assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
876   (void)eltType;
877   return {BoolElementIterator(*this, 0),
878           BoolElementIterator(*this, getNumElements())};
879 }
880 
881 /// Return the held element values as a range of APInts. The element type of
882 /// this attribute must be of integer type.
883 auto DenseElementsAttr::getIntValues() const
884     -> llvm::iterator_range<IntElementIterator> {
885   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
886   return {raw_int_begin(), raw_int_end()};
887 }
888 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
889   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
890   return raw_int_begin();
891 }
892 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
893   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
894   return raw_int_end();
895 }
896 auto DenseElementsAttr::getComplexIntValues() const
897     -> llvm::iterator_range<ComplexIntElementIterator> {
898   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
899   (void)eltTy;
900   assert(eltTy.isa<IntegerType>() && "expected complex integral type");
901   return {ComplexIntElementIterator(*this, 0),
902           ComplexIntElementIterator(*this, getNumElements())};
903 }
904 
905 /// Return the held element values as a range of APFloat. The element type of
906 /// this attribute must be of float type.
907 auto DenseElementsAttr::getFloatValues() const
908     -> llvm::iterator_range<FloatElementIterator> {
909   auto elementType = getType().getElementType().cast<FloatType>();
910   const auto &elementSemantics = elementType.getFloatSemantics();
911   return {FloatElementIterator(elementSemantics, raw_int_begin()),
912           FloatElementIterator(elementSemantics, raw_int_end())};
913 }
914 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
915   return getFloatValues().begin();
916 }
917 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
918   return getFloatValues().end();
919 }
920 auto DenseElementsAttr::getComplexFloatValues() const
921     -> llvm::iterator_range<ComplexFloatElementIterator> {
922   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
923   assert(eltTy.isa<FloatType>() && "expected complex float type");
924   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
925   return {{semantics, {*this, 0}},
926           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
927 }
928 
929 /// Return the raw storage data held by this attribute.
930 ArrayRef<char> DenseElementsAttr::getRawData() const {
931   return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
932 }
933 
934 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
935   return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
936 }
937 
938 /// Return a new DenseElementsAttr that has the same data as the current
939 /// attribute, but has been reshaped to 'newType'. The new type must have the
940 /// same total number of elements as well as element type.
941 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
942   ShapedType curType = getType();
943   if (curType == newType)
944     return *this;
945 
946   (void)curType;
947   assert(newType.getElementType() == curType.getElementType() &&
948          "expected the same element type");
949   assert(newType.getNumElements() == curType.getNumElements() &&
950          "expected the same number of elements");
951   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
952 }
953 
954 DenseElementsAttr
955 DenseElementsAttr::mapValues(Type newElementType,
956                              function_ref<APInt(const APInt &)> mapping) const {
957   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
958 }
959 
960 DenseElementsAttr DenseElementsAttr::mapValues(
961     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
962   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
963 }
964 
965 //===----------------------------------------------------------------------===//
966 // DenseIntOrFPElementsAttr
967 //===----------------------------------------------------------------------===//
968 
969 /// Utility method to write a range of APInt values to a buffer.
970 template <typename APRangeT>
971 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
972                                 APRangeT &&values) {
973   data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
974   size_t offset = 0;
975   for (auto it = values.begin(), e = values.end(); it != e;
976        ++it, offset += storageWidth) {
977     assert((*it).getBitWidth() <= storageWidth);
978     writeBits(data.data(), offset, *it);
979   }
980 }
981 
982 /// Constructs a dense elements attribute from an array of raw APFloat values.
983 /// Each APFloat value is expected to have the same bitwidth as the element
984 /// type of 'type'. 'type' must be a vector or tensor with static shape.
985 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
986                                                    size_t storageWidth,
987                                                    ArrayRef<APFloat> values,
988                                                    bool isSplat) {
989   std::vector<char> data;
990   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
991   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
992   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
993 }
994 
995 /// Constructs a dense elements attribute from an array of raw APInt values.
996 /// Each APInt value is expected to have the same bitwidth as the element type
997 /// of 'type'.
998 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
999                                                    size_t storageWidth,
1000                                                    ArrayRef<APInt> values,
1001                                                    bool isSplat) {
1002   std::vector<char> data;
1003   writeAPIntsToBuffer(storageWidth, data, values);
1004   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1005 }
1006 
1007 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1008                                                    ArrayRef<char> data,
1009                                                    bool isSplat) {
1010   assert((type.isa<RankedTensorType, VectorType>()) &&
1011          "type must be ranked tensor or vector");
1012   assert(type.hasStaticShape() && "type must have static shape");
1013   return Base::get(type.getContext(), type, data, isSplat);
1014 }
1015 
1016 /// Overload of the raw 'get' method that asserts that the given type is of
1017 /// complex type. This method is used to verify type invariants that the
1018 /// templatized 'get' method cannot.
1019 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1020                                                           ArrayRef<char> data,
1021                                                           int64_t dataEltSize,
1022                                                           bool isInt,
1023                                                           bool isSigned) {
1024   assert(::isValidIntOrFloat(
1025       type.getElementType().cast<ComplexType>().getElementType(),
1026       dataEltSize / 2, isInt, isSigned));
1027 
1028   int64_t numElements = data.size() / dataEltSize;
1029   assert(numElements == 1 || numElements == type.getNumElements());
1030   return getRaw(type, data, /*isSplat=*/numElements == 1);
1031 }
1032 
1033 /// Overload of the 'getRaw' method that asserts that the given type is of
1034 /// integer type. This method is used to verify type invariants that the
1035 /// templatized 'get' method cannot.
1036 DenseElementsAttr
1037 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1038                                            int64_t dataEltSize, bool isInt,
1039                                            bool isSigned) {
1040   assert(
1041       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1042 
1043   int64_t numElements = data.size() / dataEltSize;
1044   assert(numElements == 1 || numElements == type.getNumElements());
1045   return getRaw(type, data, /*isSplat=*/numElements == 1);
1046 }
1047 
1048 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1049     const char *inRawData, char *outRawData, size_t elementBitWidth,
1050     size_t numElements) {
1051   using llvm::support::ulittle16_t;
1052   using llvm::support::ulittle32_t;
1053   using llvm::support::ulittle64_t;
1054 
1055   assert(llvm::support::endian::system_endianness() == // NOLINT
1056          llvm::support::endianness::big);              // NOLINT
1057   // NOLINT to avoid warning message about replacing by static_assert()
1058 
1059   // Following std::copy_n always converts endianness on BE machine.
1060   switch (elementBitWidth) {
1061   case 16: {
1062     const ulittle16_t *inRawDataPos =
1063         reinterpret_cast<const ulittle16_t *>(inRawData);
1064     uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1065     std::copy_n(inRawDataPos, numElements, outDataPos);
1066     break;
1067   }
1068   case 32: {
1069     const ulittle32_t *inRawDataPos =
1070         reinterpret_cast<const ulittle32_t *>(inRawData);
1071     uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1072     std::copy_n(inRawDataPos, numElements, outDataPos);
1073     break;
1074   }
1075   case 64: {
1076     const ulittle64_t *inRawDataPos =
1077         reinterpret_cast<const ulittle64_t *>(inRawData);
1078     uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1079     std::copy_n(inRawDataPos, numElements, outDataPos);
1080     break;
1081   }
1082   default: {
1083     size_t nBytes = elementBitWidth / CHAR_BIT;
1084     for (size_t i = 0; i < nBytes; i++)
1085       std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1086     break;
1087   }
1088   }
1089 }
1090 
1091 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1092     ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1093     ShapedType type) {
1094   size_t numElements = type.getNumElements();
1095   Type elementType = type.getElementType();
1096   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1097     elementType = complexTy.getElementType();
1098     numElements = numElements * 2;
1099   }
1100   size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1101   assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1102          inRawData.size() <= outRawData.size());
1103   convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1104                                   elementBitWidth, numElements);
1105 }
1106 
1107 //===----------------------------------------------------------------------===//
1108 // DenseFPElementsAttr
1109 //===----------------------------------------------------------------------===//
1110 
1111 template <typename Fn, typename Attr>
1112 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1113                                 Type newElementType,
1114                                 llvm::SmallVectorImpl<char> &data) {
1115   size_t bitWidth = getDenseElementBitWidth(newElementType);
1116   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1117 
1118   ShapedType newArrayType;
1119   if (inType.isa<RankedTensorType>())
1120     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1121   else if (inType.isa<UnrankedTensorType>())
1122     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1123   else if (inType.isa<VectorType>())
1124     newArrayType = VectorType::get(inType.getShape(), newElementType);
1125   else
1126     assert(newArrayType && "Unhandled tensor type");
1127 
1128   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1129   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1130 
1131   // Functor used to process a single element value of the attribute.
1132   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1133     auto newInt = mapping(value);
1134     assert(newInt.getBitWidth() == bitWidth);
1135     writeBits(data.data(), index * storageBitWidth, newInt);
1136   };
1137 
1138   // Check for the splat case.
1139   if (attr.isSplat()) {
1140     processElt(*attr.begin(), /*index=*/0);
1141     return newArrayType;
1142   }
1143 
1144   // Otherwise, process all of the element values.
1145   uint64_t elementIdx = 0;
1146   for (auto value : attr)
1147     processElt(value, elementIdx++);
1148   return newArrayType;
1149 }
1150 
1151 DenseElementsAttr DenseFPElementsAttr::mapValues(
1152     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1153   llvm::SmallVector<char, 8> elementData;
1154   auto newArrayType =
1155       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1156 
1157   return getRaw(newArrayType, elementData, isSplat());
1158 }
1159 
1160 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1161 bool DenseFPElementsAttr::classof(Attribute attr) {
1162   return attr.isa<DenseElementsAttr>() &&
1163          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1164 }
1165 
1166 //===----------------------------------------------------------------------===//
1167 // DenseIntElementsAttr
1168 //===----------------------------------------------------------------------===//
1169 
1170 DenseElementsAttr DenseIntElementsAttr::mapValues(
1171     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1172   llvm::SmallVector<char, 8> elementData;
1173   auto newArrayType =
1174       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1175 
1176   return getRaw(newArrayType, elementData, isSplat());
1177 }
1178 
1179 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1180 bool DenseIntElementsAttr::classof(Attribute attr) {
1181   return attr.isa<DenseElementsAttr>() &&
1182          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1183 }
1184 
1185 //===----------------------------------------------------------------------===//
1186 // OpaqueElementsAttr
1187 //===----------------------------------------------------------------------===//
1188 
1189 /// Return the value at the given index. If index does not refer to a valid
1190 /// element, then a null attribute is returned.
1191 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1192   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1193   return Attribute();
1194 }
1195 
1196 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1197   Dialect *dialect = getDialect().getDialect();
1198   if (!dialect)
1199     return true;
1200   auto *interface =
1201       dialect->getRegisteredInterface<DialectDecodeAttributesInterface>();
1202   if (!interface)
1203     return true;
1204   return failed(interface->decode(*this, result));
1205 }
1206 
1207 LogicalResult
1208 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1209                            Identifier dialect, StringRef value,
1210                            ShapedType type) {
1211   if (!Dialect::isValidNamespace(dialect.strref()))
1212     return emitError() << "invalid dialect namespace '" << dialect << "'";
1213   return success();
1214 }
1215 
1216 //===----------------------------------------------------------------------===//
1217 // SparseElementsAttr
1218 //===----------------------------------------------------------------------===//
1219 
1220 /// Return the value of the element at the given index.
1221 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1222   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1223   auto type = getType();
1224 
1225   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1226   // as a 1-D index array.
1227   auto sparseIndices = getIndices();
1228   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1229 
1230   // Check to see if the indices are a splat.
1231   if (sparseIndices.isSplat()) {
1232     // If the index is also not a splat of the index value, we know that the
1233     // value is zero.
1234     auto splatIndex = *sparseIndexValues.begin();
1235     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
1236       return getZeroAttr();
1237 
1238     // If the indices are a splat, we also expect the values to be a splat.
1239     assert(getValues().isSplat() && "expected splat values");
1240     return getValues().getSplatValue();
1241   }
1242 
1243   // Build a mapping between known indices and the offset of the stored element.
1244   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
1245   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1246   size_t rank = type.getRank();
1247   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1248     mappedIndices.try_emplace(
1249         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1250 
1251   // Look for the provided index key within the mapped indices. If the provided
1252   // index is not found, then return a zero attribute.
1253   auto it = mappedIndices.find(index);
1254   if (it == mappedIndices.end())
1255     return getZeroAttr();
1256 
1257   // Otherwise, return the held sparse value element.
1258   return getValues().getValue(it->second);
1259 }
1260 
1261 /// Get a zero APFloat for the given sparse attribute.
1262 APFloat SparseElementsAttr::getZeroAPFloat() const {
1263   auto eltType = getType().getElementType().cast<FloatType>();
1264   return APFloat(eltType.getFloatSemantics());
1265 }
1266 
1267 /// Get a zero APInt for the given sparse attribute.
1268 APInt SparseElementsAttr::getZeroAPInt() const {
1269   auto eltType = getType().getElementType().cast<IntegerType>();
1270   return APInt::getNullValue(eltType.getWidth());
1271 }
1272 
1273 /// Get a zero attribute for the given attribute type.
1274 Attribute SparseElementsAttr::getZeroAttr() const {
1275   auto eltType = getType().getElementType();
1276 
1277   // Handle floating point elements.
1278   if (eltType.isa<FloatType>())
1279     return FloatAttr::get(eltType, 0);
1280 
1281   // Otherwise, this is an integer.
1282   // TODO: Handle StringAttr here.
1283   return IntegerAttr::get(eltType, 0);
1284 }
1285 
1286 /// Flatten, and return, all of the sparse indices in this attribute in
1287 /// row-major order.
1288 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1289   std::vector<ptrdiff_t> flatSparseIndices;
1290 
1291   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1292   // as a 1-D index array.
1293   auto sparseIndices = getIndices();
1294   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1295   if (sparseIndices.isSplat()) {
1296     SmallVector<uint64_t, 8> indices(getType().getRank(),
1297                                      *sparseIndexValues.begin());
1298     flatSparseIndices.push_back(getFlattenedIndex(indices));
1299     return flatSparseIndices;
1300   }
1301 
1302   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1303   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1304   size_t rank = getType().getRank();
1305   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1306     flatSparseIndices.push_back(getFlattenedIndex(
1307         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1308   return flatSparseIndices;
1309 }
1310