1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BuiltinDialect.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/IntegerSet.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/SymbolTable.h"
17 #include "mlir/IR/Types.h"
18 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
19 #include "llvm/ADT/APSInt.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/Support/Endian.h"
22 
23 using namespace mlir;
24 using namespace mlir::detail;
25 
26 //===----------------------------------------------------------------------===//
27 /// Tablegen Attribute Definitions
28 //===----------------------------------------------------------------------===//
29 
30 #define GET_ATTRDEF_CLASSES
31 #include "mlir/IR/BuiltinAttributes.cpp.inc"
32 
33 //===----------------------------------------------------------------------===//
34 // BuiltinDialect
35 //===----------------------------------------------------------------------===//
36 
37 void BuiltinDialect::registerAttributes() {
38   addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
39                 DenseStringElementsAttr, DictionaryAttr, FloatAttr,
40                 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
41                 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
42                 UnitAttr>();
43 }
44 
45 //===----------------------------------------------------------------------===//
46 // ArrayAttr
47 //===----------------------------------------------------------------------===//
48 
49 void ArrayAttr::walkImmediateSubElements(
50     function_ref<void(Attribute)> walkAttrsFn,
51     function_ref<void(Type)> walkTypesFn) const {
52   for (Attribute attr : getValue())
53     walkAttrsFn(attr);
54 }
55 
56 SubElementAttrInterface ArrayAttr::replaceImmediateSubAttribute(
57     ArrayRef<std::pair<size_t, Attribute>> replacements) const {
58   std::vector<Attribute> vector = getValue().vec();
59   for (auto &it : replacements) {
60     vector[it.first] = it.second;
61   }
62   return get(getContext(), vector);
63 }
64 
65 //===----------------------------------------------------------------------===//
66 // DictionaryAttr
67 //===----------------------------------------------------------------------===//
68 
69 /// Helper function that does either an in place sort or sorts from source array
70 /// into destination. If inPlace then storage is both the source and the
71 /// destination, else value is the source and storage destination. Returns
72 /// whether source was sorted.
73 template <bool inPlace>
74 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
75                                SmallVectorImpl<NamedAttribute> &storage) {
76   // Specialize for the common case.
77   switch (value.size()) {
78   case 0:
79     // Zero already sorted.
80     if (!inPlace)
81       storage.clear();
82     break;
83   case 1:
84     // One already sorted but may need to be copied.
85     if (!inPlace)
86       storage.assign({value[0]});
87     break;
88   case 2: {
89     bool isSorted = value[0] < value[1];
90     if (inPlace) {
91       if (!isSorted)
92         std::swap(storage[0], storage[1]);
93     } else if (isSorted) {
94       storage.assign({value[0], value[1]});
95     } else {
96       storage.assign({value[1], value[0]});
97     }
98     return !isSorted;
99   }
100   default:
101     if (!inPlace)
102       storage.assign(value.begin(), value.end());
103     // Check to see they are sorted already.
104     bool isSorted = llvm::is_sorted(value);
105     // If not, do a general sort.
106     if (!isSorted)
107       llvm::array_pod_sort(storage.begin(), storage.end());
108     return !isSorted;
109   }
110   return false;
111 }
112 
113 /// Returns an entry with a duplicate name from the given sorted array of named
114 /// attributes. Returns llvm::None if all elements have unique names.
115 static Optional<NamedAttribute>
116 findDuplicateElement(ArrayRef<NamedAttribute> value) {
117   const Optional<NamedAttribute> none{llvm::None};
118   if (value.size() < 2)
119     return none;
120 
121   if (value.size() == 2)
122     return value[0].first == value[1].first ? value[0] : none;
123 
124   auto it = std::adjacent_find(
125       value.begin(), value.end(),
126       [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; });
127   return it != value.end() ? *it : none;
128 }
129 
130 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
131                           SmallVectorImpl<NamedAttribute> &storage) {
132   bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
133   assert(!findDuplicateElement(storage) &&
134          "DictionaryAttr element names must be unique");
135   return isSorted;
136 }
137 
138 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
139   bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
140   assert(!findDuplicateElement(array) &&
141          "DictionaryAttr element names must be unique");
142   return isSorted;
143 }
144 
145 Optional<NamedAttribute>
146 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
147                               bool isSorted) {
148   if (!isSorted)
149     dictionaryAttrSort</*inPlace=*/true>(array, array);
150   return findDuplicateElement(array);
151 }
152 
153 DictionaryAttr DictionaryAttr::get(MLIRContext *context,
154                                    ArrayRef<NamedAttribute> value) {
155   if (value.empty())
156     return DictionaryAttr::getEmpty(context);
157   assert(llvm::all_of(value,
158                       [](const NamedAttribute &attr) { return attr.second; }) &&
159          "value cannot have null entries");
160 
161   // We need to sort the element list to canonicalize it.
162   SmallVector<NamedAttribute, 8> storage;
163   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
164     value = storage;
165   assert(!findDuplicateElement(value) &&
166          "DictionaryAttr element names must be unique");
167   return Base::get(context, value);
168 }
169 /// Construct a dictionary with an array of values that is known to already be
170 /// sorted by name and uniqued.
171 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
172                                              ArrayRef<NamedAttribute> value) {
173   if (value.empty())
174     return DictionaryAttr::getEmpty(context);
175   // Ensure that the attribute elements are unique and sorted.
176   assert(llvm::is_sorted(value,
177                          [](NamedAttribute l, NamedAttribute r) {
178                            return l.first.strref() < r.first.strref();
179                          }) &&
180          "expected attribute values to be sorted");
181   assert(!findDuplicateElement(value) &&
182          "DictionaryAttr element names must be unique");
183   return Base::get(context, value);
184 }
185 
186 /// Return the specified attribute if present, null otherwise.
187 Attribute DictionaryAttr::get(StringRef name) const {
188   auto it = impl::findAttrSorted(begin(), end(), name);
189   return it.second ? it.first->second : Attribute();
190 }
191 Attribute DictionaryAttr::get(Identifier name) const {
192   auto it = impl::findAttrSorted(begin(), end(), name);
193   return it.second ? it.first->second : Attribute();
194 }
195 
196 /// Return the specified named attribute if present, None otherwise.
197 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
198   auto it = impl::findAttrSorted(begin(), end(), name);
199   return it.second ? *it.first : Optional<NamedAttribute>();
200 }
201 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
202   auto it = impl::findAttrSorted(begin(), end(), name);
203   return it.second ? *it.first : Optional<NamedAttribute>();
204 }
205 
206 /// Return whether the specified attribute is present.
207 bool DictionaryAttr::contains(StringRef name) const {
208   return impl::findAttrSorted(begin(), end(), name).second;
209 }
210 bool DictionaryAttr::contains(Identifier name) const {
211   return impl::findAttrSorted(begin(), end(), name).second;
212 }
213 
214 DictionaryAttr::iterator DictionaryAttr::begin() const {
215   return getValue().begin();
216 }
217 DictionaryAttr::iterator DictionaryAttr::end() const {
218   return getValue().end();
219 }
220 size_t DictionaryAttr::size() const { return getValue().size(); }
221 
222 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
223   return Base::get(context, ArrayRef<NamedAttribute>());
224 }
225 
226 void DictionaryAttr::walkImmediateSubElements(
227     function_ref<void(Attribute)> walkAttrsFn,
228     function_ref<void(Type)> walkTypesFn) const {
229   for (Attribute attr : llvm::make_second_range(getValue()))
230     walkAttrsFn(attr);
231 }
232 
233 SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute(
234     ArrayRef<std::pair<size_t, Attribute>> replacements) const {
235   std::vector<NamedAttribute> vec = getValue().vec();
236   for (auto &it : replacements) {
237     vec[it.first].second = it.second;
238   }
239   // The above only modifies the mapped value, but not the key, and therefore
240   // not the order of the elements. It remains sorted
241   return getWithSorted(getContext(), vec);
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // StringAttr
246 //===----------------------------------------------------------------------===//
247 
248 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) {
249   return Base::get(context, "", NoneType::get(context));
250 }
251 
252 /// Twine support for StringAttr.
253 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) {
254   // Fast-path empty twine.
255   if (twine.isTriviallyEmpty())
256     return get(context);
257   SmallVector<char, 32> tempStr;
258   return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context));
259 }
260 
261 /// Twine support for StringAttr.
262 StringAttr StringAttr::get(const Twine &twine, Type type) {
263   SmallVector<char, 32> tempStr;
264   return Base::get(type.getContext(), twine.toStringRef(tempStr), type);
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // FloatAttr
269 //===----------------------------------------------------------------------===//
270 
271 double FloatAttr::getValueAsDouble() const {
272   return getValueAsDouble(getValue());
273 }
274 double FloatAttr::getValueAsDouble(APFloat value) {
275   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
276     bool losesInfo = false;
277     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
278                   &losesInfo);
279   }
280   return value.convertToDouble();
281 }
282 
283 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
284                                 Type type, APFloat value) {
285   // Verify that the type is correct.
286   if (!type.isa<FloatType>())
287     return emitError() << "expected floating point type";
288 
289   // Verify that the type semantics match that of the value.
290   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
291     return emitError()
292            << "FloatAttr type doesn't match the type implied by its value";
293   }
294   return success();
295 }
296 
297 //===----------------------------------------------------------------------===//
298 // SymbolRefAttr
299 //===----------------------------------------------------------------------===//
300 
301 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
302                                  ArrayRef<FlatSymbolRefAttr> nestedRefs) {
303   return get(StringAttr::get(ctx, value), nestedRefs);
304 }
305 
306 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
307   return get(ctx, value, {}).cast<FlatSymbolRefAttr>();
308 }
309 
310 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
311   return get(value, {}).cast<FlatSymbolRefAttr>();
312 }
313 
314 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) {
315   auto symName =
316       symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
317   assert(symName && "value does not have a valid symbol name");
318   return SymbolRefAttr::get(symName);
319 }
320 
321 StringAttr SymbolRefAttr::getLeafReference() const {
322   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
323   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
324 }
325 
326 //===----------------------------------------------------------------------===//
327 // IntegerAttr
328 //===----------------------------------------------------------------------===//
329 
330 int64_t IntegerAttr::getInt() const {
331   assert((getType().isIndex() || getType().isSignlessInteger()) &&
332          "must be signless integer");
333   return getValue().getSExtValue();
334 }
335 
336 int64_t IntegerAttr::getSInt() const {
337   assert(getType().isSignedInteger() && "must be signed integer");
338   return getValue().getSExtValue();
339 }
340 
341 uint64_t IntegerAttr::getUInt() const {
342   assert(getType().isUnsignedInteger() && "must be unsigned integer");
343   return getValue().getZExtValue();
344 }
345 
346 /// Return the value as an APSInt which carries the signed from the type of
347 /// the attribute.  This traps on signless integers types!
348 APSInt IntegerAttr::getAPSInt() const {
349   assert(!getType().isSignlessInteger() &&
350          "Signless integers don't carry a sign for APSInt");
351   return APSInt(getValue(), getType().isUnsignedInteger());
352 }
353 
354 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
355                                   Type type, APInt value) {
356   if (IntegerType integerType = type.dyn_cast<IntegerType>()) {
357     if (integerType.getWidth() != value.getBitWidth())
358       return emitError() << "integer type bit width (" << integerType.getWidth()
359                          << ") doesn't match value bit width ("
360                          << value.getBitWidth() << ")";
361     return success();
362   }
363   if (type.isa<IndexType>())
364     return success();
365   return emitError() << "expected integer or index type";
366 }
367 
368 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
369   auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
370   return attr.cast<BoolAttr>();
371 }
372 
373 //===----------------------------------------------------------------------===//
374 // BoolAttr
375 
376 bool BoolAttr::getValue() const {
377   auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
378   return storage->value.getBoolValue();
379 }
380 
381 bool BoolAttr::classof(Attribute attr) {
382   IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
383   return intAttr && intAttr.getType().isSignlessInteger(1);
384 }
385 
386 //===----------------------------------------------------------------------===//
387 // OpaqueAttr
388 //===----------------------------------------------------------------------===//
389 
390 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
391                                  Identifier dialect, StringRef attrData,
392                                  Type type) {
393   if (!Dialect::isValidNamespace(dialect.strref()))
394     return emitError() << "invalid dialect namespace '" << dialect << "'";
395 
396   // Check that the dialect is actually registered.
397   MLIRContext *context = dialect.getContext();
398   if (!context->allowsUnregisteredDialects() &&
399       !context->getLoadedDialect(dialect.strref())) {
400     return emitError()
401            << "#" << dialect << "<\"" << attrData << "\"> : " << type
402            << " attribute created with unregistered dialect. If this is "
403               "intended, please call allowUnregisteredDialects() on the "
404               "MLIRContext, or use -allow-unregistered-dialect with "
405               "the MLIR opt tool used";
406   }
407 
408   return success();
409 }
410 
411 //===----------------------------------------------------------------------===//
412 // DenseElementsAttr Utilities
413 //===----------------------------------------------------------------------===//
414 
415 /// Get the bitwidth of a dense element type within the buffer.
416 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
417 static size_t getDenseElementStorageWidth(size_t origWidth) {
418   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
419 }
420 static size_t getDenseElementStorageWidth(Type elementType) {
421   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
422 }
423 
424 /// Set a bit to a specific value.
425 static void setBit(char *rawData, size_t bitPos, bool value) {
426   if (value)
427     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
428   else
429     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
430 }
431 
432 /// Return the value of the specified bit.
433 static bool getBit(const char *rawData, size_t bitPos) {
434   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
435 }
436 
437 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
438 /// BE format.
439 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
440                                          char *result) {
441   assert(llvm::support::endian::system_endianness() == // NOLINT
442          llvm::support::endianness::big);              // NOLINT
443   assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
444 
445   // Copy the words filled with data.
446   // For example, when `value` has 2 words, the first word is filled with data.
447   // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
448   size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
449   std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
450               numFilledWords, result);
451   // Convert last word of APInt to LE format and store it in char
452   // array(`valueLE`).
453   // ex. last word of `value` (BE): |------ij|  ==> `valueLE` (LE): |ji------|
454   size_t lastWordPos = numFilledWords;
455   SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
456   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
457       reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
458       valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
459   // Extract actual APInt data from `valueLE`, convert endianness to BE format,
460   // and store it in `result`.
461   // ex. `valueLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|ij|
462   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
463       valueLE.begin(), result + lastWordPos,
464       (numBytes - lastWordPos) * CHAR_BIT, 1);
465 }
466 
467 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
468 /// format.
469 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
470                                          APInt &result) {
471   assert(llvm::support::endian::system_endianness() == // NOLINT
472          llvm::support::endianness::big);              // NOLINT
473   assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
474 
475   // Copy the data that fills the word of `result` from `inArray`.
476   // For example, when `result` has 2 words, the first word will be filled with
477   // data. So, the first 8 bytes are copied from `inArray` here.
478   // `inArray` (10 bytes, BE): |abcdefgh|ij|
479   //                     ==> `result` (2 words, BE): |abcdefgh|--------|
480   size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
481   std::copy_n(
482       inArray, numFilledWords,
483       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
484 
485   // Convert array data which will be last word of `result` to LE format, and
486   // store it in char array(`inArrayLE`).
487   // ex. `inArray` (last two bytes, BE): |ij|  ==> `inArrayLE` (LE): |ji------|
488   size_t lastWordPos = numFilledWords;
489   SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
490   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
491       inArray + lastWordPos, inArrayLE.begin(),
492       (numBytes - lastWordPos) * CHAR_BIT, 1);
493 
494   // Convert `inArrayLE` to BE format, and store it in last word of `result`.
495   // ex. `inArrayLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|------ij|
496   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
497       inArrayLE.begin(),
498       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
499           lastWordPos,
500       APInt::APINT_BITS_PER_WORD, 1);
501 }
502 
503 /// Writes value to the bit position `bitPos` in array `rawData`.
504 static void writeBits(char *rawData, size_t bitPos, APInt value) {
505   size_t bitWidth = value.getBitWidth();
506 
507   // If the bitwidth is 1 we just toggle the specific bit.
508   if (bitWidth == 1)
509     return setBit(rawData, bitPos, value.isOneValue());
510 
511   // Otherwise, the bit position is guaranteed to be byte aligned.
512   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
513   if (llvm::support::endian::system_endianness() ==
514       llvm::support::endianness::big) {
515     // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
516     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
517     // work correctly in BE format.
518     // ex. `value` (2 words including 10 bytes)
519     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------|
520     copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
521                                  rawData + (bitPos / CHAR_BIT));
522   } else {
523     std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
524                 llvm::divideCeil(bitWidth, CHAR_BIT),
525                 rawData + (bitPos / CHAR_BIT));
526   }
527 }
528 
529 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
530 /// `rawData`.
531 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
532   // Handle a boolean bit position.
533   if (bitWidth == 1)
534     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
535 
536   // Otherwise, the bit position must be 8-bit aligned.
537   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
538   APInt result(bitWidth, 0);
539   if (llvm::support::endian::system_endianness() ==
540       llvm::support::endianness::big) {
541     // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
542     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
543     // work correctly in BE format.
544     // ex. `result` (2 words including 10 bytes)
545     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------| This function
546     copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
547                                  llvm::divideCeil(bitWidth, CHAR_BIT), result);
548   } else {
549     std::copy_n(rawData + (bitPos / CHAR_BIT),
550                 llvm::divideCeil(bitWidth, CHAR_BIT),
551                 const_cast<char *>(
552                     reinterpret_cast<const char *>(result.getRawData())));
553   }
554   return result;
555 }
556 
557 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
558 /// the same element count as 'type'.
559 template <typename Values>
560 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
561   return (values.size() == 1) ||
562          (type.getNumElements() == static_cast<int64_t>(values.size()));
563 }
564 
565 //===----------------------------------------------------------------------===//
566 // DenseElementsAttr Iterators
567 //===----------------------------------------------------------------------===//
568 
569 //===----------------------------------------------------------------------===//
570 // AttributeElementIterator
571 
572 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
573     DenseElementsAttr attr, size_t index)
574     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
575                                       Attribute, Attribute, Attribute>(
576           attr.getAsOpaquePointer(), index) {}
577 
578 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
579   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
580   Type eltTy = owner.getElementType();
581   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
582     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
583   if (eltTy.isa<IndexType>())
584     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
585   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
586     IntElementIterator intIt(owner, index);
587     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
588     return FloatAttr::get(eltTy, *floatIt);
589   }
590   if (auto complexTy = eltTy.dyn_cast<ComplexType>()) {
591     auto complexEltTy = complexTy.getElementType();
592     ComplexIntElementIterator complexIntIt(owner, index);
593     if (complexEltTy.isa<IntegerType>()) {
594       auto value = *complexIntIt;
595       auto real = IntegerAttr::get(complexEltTy, value.real());
596       auto imag = IntegerAttr::get(complexEltTy, value.imag());
597       return ArrayAttr::get(complexTy.getContext(),
598                             ArrayRef<Attribute>{real, imag});
599     }
600 
601     ComplexFloatElementIterator complexFloatIt(
602         complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt);
603     auto value = *complexFloatIt;
604     auto real = FloatAttr::get(complexEltTy, value.real());
605     auto imag = FloatAttr::get(complexEltTy, value.imag());
606     return ArrayAttr::get(complexTy.getContext(),
607                           ArrayRef<Attribute>{real, imag});
608   }
609   if (owner.isa<DenseStringElementsAttr>()) {
610     ArrayRef<StringRef> vals = owner.getRawStringData();
611     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
612   }
613   llvm_unreachable("unexpected element type");
614 }
615 
616 //===----------------------------------------------------------------------===//
617 // BoolElementIterator
618 
619 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
620     DenseElementsAttr attr, size_t dataIndex)
621     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
622           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
623 
624 bool DenseElementsAttr::BoolElementIterator::operator*() const {
625   return getBit(getData(), getDataIndex());
626 }
627 
628 //===----------------------------------------------------------------------===//
629 // IntElementIterator
630 
631 DenseElementsAttr::IntElementIterator::IntElementIterator(
632     DenseElementsAttr attr, size_t dataIndex)
633     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
634           attr.getRawData().data(), attr.isSplat(), dataIndex),
635       bitWidth(getDenseElementBitWidth(attr.getElementType())) {}
636 
637 APInt DenseElementsAttr::IntElementIterator::operator*() const {
638   return readBits(getData(),
639                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
640                   bitWidth);
641 }
642 
643 //===----------------------------------------------------------------------===//
644 // ComplexIntElementIterator
645 
646 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
647     DenseElementsAttr attr, size_t dataIndex)
648     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
649                                       std::complex<APInt>, std::complex<APInt>,
650                                       std::complex<APInt>>(
651           attr.getRawData().data(), attr.isSplat(), dataIndex) {
652   auto complexType = attr.getElementType().cast<ComplexType>();
653   bitWidth = getDenseElementBitWidth(complexType.getElementType());
654 }
655 
656 std::complex<APInt>
657 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
658   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
659   size_t offset = getDataIndex() * storageWidth * 2;
660   return {readBits(getData(), offset, bitWidth),
661           readBits(getData(), offset + storageWidth, bitWidth)};
662 }
663 
664 //===----------------------------------------------------------------------===//
665 // FloatElementIterator
666 
667 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
668     const llvm::fltSemantics &smt, IntElementIterator it)
669     : llvm::mapped_iterator<IntElementIterator,
670                             std::function<APFloat(const APInt &)>>(
671           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
672 
673 //===----------------------------------------------------------------------===//
674 // ComplexFloatElementIterator
675 
676 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
677     const llvm::fltSemantics &smt, ComplexIntElementIterator it)
678     : llvm::mapped_iterator<
679           ComplexIntElementIterator,
680           std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
681           it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
682             return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
683           }) {}
684 
685 //===----------------------------------------------------------------------===//
686 // DenseElementsAttr
687 //===----------------------------------------------------------------------===//
688 
689 /// Method for support type inquiry through isa, cast and dyn_cast.
690 bool DenseElementsAttr::classof(Attribute attr) {
691   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
692 }
693 
694 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
695                                          ArrayRef<Attribute> values) {
696   assert(hasSameElementsOrSplat(type, values));
697 
698   // If the element type is not based on int/float/index, assume it is a string
699   // type.
700   auto eltType = type.getElementType();
701   if (!type.getElementType().isIntOrIndexOrFloat()) {
702     SmallVector<StringRef, 8> stringValues;
703     stringValues.reserve(values.size());
704     for (Attribute attr : values) {
705       assert(attr.isa<StringAttr>() &&
706              "expected string value for non integer/index/float element");
707       stringValues.push_back(attr.cast<StringAttr>().getValue());
708     }
709     return get(type, stringValues);
710   }
711 
712   // Otherwise, get the raw storage width to use for the allocation.
713   size_t bitWidth = getDenseElementBitWidth(eltType);
714   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
715 
716   // Compress the attribute values into a character buffer.
717   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
718                             values.size());
719   APInt intVal;
720   for (unsigned i = 0, e = values.size(); i < e; ++i) {
721     assert(eltType == values[i].getType() &&
722            "expected attribute value to have element type");
723     if (eltType.isa<FloatType>())
724       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
725     else if (eltType.isa<IntegerType, IndexType>())
726       intVal = values[i].cast<IntegerAttr>().getValue();
727     else
728       llvm_unreachable("unexpected element type");
729 
730     assert(intVal.getBitWidth() == bitWidth &&
731            "expected value to have same bitwidth as element type");
732     writeBits(data.data(), i * storageBitWidth, intVal);
733   }
734   return DenseIntOrFPElementsAttr::getRaw(type, data,
735                                           /*isSplat=*/(values.size() == 1));
736 }
737 
738 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
739                                          ArrayRef<bool> values) {
740   assert(hasSameElementsOrSplat(type, values));
741   assert(type.getElementType().isInteger(1));
742 
743   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
744   for (int i = 0, e = values.size(); i != e; ++i)
745     setBit(buff.data(), i, values[i]);
746   return DenseIntOrFPElementsAttr::getRaw(type, buff,
747                                           /*isSplat=*/(values.size() == 1));
748 }
749 
750 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
751                                          ArrayRef<StringRef> values) {
752   assert(!type.getElementType().isIntOrFloat());
753   return DenseStringElementsAttr::get(type, values);
754 }
755 
756 /// Constructs a dense integer elements attribute from an array of APInt
757 /// values. Each APInt value is expected to have the same bitwidth as the
758 /// element type of 'type'.
759 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
760                                          ArrayRef<APInt> values) {
761   assert(type.getElementType().isIntOrIndex());
762   assert(hasSameElementsOrSplat(type, values));
763   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
764   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
765                                           /*isSplat=*/(values.size() == 1));
766 }
767 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
768                                          ArrayRef<std::complex<APInt>> values) {
769   ComplexType complex = type.getElementType().cast<ComplexType>();
770   assert(complex.getElementType().isa<IntegerType>());
771   assert(hasSameElementsOrSplat(type, values));
772   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
773   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
774                           values.size() * 2);
775   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
776                                           /*isSplat=*/(values.size() == 1));
777 }
778 
779 // Constructs a dense float elements attribute from an array of APFloat
780 // values. Each APFloat value is expected to have the same bitwidth as the
781 // element type of 'type'.
782 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
783                                          ArrayRef<APFloat> values) {
784   assert(type.getElementType().isa<FloatType>());
785   assert(hasSameElementsOrSplat(type, values));
786   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
787   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
788                                           /*isSplat=*/(values.size() == 1));
789 }
790 DenseElementsAttr
791 DenseElementsAttr::get(ShapedType type,
792                        ArrayRef<std::complex<APFloat>> values) {
793   ComplexType complex = type.getElementType().cast<ComplexType>();
794   assert(complex.getElementType().isa<FloatType>());
795   assert(hasSameElementsOrSplat(type, values));
796   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
797                            values.size() * 2);
798   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
799   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
800                                           /*isSplat=*/(values.size() == 1));
801 }
802 
803 /// Construct a dense elements attribute from a raw buffer representing the
804 /// data for this attribute. Users should generally not use this methods as
805 /// the expected buffer format may not be a form the user expects.
806 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
807                                                       ArrayRef<char> rawBuffer,
808                                                       bool isSplatBuffer) {
809   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
810 }
811 
812 /// Returns true if the given buffer is a valid raw buffer for the given type.
813 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
814                                          ArrayRef<char> rawBuffer,
815                                          bool &detectedSplat) {
816   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
817   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
818 
819   // Storage width of 1 is special as it is packed by the bit.
820   if (storageWidth == 1) {
821     // Check for a splat, or a buffer equal to the number of elements which
822     // consists of either all 0's or all 1's.
823     detectedSplat = false;
824     if (rawBuffer.size() == 1) {
825       auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
826       if (rawByte == 0 || rawByte == 0xff) {
827         detectedSplat = true;
828         return true;
829       }
830     }
831     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
832   }
833   // All other types are 8-bit aligned.
834   if ((detectedSplat = rawBufferWidth == storageWidth))
835     return true;
836   return rawBufferWidth == (storageWidth * type.getNumElements());
837 }
838 
839 /// Check the information for a C++ data type, check if this type is valid for
840 /// the current attribute. This method is used to verify specific type
841 /// invariants that the templatized 'getValues' method cannot.
842 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
843                               bool isSigned) {
844   // Make sure that the data element size is the same as the type element width.
845   if (getDenseElementBitWidth(type) !=
846       static_cast<size_t>(dataEltSize * CHAR_BIT))
847     return false;
848 
849   // Check that the element type is either float or integer or index.
850   if (!isInt)
851     return type.isa<FloatType>();
852   if (type.isIndex())
853     return true;
854 
855   auto intType = type.dyn_cast<IntegerType>();
856   if (!intType)
857     return false;
858 
859   // Make sure signedness semantics is consistent.
860   if (intType.isSignless())
861     return true;
862   return intType.isSigned() ? isSigned : !isSigned;
863 }
864 
865 /// Defaults down the subclass implementation.
866 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
867                                                    ArrayRef<char> data,
868                                                    int64_t dataEltSize,
869                                                    bool isInt, bool isSigned) {
870   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
871                                                  isSigned);
872 }
873 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
874                                                       ArrayRef<char> data,
875                                                       int64_t dataEltSize,
876                                                       bool isInt,
877                                                       bool isSigned) {
878   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
879                                                     isInt, isSigned);
880 }
881 
882 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
883                                           bool isSigned) const {
884   return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned);
885 }
886 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
887                                        bool isSigned) const {
888   return ::isValidIntOrFloat(
889       getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2,
890       isInt, isSigned);
891 }
892 
893 /// Returns true if this attribute corresponds to a splat, i.e. if all element
894 /// values are the same.
895 bool DenseElementsAttr::isSplat() const {
896   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
897 }
898 
899 /// Return if the given complex type has an integer element type.
900 LLVM_ATTRIBUTE_UNUSED static bool isComplexOfIntType(Type type) {
901   return type.cast<ComplexType>().getElementType().isa<IntegerType>();
902 }
903 
904 auto DenseElementsAttr::getComplexIntValues() const
905     -> iterator_range_impl<ComplexIntElementIterator> {
906   assert(isComplexOfIntType(getElementType()) &&
907          "expected complex integral type");
908   return {getType(), ComplexIntElementIterator(*this, 0),
909           ComplexIntElementIterator(*this, getNumElements())};
910 }
911 auto DenseElementsAttr::complex_value_begin() const
912     -> ComplexIntElementIterator {
913   assert(isComplexOfIntType(getElementType()) &&
914          "expected complex integral type");
915   return ComplexIntElementIterator(*this, 0);
916 }
917 auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator {
918   assert(isComplexOfIntType(getElementType()) &&
919          "expected complex integral type");
920   return ComplexIntElementIterator(*this, getNumElements());
921 }
922 
923 /// Return the held element values as a range of APFloat. The element type of
924 /// this attribute must be of float type.
925 auto DenseElementsAttr::getFloatValues() const
926     -> iterator_range_impl<FloatElementIterator> {
927   auto elementType = getElementType().cast<FloatType>();
928   const auto &elementSemantics = elementType.getFloatSemantics();
929   return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()),
930           FloatElementIterator(elementSemantics, raw_int_end())};
931 }
932 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
933   auto elementType = getElementType().cast<FloatType>();
934   return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin());
935 }
936 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
937   auto elementType = getElementType().cast<FloatType>();
938   return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end());
939 }
940 
941 auto DenseElementsAttr::getComplexFloatValues() const
942     -> iterator_range_impl<ComplexFloatElementIterator> {
943   Type eltTy = getElementType().cast<ComplexType>().getElementType();
944   assert(eltTy.isa<FloatType>() && "expected complex float type");
945   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
946   return {getType(),
947           {semantics, {*this, 0}},
948           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
949 }
950 auto DenseElementsAttr::complex_float_value_begin() const
951     -> ComplexFloatElementIterator {
952   Type eltTy = getElementType().cast<ComplexType>().getElementType();
953   assert(eltTy.isa<FloatType>() && "expected complex float type");
954   return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}};
955 }
956 auto DenseElementsAttr::complex_float_value_end() const
957     -> ComplexFloatElementIterator {
958   Type eltTy = getElementType().cast<ComplexType>().getElementType();
959   assert(eltTy.isa<FloatType>() && "expected complex float type");
960   return {eltTy.cast<FloatType>().getFloatSemantics(),
961           {*this, static_cast<size_t>(getNumElements())}};
962 }
963 
964 /// Return the raw storage data held by this attribute.
965 ArrayRef<char> DenseElementsAttr::getRawData() const {
966   return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
967 }
968 
969 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
970   return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
971 }
972 
973 /// Return a new DenseElementsAttr that has the same data as the current
974 /// attribute, but has been reshaped to 'newType'. The new type must have the
975 /// same total number of elements as well as element type.
976 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
977   ShapedType curType = getType();
978   if (curType == newType)
979     return *this;
980 
981   assert(newType.getElementType() == curType.getElementType() &&
982          "expected the same element type");
983   assert(newType.getNumElements() == curType.getNumElements() &&
984          "expected the same number of elements");
985   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
986 }
987 
988 /// Return a new DenseElementsAttr that has the same data as the current
989 /// attribute, but has bitcast elements such that it is now 'newType'. The new
990 /// type must have the same shape and element types of the same bitwidth as the
991 /// current type.
992 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
993   ShapedType curType = getType();
994   Type curElType = curType.getElementType();
995   if (curElType == newElType)
996     return *this;
997 
998   assert(getDenseElementBitWidth(newElType) ==
999              getDenseElementBitWidth(curElType) &&
1000          "expected element types with the same bitwidth");
1001   return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
1002                                           getRawData(), isSplat());
1003 }
1004 
1005 DenseElementsAttr
1006 DenseElementsAttr::mapValues(Type newElementType,
1007                              function_ref<APInt(const APInt &)> mapping) const {
1008   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1009 }
1010 
1011 DenseElementsAttr DenseElementsAttr::mapValues(
1012     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1013   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1014 }
1015 
1016 ShapedType DenseElementsAttr::getType() const {
1017   return Attribute::getType().cast<ShapedType>();
1018 }
1019 
1020 Type DenseElementsAttr::getElementType() const {
1021   return getType().getElementType();
1022 }
1023 
1024 int64_t DenseElementsAttr::getNumElements() const {
1025   return getType().getNumElements();
1026 }
1027 
1028 //===----------------------------------------------------------------------===//
1029 // DenseIntOrFPElementsAttr
1030 //===----------------------------------------------------------------------===//
1031 
1032 /// Utility method to write a range of APInt values to a buffer.
1033 template <typename APRangeT>
1034 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1035                                 APRangeT &&values) {
1036   data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
1037   size_t offset = 0;
1038   for (auto it = values.begin(), e = values.end(); it != e;
1039        ++it, offset += storageWidth) {
1040     assert((*it).getBitWidth() <= storageWidth);
1041     writeBits(data.data(), offset, *it);
1042   }
1043 }
1044 
1045 /// Constructs a dense elements attribute from an array of raw APFloat values.
1046 /// Each APFloat value is expected to have the same bitwidth as the element
1047 /// type of 'type'. 'type' must be a vector or tensor with static shape.
1048 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1049                                                    size_t storageWidth,
1050                                                    ArrayRef<APFloat> values,
1051                                                    bool isSplat) {
1052   std::vector<char> data;
1053   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1054   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1055   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1056 }
1057 
1058 /// Constructs a dense elements attribute from an array of raw APInt values.
1059 /// Each APInt value is expected to have the same bitwidth as the element type
1060 /// of 'type'.
1061 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1062                                                    size_t storageWidth,
1063                                                    ArrayRef<APInt> values,
1064                                                    bool isSplat) {
1065   std::vector<char> data;
1066   writeAPIntsToBuffer(storageWidth, data, values);
1067   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1068 }
1069 
1070 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1071                                                    ArrayRef<char> data,
1072                                                    bool isSplat) {
1073   assert((type.isa<RankedTensorType, VectorType>()) &&
1074          "type must be ranked tensor or vector");
1075   assert(type.hasStaticShape() && "type must have static shape");
1076   return Base::get(type.getContext(), type, data, isSplat);
1077 }
1078 
1079 /// Overload of the raw 'get' method that asserts that the given type is of
1080 /// complex type. This method is used to verify type invariants that the
1081 /// templatized 'get' method cannot.
1082 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1083                                                           ArrayRef<char> data,
1084                                                           int64_t dataEltSize,
1085                                                           bool isInt,
1086                                                           bool isSigned) {
1087   assert(::isValidIntOrFloat(
1088       type.getElementType().cast<ComplexType>().getElementType(),
1089       dataEltSize / 2, isInt, isSigned));
1090 
1091   int64_t numElements = data.size() / dataEltSize;
1092   assert(numElements == 1 || numElements == type.getNumElements());
1093   return getRaw(type, data, /*isSplat=*/numElements == 1);
1094 }
1095 
1096 /// Overload of the 'getRaw' method that asserts that the given type is of
1097 /// integer type. This method is used to verify type invariants that the
1098 /// templatized 'get' method cannot.
1099 DenseElementsAttr
1100 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1101                                            int64_t dataEltSize, bool isInt,
1102                                            bool isSigned) {
1103   assert(
1104       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1105 
1106   int64_t numElements = data.size() / dataEltSize;
1107   assert(numElements == 1 || numElements == type.getNumElements());
1108   return getRaw(type, data, /*isSplat=*/numElements == 1);
1109 }
1110 
1111 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1112     const char *inRawData, char *outRawData, size_t elementBitWidth,
1113     size_t numElements) {
1114   using llvm::support::ulittle16_t;
1115   using llvm::support::ulittle32_t;
1116   using llvm::support::ulittle64_t;
1117 
1118   assert(llvm::support::endian::system_endianness() == // NOLINT
1119          llvm::support::endianness::big);              // NOLINT
1120   // NOLINT to avoid warning message about replacing by static_assert()
1121 
1122   // Following std::copy_n always converts endianness on BE machine.
1123   switch (elementBitWidth) {
1124   case 16: {
1125     const ulittle16_t *inRawDataPos =
1126         reinterpret_cast<const ulittle16_t *>(inRawData);
1127     uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1128     std::copy_n(inRawDataPos, numElements, outDataPos);
1129     break;
1130   }
1131   case 32: {
1132     const ulittle32_t *inRawDataPos =
1133         reinterpret_cast<const ulittle32_t *>(inRawData);
1134     uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1135     std::copy_n(inRawDataPos, numElements, outDataPos);
1136     break;
1137   }
1138   case 64: {
1139     const ulittle64_t *inRawDataPos =
1140         reinterpret_cast<const ulittle64_t *>(inRawData);
1141     uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1142     std::copy_n(inRawDataPos, numElements, outDataPos);
1143     break;
1144   }
1145   default: {
1146     size_t nBytes = elementBitWidth / CHAR_BIT;
1147     for (size_t i = 0; i < nBytes; i++)
1148       std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1149     break;
1150   }
1151   }
1152 }
1153 
1154 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1155     ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1156     ShapedType type) {
1157   size_t numElements = type.getNumElements();
1158   Type elementType = type.getElementType();
1159   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1160     elementType = complexTy.getElementType();
1161     numElements = numElements * 2;
1162   }
1163   size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1164   assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1165          inRawData.size() <= outRawData.size());
1166   convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1167                                   elementBitWidth, numElements);
1168 }
1169 
1170 //===----------------------------------------------------------------------===//
1171 // DenseFPElementsAttr
1172 //===----------------------------------------------------------------------===//
1173 
1174 template <typename Fn, typename Attr>
1175 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1176                                 Type newElementType,
1177                                 llvm::SmallVectorImpl<char> &data) {
1178   size_t bitWidth = getDenseElementBitWidth(newElementType);
1179   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1180 
1181   ShapedType newArrayType;
1182   if (inType.isa<RankedTensorType>())
1183     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1184   else if (inType.isa<UnrankedTensorType>())
1185     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1186   else if (inType.isa<VectorType>())
1187     newArrayType = VectorType::get(inType.getShape(), newElementType);
1188   else
1189     assert(newArrayType && "Unhandled tensor type");
1190 
1191   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1192   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1193 
1194   // Functor used to process a single element value of the attribute.
1195   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1196     auto newInt = mapping(value);
1197     assert(newInt.getBitWidth() == bitWidth);
1198     writeBits(data.data(), index * storageBitWidth, newInt);
1199   };
1200 
1201   // Check for the splat case.
1202   if (attr.isSplat()) {
1203     processElt(*attr.begin(), /*index=*/0);
1204     return newArrayType;
1205   }
1206 
1207   // Otherwise, process all of the element values.
1208   uint64_t elementIdx = 0;
1209   for (auto value : attr)
1210     processElt(value, elementIdx++);
1211   return newArrayType;
1212 }
1213 
1214 DenseElementsAttr DenseFPElementsAttr::mapValues(
1215     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1216   llvm::SmallVector<char, 8> elementData;
1217   auto newArrayType =
1218       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1219 
1220   return getRaw(newArrayType, elementData, isSplat());
1221 }
1222 
1223 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1224 bool DenseFPElementsAttr::classof(Attribute attr) {
1225   return attr.isa<DenseElementsAttr>() &&
1226          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1227 }
1228 
1229 //===----------------------------------------------------------------------===//
1230 // DenseIntElementsAttr
1231 //===----------------------------------------------------------------------===//
1232 
1233 DenseElementsAttr DenseIntElementsAttr::mapValues(
1234     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1235   llvm::SmallVector<char, 8> elementData;
1236   auto newArrayType =
1237       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1238 
1239   return getRaw(newArrayType, elementData, isSplat());
1240 }
1241 
1242 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1243 bool DenseIntElementsAttr::classof(Attribute attr) {
1244   return attr.isa<DenseElementsAttr>() &&
1245          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1246 }
1247 
1248 //===----------------------------------------------------------------------===//
1249 // OpaqueElementsAttr
1250 //===----------------------------------------------------------------------===//
1251 
1252 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1253   Dialect *dialect = getDialect().getDialect();
1254   if (!dialect)
1255     return true;
1256   auto *interface =
1257       dialect->getRegisteredInterface<DialectDecodeAttributesInterface>();
1258   if (!interface)
1259     return true;
1260   return failed(interface->decode(*this, result));
1261 }
1262 
1263 LogicalResult
1264 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1265                            Identifier dialect, StringRef value,
1266                            ShapedType type) {
1267   if (!Dialect::isValidNamespace(dialect.strref()))
1268     return emitError() << "invalid dialect namespace '" << dialect << "'";
1269   return success();
1270 }
1271 
1272 //===----------------------------------------------------------------------===//
1273 // SparseElementsAttr
1274 //===----------------------------------------------------------------------===//
1275 
1276 /// Get a zero APFloat for the given sparse attribute.
1277 APFloat SparseElementsAttr::getZeroAPFloat() const {
1278   auto eltType = getElementType().cast<FloatType>();
1279   return APFloat(eltType.getFloatSemantics());
1280 }
1281 
1282 /// Get a zero APInt for the given sparse attribute.
1283 APInt SparseElementsAttr::getZeroAPInt() const {
1284   auto eltType = getElementType().cast<IntegerType>();
1285   return APInt::getZero(eltType.getWidth());
1286 }
1287 
1288 /// Get a zero attribute for the given attribute type.
1289 Attribute SparseElementsAttr::getZeroAttr() const {
1290   auto eltType = getElementType();
1291 
1292   // Handle floating point elements.
1293   if (eltType.isa<FloatType>())
1294     return FloatAttr::get(eltType, 0);
1295 
1296   // Handle string type.
1297   if (getValues().isa<DenseStringElementsAttr>())
1298     return StringAttr::get("", eltType);
1299 
1300   // Otherwise, this is an integer.
1301   return IntegerAttr::get(eltType, 0);
1302 }
1303 
1304 /// Flatten, and return, all of the sparse indices in this attribute in
1305 /// row-major order.
1306 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1307   std::vector<ptrdiff_t> flatSparseIndices;
1308 
1309   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1310   // as a 1-D index array.
1311   auto sparseIndices = getIndices();
1312   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1313   if (sparseIndices.isSplat()) {
1314     SmallVector<uint64_t, 8> indices(getType().getRank(),
1315                                      *sparseIndexValues.begin());
1316     flatSparseIndices.push_back(getFlattenedIndex(indices));
1317     return flatSparseIndices;
1318   }
1319 
1320   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1321   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1322   size_t rank = getType().getRank();
1323   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1324     flatSparseIndices.push_back(getFlattenedIndex(
1325         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1326   return flatSparseIndices;
1327 }
1328 
1329 LogicalResult
1330 SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1331                            ShapedType type, DenseIntElementsAttr sparseIndices,
1332                            DenseElementsAttr values) {
1333   ShapedType valuesType = values.getType();
1334   if (valuesType.getRank() != 1)
1335     return emitError() << "expected 1-d tensor for sparse element values";
1336 
1337   // Verify the indices and values shape.
1338   ShapedType indicesType = sparseIndices.getType();
1339   auto emitShapeError = [&]() {
1340     return emitError() << "expected shape ([" << type.getShape()
1341                        << "]); inferred shape of indices literal (["
1342                        << indicesType.getShape()
1343                        << "]); inferred shape of values literal (["
1344                        << valuesType.getShape() << "])";
1345   };
1346   // Verify indices shape.
1347   size_t rank = type.getRank(), indicesRank = indicesType.getRank();
1348   if (indicesRank == 2) {
1349     if (indicesType.getDimSize(1) != static_cast<int64_t>(rank))
1350       return emitShapeError();
1351   } else if (indicesRank != 1 || rank != 1) {
1352     return emitShapeError();
1353   }
1354   // Verify the values shape.
1355   int64_t numSparseIndices = indicesType.getDimSize(0);
1356   if (numSparseIndices != valuesType.getDimSize(0))
1357     return emitShapeError();
1358 
1359   // Verify that the sparse indices are within the value shape.
1360   auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) {
1361     return emitError()
1362            << "sparse index #" << indexNum
1363            << " is not contained within the value shape, with index=[" << index
1364            << "], and type=" << type;
1365   };
1366 
1367   // Handle the case where the index values are a splat.
1368   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1369   if (sparseIndices.isSplat()) {
1370     SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin());
1371     if (!ElementsAttr::isValidIndex(type, indices))
1372       return emitIndexError(0, indices);
1373     return success();
1374   }
1375 
1376   // Otherwise, reinterpret each index as an ArrayRef.
1377   for (size_t i = 0, e = numSparseIndices; i != e; ++i) {
1378     ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank),
1379                              rank);
1380     if (!ElementsAttr::isValidIndex(type, index))
1381       return emitIndexError(i, index);
1382   }
1383 
1384   return success();
1385 }
1386 
1387 //===----------------------------------------------------------------------===//
1388 // TypeAttr
1389 //===----------------------------------------------------------------------===//
1390 
1391 void TypeAttr::walkImmediateSubElements(
1392     function_ref<void(Attribute)> walkAttrsFn,
1393     function_ref<void(Type)> walkTypesFn) const {
1394   walkTypesFn(getValue());
1395 }
1396