1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BuiltinDialect.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/IntegerSet.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/SymbolTable.h"
17 #include "mlir/IR/Types.h"
18 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
19 #include "llvm/ADT/APSInt.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/Support/Endian.h"
22 
23 using namespace mlir;
24 using namespace mlir::detail;
25 
26 //===----------------------------------------------------------------------===//
27 /// Tablegen Attribute Definitions
28 //===----------------------------------------------------------------------===//
29 
30 #define GET_ATTRDEF_CLASSES
31 #include "mlir/IR/BuiltinAttributes.cpp.inc"
32 
33 //===----------------------------------------------------------------------===//
34 // BuiltinDialect
35 //===----------------------------------------------------------------------===//
36 
37 void BuiltinDialect::registerAttributes() {
38   addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
39                 DenseStringElementsAttr, DictionaryAttr, FloatAttr,
40                 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
41                 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
42                 UnitAttr>();
43 }
44 
45 //===----------------------------------------------------------------------===//
46 // ArrayAttr
47 //===----------------------------------------------------------------------===//
48 
49 void ArrayAttr::walkImmediateSubElements(
50     function_ref<void(Attribute)> walkAttrsFn,
51     function_ref<void(Type)> walkTypesFn) const {
52   for (Attribute attr : getValue())
53     walkAttrsFn(attr);
54 }
55 
56 SubElementAttrInterface ArrayAttr::replaceImmediateSubAttribute(
57     ArrayRef<std::pair<size_t, Attribute>> replacements) const {
58   std::vector<Attribute> vector = getValue().vec();
59   for (auto &it : replacements) {
60     vector[it.first] = it.second;
61   }
62   return get(getContext(), vector);
63 }
64 
65 //===----------------------------------------------------------------------===//
66 // DictionaryAttr
67 //===----------------------------------------------------------------------===//
68 
69 /// Helper function that does either an in place sort or sorts from source array
70 /// into destination. If inPlace then storage is both the source and the
71 /// destination, else value is the source and storage destination. Returns
72 /// whether source was sorted.
73 template <bool inPlace>
74 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
75                                SmallVectorImpl<NamedAttribute> &storage) {
76   // Specialize for the common case.
77   switch (value.size()) {
78   case 0:
79     // Zero already sorted.
80     if (!inPlace)
81       storage.clear();
82     break;
83   case 1:
84     // One already sorted but may need to be copied.
85     if (!inPlace)
86       storage.assign({value[0]});
87     break;
88   case 2: {
89     bool isSorted = value[0] < value[1];
90     if (inPlace) {
91       if (!isSorted)
92         std::swap(storage[0], storage[1]);
93     } else if (isSorted) {
94       storage.assign({value[0], value[1]});
95     } else {
96       storage.assign({value[1], value[0]});
97     }
98     return !isSorted;
99   }
100   default:
101     if (!inPlace)
102       storage.assign(value.begin(), value.end());
103     // Check to see they are sorted already.
104     bool isSorted = llvm::is_sorted(value);
105     // If not, do a general sort.
106     if (!isSorted)
107       llvm::array_pod_sort(storage.begin(), storage.end());
108     return !isSorted;
109   }
110   return false;
111 }
112 
113 /// Returns an entry with a duplicate name from the given sorted array of named
114 /// attributes. Returns llvm::None if all elements have unique names.
115 static Optional<NamedAttribute>
116 findDuplicateElement(ArrayRef<NamedAttribute> value) {
117   const Optional<NamedAttribute> none{llvm::None};
118   if (value.size() < 2)
119     return none;
120 
121   if (value.size() == 2)
122     return value[0].first == value[1].first ? value[0] : none;
123 
124   auto it = std::adjacent_find(
125       value.begin(), value.end(),
126       [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; });
127   return it != value.end() ? *it : none;
128 }
129 
130 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
131                           SmallVectorImpl<NamedAttribute> &storage) {
132   bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
133   assert(!findDuplicateElement(storage) &&
134          "DictionaryAttr element names must be unique");
135   return isSorted;
136 }
137 
138 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
139   bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
140   assert(!findDuplicateElement(array) &&
141          "DictionaryAttr element names must be unique");
142   return isSorted;
143 }
144 
145 Optional<NamedAttribute>
146 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
147                               bool isSorted) {
148   if (!isSorted)
149     dictionaryAttrSort</*inPlace=*/true>(array, array);
150   return findDuplicateElement(array);
151 }
152 
153 DictionaryAttr DictionaryAttr::get(MLIRContext *context,
154                                    ArrayRef<NamedAttribute> value) {
155   if (value.empty())
156     return DictionaryAttr::getEmpty(context);
157   assert(llvm::all_of(value,
158                       [](const NamedAttribute &attr) { return attr.second; }) &&
159          "value cannot have null entries");
160 
161   // We need to sort the element list to canonicalize it.
162   SmallVector<NamedAttribute, 8> storage;
163   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
164     value = storage;
165   assert(!findDuplicateElement(value) &&
166          "DictionaryAttr element names must be unique");
167   return Base::get(context, value);
168 }
169 /// Construct a dictionary with an array of values that is known to already be
170 /// sorted by name and uniqued.
171 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
172                                              ArrayRef<NamedAttribute> value) {
173   if (value.empty())
174     return DictionaryAttr::getEmpty(context);
175   // Ensure that the attribute elements are unique and sorted.
176   assert(llvm::is_sorted(value,
177                          [](NamedAttribute l, NamedAttribute r) {
178                            return l.first.strref() < r.first.strref();
179                          }) &&
180          "expected attribute values to be sorted");
181   assert(!findDuplicateElement(value) &&
182          "DictionaryAttr element names must be unique");
183   return Base::get(context, value);
184 }
185 
186 /// Return the specified attribute if present, null otherwise.
187 Attribute DictionaryAttr::get(StringRef name) const {
188   auto it = impl::findAttrSorted(begin(), end(), name);
189   return it.second ? it.first->second : Attribute();
190 }
191 Attribute DictionaryAttr::get(StringAttr name) const {
192   auto it = impl::findAttrSorted(begin(), end(), name);
193   return it.second ? it.first->second : Attribute();
194 }
195 
196 /// Return the specified named attribute if present, None otherwise.
197 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
198   auto it = impl::findAttrSorted(begin(), end(), name);
199   return it.second ? *it.first : Optional<NamedAttribute>();
200 }
201 Optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const {
202   auto it = impl::findAttrSorted(begin(), end(), name);
203   return it.second ? *it.first : Optional<NamedAttribute>();
204 }
205 
206 /// Return whether the specified attribute is present.
207 bool DictionaryAttr::contains(StringRef name) const {
208   return impl::findAttrSorted(begin(), end(), name).second;
209 }
210 bool DictionaryAttr::contains(StringAttr name) const {
211   return impl::findAttrSorted(begin(), end(), name).second;
212 }
213 
214 DictionaryAttr::iterator DictionaryAttr::begin() const {
215   return getValue().begin();
216 }
217 DictionaryAttr::iterator DictionaryAttr::end() const {
218   return getValue().end();
219 }
220 size_t DictionaryAttr::size() const { return getValue().size(); }
221 
222 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
223   return Base::get(context, ArrayRef<NamedAttribute>());
224 }
225 
226 void DictionaryAttr::walkImmediateSubElements(
227     function_ref<void(Attribute)> walkAttrsFn,
228     function_ref<void(Type)> walkTypesFn) const {
229   for (Attribute attr : llvm::make_second_range(getValue()))
230     walkAttrsFn(attr);
231 }
232 
233 SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute(
234     ArrayRef<std::pair<size_t, Attribute>> replacements) const {
235   std::vector<NamedAttribute> vec = getValue().vec();
236   for (auto &it : replacements) {
237     vec[it.first].second = it.second;
238   }
239   // The above only modifies the mapped value, but not the key, and therefore
240   // not the order of the elements. It remains sorted
241   return getWithSorted(getContext(), vec);
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // StringAttr
246 //===----------------------------------------------------------------------===//
247 
248 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) {
249   return Base::get(context, "", NoneType::get(context));
250 }
251 
252 /// Twine support for StringAttr.
253 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) {
254   // Fast-path empty twine.
255   if (twine.isTriviallyEmpty())
256     return get(context);
257   SmallVector<char, 32> tempStr;
258   return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context));
259 }
260 
261 /// Twine support for StringAttr.
262 StringAttr StringAttr::get(const Twine &twine, Type type) {
263   SmallVector<char, 32> tempStr;
264   return Base::get(type.getContext(), twine.toStringRef(tempStr), type);
265 }
266 
267 StringRef StringAttr::getValue() const { return getImpl()->value; }
268 
269 Dialect *StringAttr::getReferencedDialect() const {
270   return getImpl()->referencedDialect;
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // FloatAttr
275 //===----------------------------------------------------------------------===//
276 
277 double FloatAttr::getValueAsDouble() const {
278   return getValueAsDouble(getValue());
279 }
280 double FloatAttr::getValueAsDouble(APFloat value) {
281   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
282     bool losesInfo = false;
283     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
284                   &losesInfo);
285   }
286   return value.convertToDouble();
287 }
288 
289 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
290                                 Type type, APFloat value) {
291   // Verify that the type is correct.
292   if (!type.isa<FloatType>())
293     return emitError() << "expected floating point type";
294 
295   // Verify that the type semantics match that of the value.
296   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
297     return emitError()
298            << "FloatAttr type doesn't match the type implied by its value";
299   }
300   return success();
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // SymbolRefAttr
305 //===----------------------------------------------------------------------===//
306 
307 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
308                                  ArrayRef<FlatSymbolRefAttr> nestedRefs) {
309   return get(StringAttr::get(ctx, value), nestedRefs);
310 }
311 
312 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
313   return get(ctx, value, {}).cast<FlatSymbolRefAttr>();
314 }
315 
316 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
317   return get(value, {}).cast<FlatSymbolRefAttr>();
318 }
319 
320 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) {
321   auto symName =
322       symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
323   assert(symName && "value does not have a valid symbol name");
324   return SymbolRefAttr::get(symName);
325 }
326 
327 StringAttr SymbolRefAttr::getLeafReference() const {
328   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
329   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
330 }
331 
332 //===----------------------------------------------------------------------===//
333 // IntegerAttr
334 //===----------------------------------------------------------------------===//
335 
336 int64_t IntegerAttr::getInt() const {
337   assert((getType().isIndex() || getType().isSignlessInteger()) &&
338          "must be signless integer");
339   return getValue().getSExtValue();
340 }
341 
342 int64_t IntegerAttr::getSInt() const {
343   assert(getType().isSignedInteger() && "must be signed integer");
344   return getValue().getSExtValue();
345 }
346 
347 uint64_t IntegerAttr::getUInt() const {
348   assert(getType().isUnsignedInteger() && "must be unsigned integer");
349   return getValue().getZExtValue();
350 }
351 
352 /// Return the value as an APSInt which carries the signed from the type of
353 /// the attribute.  This traps on signless integers types!
354 APSInt IntegerAttr::getAPSInt() const {
355   assert(!getType().isSignlessInteger() &&
356          "Signless integers don't carry a sign for APSInt");
357   return APSInt(getValue(), getType().isUnsignedInteger());
358 }
359 
360 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
361                                   Type type, APInt value) {
362   if (IntegerType integerType = type.dyn_cast<IntegerType>()) {
363     if (integerType.getWidth() != value.getBitWidth())
364       return emitError() << "integer type bit width (" << integerType.getWidth()
365                          << ") doesn't match value bit width ("
366                          << value.getBitWidth() << ")";
367     return success();
368   }
369   if (type.isa<IndexType>())
370     return success();
371   return emitError() << "expected integer or index type";
372 }
373 
374 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
375   auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
376   return attr.cast<BoolAttr>();
377 }
378 
379 //===----------------------------------------------------------------------===//
380 // BoolAttr
381 
382 bool BoolAttr::getValue() const {
383   auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
384   return storage->value.getBoolValue();
385 }
386 
387 bool BoolAttr::classof(Attribute attr) {
388   IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
389   return intAttr && intAttr.getType().isSignlessInteger(1);
390 }
391 
392 //===----------------------------------------------------------------------===//
393 // OpaqueAttr
394 //===----------------------------------------------------------------------===//
395 
396 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
397                                  StringAttr dialect, StringRef attrData,
398                                  Type type) {
399   if (!Dialect::isValidNamespace(dialect.strref()))
400     return emitError() << "invalid dialect namespace '" << dialect << "'";
401 
402   // Check that the dialect is actually registered.
403   MLIRContext *context = dialect.getContext();
404   if (!context->allowsUnregisteredDialects() &&
405       !context->getLoadedDialect(dialect.strref())) {
406     return emitError()
407            << "#" << dialect << "<\"" << attrData << "\"> : " << type
408            << " attribute created with unregistered dialect. If this is "
409               "intended, please call allowUnregisteredDialects() on the "
410               "MLIRContext, or use -allow-unregistered-dialect with "
411               "the MLIR opt tool used";
412   }
413 
414   return success();
415 }
416 
417 //===----------------------------------------------------------------------===//
418 // DenseElementsAttr Utilities
419 //===----------------------------------------------------------------------===//
420 
421 /// Get the bitwidth of a dense element type within the buffer.
422 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
423 static size_t getDenseElementStorageWidth(size_t origWidth) {
424   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
425 }
426 static size_t getDenseElementStorageWidth(Type elementType) {
427   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
428 }
429 
430 /// Set a bit to a specific value.
431 static void setBit(char *rawData, size_t bitPos, bool value) {
432   if (value)
433     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
434   else
435     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
436 }
437 
438 /// Return the value of the specified bit.
439 static bool getBit(const char *rawData, size_t bitPos) {
440   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
441 }
442 
443 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
444 /// BE format.
445 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
446                                          char *result) {
447   assert(llvm::support::endian::system_endianness() == // NOLINT
448          llvm::support::endianness::big);              // NOLINT
449   assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
450 
451   // Copy the words filled with data.
452   // For example, when `value` has 2 words, the first word is filled with data.
453   // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
454   size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
455   std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
456               numFilledWords, result);
457   // Convert last word of APInt to LE format and store it in char
458   // array(`valueLE`).
459   // ex. last word of `value` (BE): |------ij|  ==> `valueLE` (LE): |ji------|
460   size_t lastWordPos = numFilledWords;
461   SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
462   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
463       reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
464       valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
465   // Extract actual APInt data from `valueLE`, convert endianness to BE format,
466   // and store it in `result`.
467   // ex. `valueLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|ij|
468   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
469       valueLE.begin(), result + lastWordPos,
470       (numBytes - lastWordPos) * CHAR_BIT, 1);
471 }
472 
473 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
474 /// format.
475 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
476                                          APInt &result) {
477   assert(llvm::support::endian::system_endianness() == // NOLINT
478          llvm::support::endianness::big);              // NOLINT
479   assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
480 
481   // Copy the data that fills the word of `result` from `inArray`.
482   // For example, when `result` has 2 words, the first word will be filled with
483   // data. So, the first 8 bytes are copied from `inArray` here.
484   // `inArray` (10 bytes, BE): |abcdefgh|ij|
485   //                     ==> `result` (2 words, BE): |abcdefgh|--------|
486   size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
487   std::copy_n(
488       inArray, numFilledWords,
489       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
490 
491   // Convert array data which will be last word of `result` to LE format, and
492   // store it in char array(`inArrayLE`).
493   // ex. `inArray` (last two bytes, BE): |ij|  ==> `inArrayLE` (LE): |ji------|
494   size_t lastWordPos = numFilledWords;
495   SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
496   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
497       inArray + lastWordPos, inArrayLE.begin(),
498       (numBytes - lastWordPos) * CHAR_BIT, 1);
499 
500   // Convert `inArrayLE` to BE format, and store it in last word of `result`.
501   // ex. `inArrayLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|------ij|
502   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
503       inArrayLE.begin(),
504       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
505           lastWordPos,
506       APInt::APINT_BITS_PER_WORD, 1);
507 }
508 
509 /// Writes value to the bit position `bitPos` in array `rawData`.
510 static void writeBits(char *rawData, size_t bitPos, APInt value) {
511   size_t bitWidth = value.getBitWidth();
512 
513   // If the bitwidth is 1 we just toggle the specific bit.
514   if (bitWidth == 1)
515     return setBit(rawData, bitPos, value.isOneValue());
516 
517   // Otherwise, the bit position is guaranteed to be byte aligned.
518   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
519   if (llvm::support::endian::system_endianness() ==
520       llvm::support::endianness::big) {
521     // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
522     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
523     // work correctly in BE format.
524     // ex. `value` (2 words including 10 bytes)
525     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------|
526     copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
527                                  rawData + (bitPos / CHAR_BIT));
528   } else {
529     std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
530                 llvm::divideCeil(bitWidth, CHAR_BIT),
531                 rawData + (bitPos / CHAR_BIT));
532   }
533 }
534 
535 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
536 /// `rawData`.
537 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
538   // Handle a boolean bit position.
539   if (bitWidth == 1)
540     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
541 
542   // Otherwise, the bit position must be 8-bit aligned.
543   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
544   APInt result(bitWidth, 0);
545   if (llvm::support::endian::system_endianness() ==
546       llvm::support::endianness::big) {
547     // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
548     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
549     // work correctly in BE format.
550     // ex. `result` (2 words including 10 bytes)
551     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------| This function
552     copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
553                                  llvm::divideCeil(bitWidth, CHAR_BIT), result);
554   } else {
555     std::copy_n(rawData + (bitPos / CHAR_BIT),
556                 llvm::divideCeil(bitWidth, CHAR_BIT),
557                 const_cast<char *>(
558                     reinterpret_cast<const char *>(result.getRawData())));
559   }
560   return result;
561 }
562 
563 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
564 /// the same element count as 'type'.
565 template <typename Values>
566 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
567   return (values.size() == 1) ||
568          (type.getNumElements() == static_cast<int64_t>(values.size()));
569 }
570 
571 //===----------------------------------------------------------------------===//
572 // DenseElementsAttr Iterators
573 //===----------------------------------------------------------------------===//
574 
575 //===----------------------------------------------------------------------===//
576 // AttributeElementIterator
577 
578 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
579     DenseElementsAttr attr, size_t index)
580     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
581                                       Attribute, Attribute, Attribute>(
582           attr.getAsOpaquePointer(), index) {}
583 
584 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
585   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
586   Type eltTy = owner.getElementType();
587   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
588     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
589   if (eltTy.isa<IndexType>())
590     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
591   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
592     IntElementIterator intIt(owner, index);
593     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
594     return FloatAttr::get(eltTy, *floatIt);
595   }
596   if (auto complexTy = eltTy.dyn_cast<ComplexType>()) {
597     auto complexEltTy = complexTy.getElementType();
598     ComplexIntElementIterator complexIntIt(owner, index);
599     if (complexEltTy.isa<IntegerType>()) {
600       auto value = *complexIntIt;
601       auto real = IntegerAttr::get(complexEltTy, value.real());
602       auto imag = IntegerAttr::get(complexEltTy, value.imag());
603       return ArrayAttr::get(complexTy.getContext(),
604                             ArrayRef<Attribute>{real, imag});
605     }
606 
607     ComplexFloatElementIterator complexFloatIt(
608         complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt);
609     auto value = *complexFloatIt;
610     auto real = FloatAttr::get(complexEltTy, value.real());
611     auto imag = FloatAttr::get(complexEltTy, value.imag());
612     return ArrayAttr::get(complexTy.getContext(),
613                           ArrayRef<Attribute>{real, imag});
614   }
615   if (owner.isa<DenseStringElementsAttr>()) {
616     ArrayRef<StringRef> vals = owner.getRawStringData();
617     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
618   }
619   llvm_unreachable("unexpected element type");
620 }
621 
622 //===----------------------------------------------------------------------===//
623 // BoolElementIterator
624 
625 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
626     DenseElementsAttr attr, size_t dataIndex)
627     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
628           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
629 
630 bool DenseElementsAttr::BoolElementIterator::operator*() const {
631   return getBit(getData(), getDataIndex());
632 }
633 
634 //===----------------------------------------------------------------------===//
635 // IntElementIterator
636 
637 DenseElementsAttr::IntElementIterator::IntElementIterator(
638     DenseElementsAttr attr, size_t dataIndex)
639     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
640           attr.getRawData().data(), attr.isSplat(), dataIndex),
641       bitWidth(getDenseElementBitWidth(attr.getElementType())) {}
642 
643 APInt DenseElementsAttr::IntElementIterator::operator*() const {
644   return readBits(getData(),
645                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
646                   bitWidth);
647 }
648 
649 //===----------------------------------------------------------------------===//
650 // ComplexIntElementIterator
651 
652 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
653     DenseElementsAttr attr, size_t dataIndex)
654     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
655                                       std::complex<APInt>, std::complex<APInt>,
656                                       std::complex<APInt>>(
657           attr.getRawData().data(), attr.isSplat(), dataIndex) {
658   auto complexType = attr.getElementType().cast<ComplexType>();
659   bitWidth = getDenseElementBitWidth(complexType.getElementType());
660 }
661 
662 std::complex<APInt>
663 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
664   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
665   size_t offset = getDataIndex() * storageWidth * 2;
666   return {readBits(getData(), offset, bitWidth),
667           readBits(getData(), offset + storageWidth, bitWidth)};
668 }
669 
670 //===----------------------------------------------------------------------===//
671 // DenseElementsAttr
672 //===----------------------------------------------------------------------===//
673 
674 /// Method for support type inquiry through isa, cast and dyn_cast.
675 bool DenseElementsAttr::classof(Attribute attr) {
676   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
677 }
678 
679 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
680                                          ArrayRef<Attribute> values) {
681   assert(hasSameElementsOrSplat(type, values));
682 
683   // If the element type is not based on int/float/index, assume it is a string
684   // type.
685   auto eltType = type.getElementType();
686   if (!type.getElementType().isIntOrIndexOrFloat()) {
687     SmallVector<StringRef, 8> stringValues;
688     stringValues.reserve(values.size());
689     for (Attribute attr : values) {
690       assert(attr.isa<StringAttr>() &&
691              "expected string value for non integer/index/float element");
692       stringValues.push_back(attr.cast<StringAttr>().getValue());
693     }
694     return get(type, stringValues);
695   }
696 
697   // Otherwise, get the raw storage width to use for the allocation.
698   size_t bitWidth = getDenseElementBitWidth(eltType);
699   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
700 
701   // Compress the attribute values into a character buffer.
702   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
703                             values.size());
704   APInt intVal;
705   for (unsigned i = 0, e = values.size(); i < e; ++i) {
706     assert(eltType == values[i].getType() &&
707            "expected attribute value to have element type");
708     if (eltType.isa<FloatType>())
709       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
710     else if (eltType.isa<IntegerType, IndexType>())
711       intVal = values[i].cast<IntegerAttr>().getValue();
712     else
713       llvm_unreachable("unexpected element type");
714 
715     assert(intVal.getBitWidth() == bitWidth &&
716            "expected value to have same bitwidth as element type");
717     writeBits(data.data(), i * storageBitWidth, intVal);
718   }
719   return DenseIntOrFPElementsAttr::getRaw(type, data,
720                                           /*isSplat=*/(values.size() == 1));
721 }
722 
723 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
724                                          ArrayRef<bool> values) {
725   assert(hasSameElementsOrSplat(type, values));
726   assert(type.getElementType().isInteger(1));
727 
728   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
729   for (int i = 0, e = values.size(); i != e; ++i)
730     setBit(buff.data(), i, values[i]);
731   return DenseIntOrFPElementsAttr::getRaw(type, buff,
732                                           /*isSplat=*/(values.size() == 1));
733 }
734 
735 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
736                                          ArrayRef<StringRef> values) {
737   assert(!type.getElementType().isIntOrFloat());
738   return DenseStringElementsAttr::get(type, values);
739 }
740 
741 /// Constructs a dense integer elements attribute from an array of APInt
742 /// values. Each APInt value is expected to have the same bitwidth as the
743 /// element type of 'type'.
744 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
745                                          ArrayRef<APInt> values) {
746   assert(type.getElementType().isIntOrIndex());
747   assert(hasSameElementsOrSplat(type, values));
748   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
749   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
750                                           /*isSplat=*/(values.size() == 1));
751 }
752 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
753                                          ArrayRef<std::complex<APInt>> values) {
754   ComplexType complex = type.getElementType().cast<ComplexType>();
755   assert(complex.getElementType().isa<IntegerType>());
756   assert(hasSameElementsOrSplat(type, values));
757   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
758   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
759                           values.size() * 2);
760   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
761                                           /*isSplat=*/(values.size() == 1));
762 }
763 
764 // Constructs a dense float elements attribute from an array of APFloat
765 // values. Each APFloat value is expected to have the same bitwidth as the
766 // element type of 'type'.
767 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
768                                          ArrayRef<APFloat> values) {
769   assert(type.getElementType().isa<FloatType>());
770   assert(hasSameElementsOrSplat(type, values));
771   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
772   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
773                                           /*isSplat=*/(values.size() == 1));
774 }
775 DenseElementsAttr
776 DenseElementsAttr::get(ShapedType type,
777                        ArrayRef<std::complex<APFloat>> values) {
778   ComplexType complex = type.getElementType().cast<ComplexType>();
779   assert(complex.getElementType().isa<FloatType>());
780   assert(hasSameElementsOrSplat(type, values));
781   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
782                            values.size() * 2);
783   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
784   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
785                                           /*isSplat=*/(values.size() == 1));
786 }
787 
788 /// Construct a dense elements attribute from a raw buffer representing the
789 /// data for this attribute. Users should generally not use this methods as
790 /// the expected buffer format may not be a form the user expects.
791 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
792                                                       ArrayRef<char> rawBuffer,
793                                                       bool isSplatBuffer) {
794   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
795 }
796 
797 /// Returns true if the given buffer is a valid raw buffer for the given type.
798 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
799                                          ArrayRef<char> rawBuffer,
800                                          bool &detectedSplat) {
801   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
802   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
803 
804   // Storage width of 1 is special as it is packed by the bit.
805   if (storageWidth == 1) {
806     // Check for a splat, or a buffer equal to the number of elements which
807     // consists of either all 0's or all 1's.
808     detectedSplat = false;
809     if (rawBuffer.size() == 1) {
810       auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
811       if (rawByte == 0 || rawByte == 0xff) {
812         detectedSplat = true;
813         return true;
814       }
815     }
816     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
817   }
818   // All other types are 8-bit aligned.
819   if ((detectedSplat = rawBufferWidth == storageWidth))
820     return true;
821   return rawBufferWidth == (storageWidth * type.getNumElements());
822 }
823 
824 /// Check the information for a C++ data type, check if this type is valid for
825 /// the current attribute. This method is used to verify specific type
826 /// invariants that the templatized 'getValues' method cannot.
827 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
828                               bool isSigned) {
829   // Make sure that the data element size is the same as the type element width.
830   if (getDenseElementBitWidth(type) !=
831       static_cast<size_t>(dataEltSize * CHAR_BIT))
832     return false;
833 
834   // Check that the element type is either float or integer or index.
835   if (!isInt)
836     return type.isa<FloatType>();
837   if (type.isIndex())
838     return true;
839 
840   auto intType = type.dyn_cast<IntegerType>();
841   if (!intType)
842     return false;
843 
844   // Make sure signedness semantics is consistent.
845   if (intType.isSignless())
846     return true;
847   return intType.isSigned() ? isSigned : !isSigned;
848 }
849 
850 /// Defaults down the subclass implementation.
851 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
852                                                    ArrayRef<char> data,
853                                                    int64_t dataEltSize,
854                                                    bool isInt, bool isSigned) {
855   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
856                                                  isSigned);
857 }
858 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
859                                                       ArrayRef<char> data,
860                                                       int64_t dataEltSize,
861                                                       bool isInt,
862                                                       bool isSigned) {
863   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
864                                                     isInt, isSigned);
865 }
866 
867 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
868                                           bool isSigned) const {
869   return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned);
870 }
871 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
872                                        bool isSigned) const {
873   return ::isValidIntOrFloat(
874       getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2,
875       isInt, isSigned);
876 }
877 
878 /// Returns true if this attribute corresponds to a splat, i.e. if all element
879 /// values are the same.
880 bool DenseElementsAttr::isSplat() const {
881   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
882 }
883 
884 /// Return if the given complex type has an integer element type.
885 LLVM_ATTRIBUTE_UNUSED static bool isComplexOfIntType(Type type) {
886   return type.cast<ComplexType>().getElementType().isa<IntegerType>();
887 }
888 
889 auto DenseElementsAttr::getComplexIntValues() const
890     -> iterator_range_impl<ComplexIntElementIterator> {
891   assert(isComplexOfIntType(getElementType()) &&
892          "expected complex integral type");
893   return {getType(), ComplexIntElementIterator(*this, 0),
894           ComplexIntElementIterator(*this, getNumElements())};
895 }
896 auto DenseElementsAttr::complex_value_begin() const
897     -> ComplexIntElementIterator {
898   assert(isComplexOfIntType(getElementType()) &&
899          "expected complex integral type");
900   return ComplexIntElementIterator(*this, 0);
901 }
902 auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator {
903   assert(isComplexOfIntType(getElementType()) &&
904          "expected complex integral type");
905   return ComplexIntElementIterator(*this, getNumElements());
906 }
907 
908 /// Return the held element values as a range of APFloat. The element type of
909 /// this attribute must be of float type.
910 auto DenseElementsAttr::getFloatValues() const
911     -> iterator_range_impl<FloatElementIterator> {
912   auto elementType = getElementType().cast<FloatType>();
913   const auto &elementSemantics = elementType.getFloatSemantics();
914   return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()),
915           FloatElementIterator(elementSemantics, raw_int_end())};
916 }
917 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
918   auto elementType = getElementType().cast<FloatType>();
919   return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin());
920 }
921 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
922   auto elementType = getElementType().cast<FloatType>();
923   return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end());
924 }
925 
926 auto DenseElementsAttr::getComplexFloatValues() const
927     -> iterator_range_impl<ComplexFloatElementIterator> {
928   Type eltTy = getElementType().cast<ComplexType>().getElementType();
929   assert(eltTy.isa<FloatType>() && "expected complex float type");
930   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
931   return {getType(),
932           {semantics, {*this, 0}},
933           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
934 }
935 auto DenseElementsAttr::complex_float_value_begin() const
936     -> ComplexFloatElementIterator {
937   Type eltTy = getElementType().cast<ComplexType>().getElementType();
938   assert(eltTy.isa<FloatType>() && "expected complex float type");
939   return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}};
940 }
941 auto DenseElementsAttr::complex_float_value_end() const
942     -> ComplexFloatElementIterator {
943   Type eltTy = getElementType().cast<ComplexType>().getElementType();
944   assert(eltTy.isa<FloatType>() && "expected complex float type");
945   return {eltTy.cast<FloatType>().getFloatSemantics(),
946           {*this, static_cast<size_t>(getNumElements())}};
947 }
948 
949 /// Return the raw storage data held by this attribute.
950 ArrayRef<char> DenseElementsAttr::getRawData() const {
951   return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
952 }
953 
954 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
955   return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
956 }
957 
958 /// Return a new DenseElementsAttr that has the same data as the current
959 /// attribute, but has been reshaped to 'newType'. The new type must have the
960 /// same total number of elements as well as element type.
961 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
962   ShapedType curType = getType();
963   if (curType == newType)
964     return *this;
965 
966   assert(newType.getElementType() == curType.getElementType() &&
967          "expected the same element type");
968   assert(newType.getNumElements() == curType.getNumElements() &&
969          "expected the same number of elements");
970   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
971 }
972 
973 /// Return a new DenseElementsAttr that has the same data as the current
974 /// attribute, but has bitcast elements such that it is now 'newType'. The new
975 /// type must have the same shape and element types of the same bitwidth as the
976 /// current type.
977 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
978   ShapedType curType = getType();
979   Type curElType = curType.getElementType();
980   if (curElType == newElType)
981     return *this;
982 
983   assert(getDenseElementBitWidth(newElType) ==
984              getDenseElementBitWidth(curElType) &&
985          "expected element types with the same bitwidth");
986   return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
987                                           getRawData(), isSplat());
988 }
989 
990 DenseElementsAttr
991 DenseElementsAttr::mapValues(Type newElementType,
992                              function_ref<APInt(const APInt &)> mapping) const {
993   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
994 }
995 
996 DenseElementsAttr DenseElementsAttr::mapValues(
997     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
998   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
999 }
1000 
1001 ShapedType DenseElementsAttr::getType() const {
1002   return Attribute::getType().cast<ShapedType>();
1003 }
1004 
1005 Type DenseElementsAttr::getElementType() const {
1006   return getType().getElementType();
1007 }
1008 
1009 int64_t DenseElementsAttr::getNumElements() const {
1010   return getType().getNumElements();
1011 }
1012 
1013 //===----------------------------------------------------------------------===//
1014 // DenseIntOrFPElementsAttr
1015 //===----------------------------------------------------------------------===//
1016 
1017 /// Utility method to write a range of APInt values to a buffer.
1018 template <typename APRangeT>
1019 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1020                                 APRangeT &&values) {
1021   data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
1022   size_t offset = 0;
1023   for (auto it = values.begin(), e = values.end(); it != e;
1024        ++it, offset += storageWidth) {
1025     assert((*it).getBitWidth() <= storageWidth);
1026     writeBits(data.data(), offset, *it);
1027   }
1028 }
1029 
1030 /// Constructs a dense elements attribute from an array of raw APFloat values.
1031 /// Each APFloat value is expected to have the same bitwidth as the element
1032 /// type of 'type'. 'type' must be a vector or tensor with static shape.
1033 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1034                                                    size_t storageWidth,
1035                                                    ArrayRef<APFloat> values,
1036                                                    bool isSplat) {
1037   std::vector<char> data;
1038   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1039   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1040   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1041 }
1042 
1043 /// Constructs a dense elements attribute from an array of raw APInt values.
1044 /// Each APInt value is expected to have the same bitwidth as the element type
1045 /// of 'type'.
1046 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1047                                                    size_t storageWidth,
1048                                                    ArrayRef<APInt> values,
1049                                                    bool isSplat) {
1050   std::vector<char> data;
1051   writeAPIntsToBuffer(storageWidth, data, values);
1052   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1053 }
1054 
1055 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1056                                                    ArrayRef<char> data,
1057                                                    bool isSplat) {
1058   assert((type.isa<RankedTensorType, VectorType>()) &&
1059          "type must be ranked tensor or vector");
1060   assert(type.hasStaticShape() && "type must have static shape");
1061   return Base::get(type.getContext(), type, data, isSplat);
1062 }
1063 
1064 /// Overload of the raw 'get' method that asserts that the given type is of
1065 /// complex type. This method is used to verify type invariants that the
1066 /// templatized 'get' method cannot.
1067 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1068                                                           ArrayRef<char> data,
1069                                                           int64_t dataEltSize,
1070                                                           bool isInt,
1071                                                           bool isSigned) {
1072   assert(::isValidIntOrFloat(
1073       type.getElementType().cast<ComplexType>().getElementType(),
1074       dataEltSize / 2, isInt, isSigned));
1075 
1076   int64_t numElements = data.size() / dataEltSize;
1077   assert(numElements == 1 || numElements == type.getNumElements());
1078   return getRaw(type, data, /*isSplat=*/numElements == 1);
1079 }
1080 
1081 /// Overload of the 'getRaw' method that asserts that the given type is of
1082 /// integer type. This method is used to verify type invariants that the
1083 /// templatized 'get' method cannot.
1084 DenseElementsAttr
1085 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1086                                            int64_t dataEltSize, bool isInt,
1087                                            bool isSigned) {
1088   assert(
1089       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1090 
1091   int64_t numElements = data.size() / dataEltSize;
1092   assert(numElements == 1 || numElements == type.getNumElements());
1093   return getRaw(type, data, /*isSplat=*/numElements == 1);
1094 }
1095 
1096 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1097     const char *inRawData, char *outRawData, size_t elementBitWidth,
1098     size_t numElements) {
1099   using llvm::support::ulittle16_t;
1100   using llvm::support::ulittle32_t;
1101   using llvm::support::ulittle64_t;
1102 
1103   assert(llvm::support::endian::system_endianness() == // NOLINT
1104          llvm::support::endianness::big);              // NOLINT
1105   // NOLINT to avoid warning message about replacing by static_assert()
1106 
1107   // Following std::copy_n always converts endianness on BE machine.
1108   switch (elementBitWidth) {
1109   case 16: {
1110     const ulittle16_t *inRawDataPos =
1111         reinterpret_cast<const ulittle16_t *>(inRawData);
1112     uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1113     std::copy_n(inRawDataPos, numElements, outDataPos);
1114     break;
1115   }
1116   case 32: {
1117     const ulittle32_t *inRawDataPos =
1118         reinterpret_cast<const ulittle32_t *>(inRawData);
1119     uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1120     std::copy_n(inRawDataPos, numElements, outDataPos);
1121     break;
1122   }
1123   case 64: {
1124     const ulittle64_t *inRawDataPos =
1125         reinterpret_cast<const ulittle64_t *>(inRawData);
1126     uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1127     std::copy_n(inRawDataPos, numElements, outDataPos);
1128     break;
1129   }
1130   default: {
1131     size_t nBytes = elementBitWidth / CHAR_BIT;
1132     for (size_t i = 0; i < nBytes; i++)
1133       std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1134     break;
1135   }
1136   }
1137 }
1138 
1139 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1140     ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1141     ShapedType type) {
1142   size_t numElements = type.getNumElements();
1143   Type elementType = type.getElementType();
1144   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1145     elementType = complexTy.getElementType();
1146     numElements = numElements * 2;
1147   }
1148   size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1149   assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1150          inRawData.size() <= outRawData.size());
1151   convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1152                                   elementBitWidth, numElements);
1153 }
1154 
1155 //===----------------------------------------------------------------------===//
1156 // DenseFPElementsAttr
1157 //===----------------------------------------------------------------------===//
1158 
1159 template <typename Fn, typename Attr>
1160 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1161                                 Type newElementType,
1162                                 llvm::SmallVectorImpl<char> &data) {
1163   size_t bitWidth = getDenseElementBitWidth(newElementType);
1164   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1165 
1166   ShapedType newArrayType;
1167   if (inType.isa<RankedTensorType>())
1168     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1169   else if (inType.isa<UnrankedTensorType>())
1170     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1171   else if (inType.isa<VectorType>())
1172     newArrayType = VectorType::get(inType.getShape(), newElementType);
1173   else
1174     assert(newArrayType && "Unhandled tensor type");
1175 
1176   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1177   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1178 
1179   // Functor used to process a single element value of the attribute.
1180   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1181     auto newInt = mapping(value);
1182     assert(newInt.getBitWidth() == bitWidth);
1183     writeBits(data.data(), index * storageBitWidth, newInt);
1184   };
1185 
1186   // Check for the splat case.
1187   if (attr.isSplat()) {
1188     processElt(*attr.begin(), /*index=*/0);
1189     return newArrayType;
1190   }
1191 
1192   // Otherwise, process all of the element values.
1193   uint64_t elementIdx = 0;
1194   for (auto value : attr)
1195     processElt(value, elementIdx++);
1196   return newArrayType;
1197 }
1198 
1199 DenseElementsAttr DenseFPElementsAttr::mapValues(
1200     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1201   llvm::SmallVector<char, 8> elementData;
1202   auto newArrayType =
1203       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1204 
1205   return getRaw(newArrayType, elementData, isSplat());
1206 }
1207 
1208 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1209 bool DenseFPElementsAttr::classof(Attribute attr) {
1210   return attr.isa<DenseElementsAttr>() &&
1211          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1212 }
1213 
1214 //===----------------------------------------------------------------------===//
1215 // DenseIntElementsAttr
1216 //===----------------------------------------------------------------------===//
1217 
1218 DenseElementsAttr DenseIntElementsAttr::mapValues(
1219     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1220   llvm::SmallVector<char, 8> elementData;
1221   auto newArrayType =
1222       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1223 
1224   return getRaw(newArrayType, elementData, isSplat());
1225 }
1226 
1227 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1228 bool DenseIntElementsAttr::classof(Attribute attr) {
1229   return attr.isa<DenseElementsAttr>() &&
1230          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1231 }
1232 
1233 //===----------------------------------------------------------------------===//
1234 // OpaqueElementsAttr
1235 //===----------------------------------------------------------------------===//
1236 
1237 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1238   Dialect *dialect = getContext()->getLoadedDialect(getDialect());
1239   if (!dialect)
1240     return true;
1241   auto *interface =
1242       dialect->getRegisteredInterface<DialectDecodeAttributesInterface>();
1243   if (!interface)
1244     return true;
1245   return failed(interface->decode(*this, result));
1246 }
1247 
1248 LogicalResult
1249 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1250                            StringAttr dialect, StringRef value,
1251                            ShapedType type) {
1252   if (!Dialect::isValidNamespace(dialect.strref()))
1253     return emitError() << "invalid dialect namespace '" << dialect << "'";
1254   return success();
1255 }
1256 
1257 //===----------------------------------------------------------------------===//
1258 // SparseElementsAttr
1259 //===----------------------------------------------------------------------===//
1260 
1261 /// Get a zero APFloat for the given sparse attribute.
1262 APFloat SparseElementsAttr::getZeroAPFloat() const {
1263   auto eltType = getElementType().cast<FloatType>();
1264   return APFloat(eltType.getFloatSemantics());
1265 }
1266 
1267 /// Get a zero APInt for the given sparse attribute.
1268 APInt SparseElementsAttr::getZeroAPInt() const {
1269   auto eltType = getElementType().cast<IntegerType>();
1270   return APInt::getZero(eltType.getWidth());
1271 }
1272 
1273 /// Get a zero attribute for the given attribute type.
1274 Attribute SparseElementsAttr::getZeroAttr() const {
1275   auto eltType = getElementType();
1276 
1277   // Handle floating point elements.
1278   if (eltType.isa<FloatType>())
1279     return FloatAttr::get(eltType, 0);
1280 
1281   // Handle string type.
1282   if (getValues().isa<DenseStringElementsAttr>())
1283     return StringAttr::get("", eltType);
1284 
1285   // Otherwise, this is an integer.
1286   return IntegerAttr::get(eltType, 0);
1287 }
1288 
1289 /// Flatten, and return, all of the sparse indices in this attribute in
1290 /// row-major order.
1291 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1292   std::vector<ptrdiff_t> flatSparseIndices;
1293 
1294   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1295   // as a 1-D index array.
1296   auto sparseIndices = getIndices();
1297   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1298   if (sparseIndices.isSplat()) {
1299     SmallVector<uint64_t, 8> indices(getType().getRank(),
1300                                      *sparseIndexValues.begin());
1301     flatSparseIndices.push_back(getFlattenedIndex(indices));
1302     return flatSparseIndices;
1303   }
1304 
1305   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1306   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1307   size_t rank = getType().getRank();
1308   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1309     flatSparseIndices.push_back(getFlattenedIndex(
1310         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1311   return flatSparseIndices;
1312 }
1313 
1314 LogicalResult
1315 SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1316                            ShapedType type, DenseIntElementsAttr sparseIndices,
1317                            DenseElementsAttr values) {
1318   ShapedType valuesType = values.getType();
1319   if (valuesType.getRank() != 1)
1320     return emitError() << "expected 1-d tensor for sparse element values";
1321 
1322   // Verify the indices and values shape.
1323   ShapedType indicesType = sparseIndices.getType();
1324   auto emitShapeError = [&]() {
1325     return emitError() << "expected shape ([" << type.getShape()
1326                        << "]); inferred shape of indices literal (["
1327                        << indicesType.getShape()
1328                        << "]); inferred shape of values literal (["
1329                        << valuesType.getShape() << "])";
1330   };
1331   // Verify indices shape.
1332   size_t rank = type.getRank(), indicesRank = indicesType.getRank();
1333   if (indicesRank == 2) {
1334     if (indicesType.getDimSize(1) != static_cast<int64_t>(rank))
1335       return emitShapeError();
1336   } else if (indicesRank != 1 || rank != 1) {
1337     return emitShapeError();
1338   }
1339   // Verify the values shape.
1340   int64_t numSparseIndices = indicesType.getDimSize(0);
1341   if (numSparseIndices != valuesType.getDimSize(0))
1342     return emitShapeError();
1343 
1344   // Verify that the sparse indices are within the value shape.
1345   auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) {
1346     return emitError()
1347            << "sparse index #" << indexNum
1348            << " is not contained within the value shape, with index=[" << index
1349            << "], and type=" << type;
1350   };
1351 
1352   // Handle the case where the index values are a splat.
1353   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1354   if (sparseIndices.isSplat()) {
1355     SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin());
1356     if (!ElementsAttr::isValidIndex(type, indices))
1357       return emitIndexError(0, indices);
1358     return success();
1359   }
1360 
1361   // Otherwise, reinterpret each index as an ArrayRef.
1362   for (size_t i = 0, e = numSparseIndices; i != e; ++i) {
1363     ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank),
1364                              rank);
1365     if (!ElementsAttr::isValidIndex(type, index))
1366       return emitIndexError(i, index);
1367   }
1368 
1369   return success();
1370 }
1371 
1372 //===----------------------------------------------------------------------===//
1373 // TypeAttr
1374 //===----------------------------------------------------------------------===//
1375 
1376 void TypeAttr::walkImmediateSubElements(
1377     function_ref<void(Attribute)> walkAttrsFn,
1378     function_ref<void(Type)> walkTypesFn) const {
1379   walkTypesFn(getValue());
1380 }
1381