1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BuiltinDialect.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/IntegerSet.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/SymbolTable.h"
17 #include "mlir/IR/Types.h"
18 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
19 #include "llvm/ADT/APSInt.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/Support/Endian.h"
22 
23 using namespace mlir;
24 using namespace mlir::detail;
25 
26 //===----------------------------------------------------------------------===//
27 /// Tablegen Attribute Definitions
28 //===----------------------------------------------------------------------===//
29 
30 #define GET_ATTRDEF_CLASSES
31 #include "mlir/IR/BuiltinAttributes.cpp.inc"
32 
33 //===----------------------------------------------------------------------===//
34 // BuiltinDialect
35 //===----------------------------------------------------------------------===//
36 
37 void BuiltinDialect::registerAttributes() {
38   addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
39                 DenseStringElementsAttr, DictionaryAttr, FloatAttr,
40                 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
41                 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
42                 UnitAttr>();
43 }
44 
45 //===----------------------------------------------------------------------===//
46 // ArrayAttr
47 //===----------------------------------------------------------------------===//
48 
49 void ArrayAttr::walkImmediateSubElements(
50     function_ref<void(Attribute)> walkAttrsFn,
51     function_ref<void(Type)> walkTypesFn) const {
52   for (Attribute attr : getValue())
53     walkAttrsFn(attr);
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // DictionaryAttr
58 //===----------------------------------------------------------------------===//
59 
60 /// Helper function that does either an in place sort or sorts from source array
61 /// into destination. If inPlace then storage is both the source and the
62 /// destination, else value is the source and storage destination. Returns
63 /// whether source was sorted.
64 template <bool inPlace>
65 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
66                                SmallVectorImpl<NamedAttribute> &storage) {
67   // Specialize for the common case.
68   switch (value.size()) {
69   case 0:
70     // Zero already sorted.
71     break;
72   case 1:
73     // One already sorted but may need to be copied.
74     if (!inPlace)
75       storage.assign({value[0]});
76     break;
77   case 2: {
78     bool isSorted = value[0] < value[1];
79     if (inPlace) {
80       if (!isSorted)
81         std::swap(storage[0], storage[1]);
82     } else if (isSorted) {
83       storage.assign({value[0], value[1]});
84     } else {
85       storage.assign({value[1], value[0]});
86     }
87     return !isSorted;
88   }
89   default:
90     if (!inPlace)
91       storage.assign(value.begin(), value.end());
92     // Check to see they are sorted already.
93     bool isSorted = llvm::is_sorted(value);
94     // If not, do a general sort.
95     if (!isSorted)
96       llvm::array_pod_sort(storage.begin(), storage.end());
97     return !isSorted;
98   }
99   return false;
100 }
101 
102 /// Returns an entry with a duplicate name from the given sorted array of named
103 /// attributes. Returns llvm::None if all elements have unique names.
104 static Optional<NamedAttribute>
105 findDuplicateElement(ArrayRef<NamedAttribute> value) {
106   const Optional<NamedAttribute> none{llvm::None};
107   if (value.size() < 2)
108     return none;
109 
110   if (value.size() == 2)
111     return value[0].first == value[1].first ? value[0] : none;
112 
113   auto it = std::adjacent_find(
114       value.begin(), value.end(),
115       [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; });
116   return it != value.end() ? *it : none;
117 }
118 
119 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
120                           SmallVectorImpl<NamedAttribute> &storage) {
121   bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
122   assert(!findDuplicateElement(storage) &&
123          "DictionaryAttr element names must be unique");
124   return isSorted;
125 }
126 
127 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
128   bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
129   assert(!findDuplicateElement(array) &&
130          "DictionaryAttr element names must be unique");
131   return isSorted;
132 }
133 
134 Optional<NamedAttribute>
135 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
136                               bool isSorted) {
137   if (!isSorted)
138     dictionaryAttrSort</*inPlace=*/true>(array, array);
139   return findDuplicateElement(array);
140 }
141 
142 DictionaryAttr DictionaryAttr::get(MLIRContext *context,
143                                    ArrayRef<NamedAttribute> value) {
144   if (value.empty())
145     return DictionaryAttr::getEmpty(context);
146   assert(llvm::all_of(value,
147                       [](const NamedAttribute &attr) { return attr.second; }) &&
148          "value cannot have null entries");
149 
150   // We need to sort the element list to canonicalize it.
151   SmallVector<NamedAttribute, 8> storage;
152   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
153     value = storage;
154   assert(!findDuplicateElement(value) &&
155          "DictionaryAttr element names must be unique");
156   return Base::get(context, value);
157 }
158 /// Construct a dictionary with an array of values that is known to already be
159 /// sorted by name and uniqued.
160 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
161                                              ArrayRef<NamedAttribute> value) {
162   if (value.empty())
163     return DictionaryAttr::getEmpty(context);
164   // Ensure that the attribute elements are unique and sorted.
165   assert(llvm::is_sorted(value,
166                          [](NamedAttribute l, NamedAttribute r) {
167                            return l.first.strref() < r.first.strref();
168                          }) &&
169          "expected attribute values to be sorted");
170   assert(!findDuplicateElement(value) &&
171          "DictionaryAttr element names must be unique");
172   return Base::get(context, value);
173 }
174 
175 /// Return the specified attribute if present, null otherwise.
176 Attribute DictionaryAttr::get(StringRef name) const {
177   Optional<NamedAttribute> attr = getNamed(name);
178   return attr ? attr->second : nullptr;
179 }
180 Attribute DictionaryAttr::get(Identifier name) const {
181   Optional<NamedAttribute> attr = getNamed(name);
182   return attr ? attr->second : nullptr;
183 }
184 
185 /// Return the specified named attribute if present, None otherwise.
186 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
187   ArrayRef<NamedAttribute> values = getValue();
188   const auto *it = llvm::lower_bound(values, name);
189   return it != values.end() && it->first == name ? *it
190                                                  : Optional<NamedAttribute>();
191 }
192 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
193   for (auto elt : getValue())
194     if (elt.first == name)
195       return elt;
196   return llvm::None;
197 }
198 
199 DictionaryAttr::iterator DictionaryAttr::begin() const {
200   return getValue().begin();
201 }
202 DictionaryAttr::iterator DictionaryAttr::end() const {
203   return getValue().end();
204 }
205 size_t DictionaryAttr::size() const { return getValue().size(); }
206 
207 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
208   return Base::get(context, ArrayRef<NamedAttribute>());
209 }
210 
211 void DictionaryAttr::walkImmediateSubElements(
212     function_ref<void(Attribute)> walkAttrsFn,
213     function_ref<void(Type)> walkTypesFn) const {
214   for (Attribute attr : llvm::make_second_range(getValue()))
215     walkAttrsFn(attr);
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // StringAttr
220 //===----------------------------------------------------------------------===//
221 
222 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) {
223   return Base::get(context, "", NoneType::get(context));
224 }
225 
226 /// Twine support for StringAttr.
227 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) {
228   // Fast-path empty twine.
229   if (twine.isTriviallyEmpty())
230     return get(context);
231   SmallVector<char, 32> tempStr;
232   return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context));
233 }
234 
235 /// Twine support for StringAttr.
236 StringAttr StringAttr::get(const Twine &twine, Type type) {
237   SmallVector<char, 32> tempStr;
238   return Base::get(type.getContext(), twine.toStringRef(tempStr), type);
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // FloatAttr
243 //===----------------------------------------------------------------------===//
244 
245 double FloatAttr::getValueAsDouble() const {
246   return getValueAsDouble(getValue());
247 }
248 double FloatAttr::getValueAsDouble(APFloat value) {
249   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
250     bool losesInfo = false;
251     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
252                   &losesInfo);
253   }
254   return value.convertToDouble();
255 }
256 
257 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
258                                 Type type, APFloat value) {
259   // Verify that the type is correct.
260   if (!type.isa<FloatType>())
261     return emitError() << "expected floating point type";
262 
263   // Verify that the type semantics match that of the value.
264   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
265     return emitError()
266            << "FloatAttr type doesn't match the type implied by its value";
267   }
268   return success();
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // SymbolRefAttr
273 //===----------------------------------------------------------------------===//
274 
275 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
276                                  ArrayRef<FlatSymbolRefAttr> nestedRefs) {
277   return get(StringAttr::get(ctx, value), nestedRefs);
278 }
279 
280 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
281   return get(ctx, value, {}).cast<FlatSymbolRefAttr>();
282 }
283 
284 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
285   return get(value, {}).cast<FlatSymbolRefAttr>();
286 }
287 
288 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) {
289   auto symName =
290       symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
291   assert(symName && "value does not have a valid symbol name");
292   return SymbolRefAttr::get(symName);
293 }
294 
295 StringAttr SymbolRefAttr::getLeafReference() const {
296   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
297   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
298 }
299 
300 //===----------------------------------------------------------------------===//
301 // IntegerAttr
302 //===----------------------------------------------------------------------===//
303 
304 int64_t IntegerAttr::getInt() const {
305   assert((getType().isIndex() || getType().isSignlessInteger()) &&
306          "must be signless integer");
307   return getValue().getSExtValue();
308 }
309 
310 int64_t IntegerAttr::getSInt() const {
311   assert(getType().isSignedInteger() && "must be signed integer");
312   return getValue().getSExtValue();
313 }
314 
315 uint64_t IntegerAttr::getUInt() const {
316   assert(getType().isUnsignedInteger() && "must be unsigned integer");
317   return getValue().getZExtValue();
318 }
319 
320 /// Return the value as an APSInt which carries the signed from the type of
321 /// the attribute.  This traps on signless integers types!
322 APSInt IntegerAttr::getAPSInt() const {
323   assert(!getType().isSignlessInteger() &&
324          "Signless integers don't carry a sign for APSInt");
325   return APSInt(getValue(), getType().isUnsignedInteger());
326 }
327 
328 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
329                                   Type type, APInt value) {
330   if (IntegerType integerType = type.dyn_cast<IntegerType>()) {
331     if (integerType.getWidth() != value.getBitWidth())
332       return emitError() << "integer type bit width (" << integerType.getWidth()
333                          << ") doesn't match value bit width ("
334                          << value.getBitWidth() << ")";
335     return success();
336   }
337   if (type.isa<IndexType>())
338     return success();
339   return emitError() << "expected integer or index type";
340 }
341 
342 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
343   auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
344   return attr.cast<BoolAttr>();
345 }
346 
347 //===----------------------------------------------------------------------===//
348 // BoolAttr
349 
350 bool BoolAttr::getValue() const {
351   auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
352   return storage->value.getBoolValue();
353 }
354 
355 bool BoolAttr::classof(Attribute attr) {
356   IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
357   return intAttr && intAttr.getType().isSignlessInteger(1);
358 }
359 
360 //===----------------------------------------------------------------------===//
361 // OpaqueAttr
362 //===----------------------------------------------------------------------===//
363 
364 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
365                                  Identifier dialect, StringRef attrData,
366                                  Type type) {
367   if (!Dialect::isValidNamespace(dialect.strref()))
368     return emitError() << "invalid dialect namespace '" << dialect << "'";
369 
370   // Check that the dialect is actually registered.
371   MLIRContext *context = dialect.getContext();
372   if (!context->allowsUnregisteredDialects() &&
373       !context->getLoadedDialect(dialect.strref())) {
374     return emitError()
375            << "#" << dialect << "<\"" << attrData << "\"> : " << type
376            << " attribute created with unregistered dialect. If this is "
377               "intended, please call allowUnregisteredDialects() on the "
378               "MLIRContext, or use -allow-unregistered-dialect with "
379               "the MLIR opt tool used";
380   }
381 
382   return success();
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // DenseElementsAttr Utilities
387 //===----------------------------------------------------------------------===//
388 
389 /// Get the bitwidth of a dense element type within the buffer.
390 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
391 static size_t getDenseElementStorageWidth(size_t origWidth) {
392   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
393 }
394 static size_t getDenseElementStorageWidth(Type elementType) {
395   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
396 }
397 
398 /// Set a bit to a specific value.
399 static void setBit(char *rawData, size_t bitPos, bool value) {
400   if (value)
401     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
402   else
403     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
404 }
405 
406 /// Return the value of the specified bit.
407 static bool getBit(const char *rawData, size_t bitPos) {
408   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
409 }
410 
411 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
412 /// BE format.
413 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
414                                          char *result) {
415   assert(llvm::support::endian::system_endianness() == // NOLINT
416          llvm::support::endianness::big);              // NOLINT
417   assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
418 
419   // Copy the words filled with data.
420   // For example, when `value` has 2 words, the first word is filled with data.
421   // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
422   size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
423   std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
424               numFilledWords, result);
425   // Convert last word of APInt to LE format and store it in char
426   // array(`valueLE`).
427   // ex. last word of `value` (BE): |------ij|  ==> `valueLE` (LE): |ji------|
428   size_t lastWordPos = numFilledWords;
429   SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
430   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
431       reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
432       valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
433   // Extract actual APInt data from `valueLE`, convert endianness to BE format,
434   // and store it in `result`.
435   // ex. `valueLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|ij|
436   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
437       valueLE.begin(), result + lastWordPos,
438       (numBytes - lastWordPos) * CHAR_BIT, 1);
439 }
440 
441 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
442 /// format.
443 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
444                                          APInt &result) {
445   assert(llvm::support::endian::system_endianness() == // NOLINT
446          llvm::support::endianness::big);              // NOLINT
447   assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
448 
449   // Copy the data that fills the word of `result` from `inArray`.
450   // For example, when `result` has 2 words, the first word will be filled with
451   // data. So, the first 8 bytes are copied from `inArray` here.
452   // `inArray` (10 bytes, BE): |abcdefgh|ij|
453   //                     ==> `result` (2 words, BE): |abcdefgh|--------|
454   size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
455   std::copy_n(
456       inArray, numFilledWords,
457       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
458 
459   // Convert array data which will be last word of `result` to LE format, and
460   // store it in char array(`inArrayLE`).
461   // ex. `inArray` (last two bytes, BE): |ij|  ==> `inArrayLE` (LE): |ji------|
462   size_t lastWordPos = numFilledWords;
463   SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
464   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
465       inArray + lastWordPos, inArrayLE.begin(),
466       (numBytes - lastWordPos) * CHAR_BIT, 1);
467 
468   // Convert `inArrayLE` to BE format, and store it in last word of `result`.
469   // ex. `inArrayLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|------ij|
470   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
471       inArrayLE.begin(),
472       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
473           lastWordPos,
474       APInt::APINT_BITS_PER_WORD, 1);
475 }
476 
477 /// Writes value to the bit position `bitPos` in array `rawData`.
478 static void writeBits(char *rawData, size_t bitPos, APInt value) {
479   size_t bitWidth = value.getBitWidth();
480 
481   // If the bitwidth is 1 we just toggle the specific bit.
482   if (bitWidth == 1)
483     return setBit(rawData, bitPos, value.isOneValue());
484 
485   // Otherwise, the bit position is guaranteed to be byte aligned.
486   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
487   if (llvm::support::endian::system_endianness() ==
488       llvm::support::endianness::big) {
489     // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
490     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
491     // work correctly in BE format.
492     // ex. `value` (2 words including 10 bytes)
493     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------|
494     copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
495                                  rawData + (bitPos / CHAR_BIT));
496   } else {
497     std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
498                 llvm::divideCeil(bitWidth, CHAR_BIT),
499                 rawData + (bitPos / CHAR_BIT));
500   }
501 }
502 
503 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
504 /// `rawData`.
505 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
506   // Handle a boolean bit position.
507   if (bitWidth == 1)
508     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
509 
510   // Otherwise, the bit position must be 8-bit aligned.
511   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
512   APInt result(bitWidth, 0);
513   if (llvm::support::endian::system_endianness() ==
514       llvm::support::endianness::big) {
515     // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
516     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
517     // work correctly in BE format.
518     // ex. `result` (2 words including 10 bytes)
519     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------| This function
520     copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
521                                  llvm::divideCeil(bitWidth, CHAR_BIT), result);
522   } else {
523     std::copy_n(rawData + (bitPos / CHAR_BIT),
524                 llvm::divideCeil(bitWidth, CHAR_BIT),
525                 const_cast<char *>(
526                     reinterpret_cast<const char *>(result.getRawData())));
527   }
528   return result;
529 }
530 
531 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
532 /// the same element count as 'type'.
533 template <typename Values>
534 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
535   return (values.size() == 1) ||
536          (type.getNumElements() == static_cast<int64_t>(values.size()));
537 }
538 
539 //===----------------------------------------------------------------------===//
540 // DenseElementsAttr Iterators
541 //===----------------------------------------------------------------------===//
542 
543 //===----------------------------------------------------------------------===//
544 // AttributeElementIterator
545 
546 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
547     DenseElementsAttr attr, size_t index)
548     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
549                                       Attribute, Attribute, Attribute>(
550           attr.getAsOpaquePointer(), index) {}
551 
552 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
553   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
554   Type eltTy = owner.getElementType();
555   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
556     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
557   if (eltTy.isa<IndexType>())
558     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
559   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
560     IntElementIterator intIt(owner, index);
561     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
562     return FloatAttr::get(eltTy, *floatIt);
563   }
564   if (auto complexTy = eltTy.dyn_cast<ComplexType>()) {
565     auto complexEltTy = complexTy.getElementType();
566     ComplexIntElementIterator complexIntIt(owner, index);
567     if (complexEltTy.isa<IntegerType>()) {
568       auto value = *complexIntIt;
569       auto real = IntegerAttr::get(complexEltTy, value.real());
570       auto imag = IntegerAttr::get(complexEltTy, value.imag());
571       return ArrayAttr::get(complexTy.getContext(),
572                             ArrayRef<Attribute>{real, imag});
573     }
574 
575     ComplexFloatElementIterator complexFloatIt(
576         complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt);
577     auto value = *complexFloatIt;
578     auto real = FloatAttr::get(complexEltTy, value.real());
579     auto imag = FloatAttr::get(complexEltTy, value.imag());
580     return ArrayAttr::get(complexTy.getContext(),
581                           ArrayRef<Attribute>{real, imag});
582   }
583   if (owner.isa<DenseStringElementsAttr>()) {
584     ArrayRef<StringRef> vals = owner.getRawStringData();
585     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
586   }
587   llvm_unreachable("unexpected element type");
588 }
589 
590 //===----------------------------------------------------------------------===//
591 // BoolElementIterator
592 
593 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
594     DenseElementsAttr attr, size_t dataIndex)
595     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
596           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
597 
598 bool DenseElementsAttr::BoolElementIterator::operator*() const {
599   return getBit(getData(), getDataIndex());
600 }
601 
602 //===----------------------------------------------------------------------===//
603 // IntElementIterator
604 
605 DenseElementsAttr::IntElementIterator::IntElementIterator(
606     DenseElementsAttr attr, size_t dataIndex)
607     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
608           attr.getRawData().data(), attr.isSplat(), dataIndex),
609       bitWidth(getDenseElementBitWidth(attr.getElementType())) {}
610 
611 APInt DenseElementsAttr::IntElementIterator::operator*() const {
612   return readBits(getData(),
613                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
614                   bitWidth);
615 }
616 
617 //===----------------------------------------------------------------------===//
618 // ComplexIntElementIterator
619 
620 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
621     DenseElementsAttr attr, size_t dataIndex)
622     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
623                                       std::complex<APInt>, std::complex<APInt>,
624                                       std::complex<APInt>>(
625           attr.getRawData().data(), attr.isSplat(), dataIndex) {
626   auto complexType = attr.getElementType().cast<ComplexType>();
627   bitWidth = getDenseElementBitWidth(complexType.getElementType());
628 }
629 
630 std::complex<APInt>
631 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
632   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
633   size_t offset = getDataIndex() * storageWidth * 2;
634   return {readBits(getData(), offset, bitWidth),
635           readBits(getData(), offset + storageWidth, bitWidth)};
636 }
637 
638 //===----------------------------------------------------------------------===//
639 // FloatElementIterator
640 
641 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
642     const llvm::fltSemantics &smt, IntElementIterator it)
643     : llvm::mapped_iterator<IntElementIterator,
644                             std::function<APFloat(const APInt &)>>(
645           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
646 
647 //===----------------------------------------------------------------------===//
648 // ComplexFloatElementIterator
649 
650 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
651     const llvm::fltSemantics &smt, ComplexIntElementIterator it)
652     : llvm::mapped_iterator<
653           ComplexIntElementIterator,
654           std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
655           it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
656             return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
657           }) {}
658 
659 //===----------------------------------------------------------------------===//
660 // DenseElementsAttr
661 //===----------------------------------------------------------------------===//
662 
663 /// Method for support type inquiry through isa, cast and dyn_cast.
664 bool DenseElementsAttr::classof(Attribute attr) {
665   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
666 }
667 
668 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
669                                          ArrayRef<Attribute> values) {
670   assert(hasSameElementsOrSplat(type, values));
671 
672   // If the element type is not based on int/float/index, assume it is a string
673   // type.
674   auto eltType = type.getElementType();
675   if (!type.getElementType().isIntOrIndexOrFloat()) {
676     SmallVector<StringRef, 8> stringValues;
677     stringValues.reserve(values.size());
678     for (Attribute attr : values) {
679       assert(attr.isa<StringAttr>() &&
680              "expected string value for non integer/index/float element");
681       stringValues.push_back(attr.cast<StringAttr>().getValue());
682     }
683     return get(type, stringValues);
684   }
685 
686   // Otherwise, get the raw storage width to use for the allocation.
687   size_t bitWidth = getDenseElementBitWidth(eltType);
688   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
689 
690   // Compress the attribute values into a character buffer.
691   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
692                             values.size());
693   APInt intVal;
694   for (unsigned i = 0, e = values.size(); i < e; ++i) {
695     assert(eltType == values[i].getType() &&
696            "expected attribute value to have element type");
697     if (eltType.isa<FloatType>())
698       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
699     else if (eltType.isa<IntegerType, IndexType>())
700       intVal = values[i].cast<IntegerAttr>().getValue();
701     else
702       llvm_unreachable("unexpected element type");
703 
704     assert(intVal.getBitWidth() == bitWidth &&
705            "expected value to have same bitwidth as element type");
706     writeBits(data.data(), i * storageBitWidth, intVal);
707   }
708   return DenseIntOrFPElementsAttr::getRaw(type, data,
709                                           /*isSplat=*/(values.size() == 1));
710 }
711 
712 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
713                                          ArrayRef<bool> values) {
714   assert(hasSameElementsOrSplat(type, values));
715   assert(type.getElementType().isInteger(1));
716 
717   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
718   for (int i = 0, e = values.size(); i != e; ++i)
719     setBit(buff.data(), i, values[i]);
720   return DenseIntOrFPElementsAttr::getRaw(type, buff,
721                                           /*isSplat=*/(values.size() == 1));
722 }
723 
724 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
725                                          ArrayRef<StringRef> values) {
726   assert(!type.getElementType().isIntOrFloat());
727   return DenseStringElementsAttr::get(type, values);
728 }
729 
730 /// Constructs a dense integer elements attribute from an array of APInt
731 /// values. Each APInt value is expected to have the same bitwidth as the
732 /// element type of 'type'.
733 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
734                                          ArrayRef<APInt> values) {
735   assert(type.getElementType().isIntOrIndex());
736   assert(hasSameElementsOrSplat(type, values));
737   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
738   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
739                                           /*isSplat=*/(values.size() == 1));
740 }
741 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
742                                          ArrayRef<std::complex<APInt>> values) {
743   ComplexType complex = type.getElementType().cast<ComplexType>();
744   assert(complex.getElementType().isa<IntegerType>());
745   assert(hasSameElementsOrSplat(type, values));
746   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
747   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
748                           values.size() * 2);
749   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
750                                           /*isSplat=*/(values.size() == 1));
751 }
752 
753 // Constructs a dense float elements attribute from an array of APFloat
754 // values. Each APFloat value is expected to have the same bitwidth as the
755 // element type of 'type'.
756 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
757                                          ArrayRef<APFloat> values) {
758   assert(type.getElementType().isa<FloatType>());
759   assert(hasSameElementsOrSplat(type, values));
760   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
761   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
762                                           /*isSplat=*/(values.size() == 1));
763 }
764 DenseElementsAttr
765 DenseElementsAttr::get(ShapedType type,
766                        ArrayRef<std::complex<APFloat>> values) {
767   ComplexType complex = type.getElementType().cast<ComplexType>();
768   assert(complex.getElementType().isa<FloatType>());
769   assert(hasSameElementsOrSplat(type, values));
770   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
771                            values.size() * 2);
772   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
773   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
774                                           /*isSplat=*/(values.size() == 1));
775 }
776 
777 /// Construct a dense elements attribute from a raw buffer representing the
778 /// data for this attribute. Users should generally not use this methods as
779 /// the expected buffer format may not be a form the user expects.
780 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
781                                                       ArrayRef<char> rawBuffer,
782                                                       bool isSplatBuffer) {
783   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
784 }
785 
786 /// Returns true if the given buffer is a valid raw buffer for the given type.
787 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
788                                          ArrayRef<char> rawBuffer,
789                                          bool &detectedSplat) {
790   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
791   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
792 
793   // Storage width of 1 is special as it is packed by the bit.
794   if (storageWidth == 1) {
795     // Check for a splat, or a buffer equal to the number of elements which
796     // consists of either all 0's or all 1's.
797     detectedSplat = false;
798     if (rawBuffer.size() == 1) {
799       auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
800       if (rawByte == 0 || rawByte == 0xff) {
801         detectedSplat = true;
802         return true;
803       }
804     }
805     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
806   }
807   // All other types are 8-bit aligned.
808   if ((detectedSplat = rawBufferWidth == storageWidth))
809     return true;
810   return rawBufferWidth == (storageWidth * type.getNumElements());
811 }
812 
813 /// Check the information for a C++ data type, check if this type is valid for
814 /// the current attribute. This method is used to verify specific type
815 /// invariants that the templatized 'getValues' method cannot.
816 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
817                               bool isSigned) {
818   // Make sure that the data element size is the same as the type element width.
819   if (getDenseElementBitWidth(type) !=
820       static_cast<size_t>(dataEltSize * CHAR_BIT))
821     return false;
822 
823   // Check that the element type is either float or integer or index.
824   if (!isInt)
825     return type.isa<FloatType>();
826   if (type.isIndex())
827     return true;
828 
829   auto intType = type.dyn_cast<IntegerType>();
830   if (!intType)
831     return false;
832 
833   // Make sure signedness semantics is consistent.
834   if (intType.isSignless())
835     return true;
836   return intType.isSigned() ? isSigned : !isSigned;
837 }
838 
839 /// Defaults down the subclass implementation.
840 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
841                                                    ArrayRef<char> data,
842                                                    int64_t dataEltSize,
843                                                    bool isInt, bool isSigned) {
844   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
845                                                  isSigned);
846 }
847 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
848                                                       ArrayRef<char> data,
849                                                       int64_t dataEltSize,
850                                                       bool isInt,
851                                                       bool isSigned) {
852   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
853                                                     isInt, isSigned);
854 }
855 
856 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
857                                           bool isSigned) const {
858   return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned);
859 }
860 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
861                                        bool isSigned) const {
862   return ::isValidIntOrFloat(
863       getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2,
864       isInt, isSigned);
865 }
866 
867 /// Returns true if this attribute corresponds to a splat, i.e. if all element
868 /// values are the same.
869 bool DenseElementsAttr::isSplat() const {
870   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
871 }
872 
873 /// Return if the given complex type has an integer element type.
874 static bool isComplexOfIntType(Type type) {
875   return type.cast<ComplexType>().getElementType().isa<IntegerType>();
876 }
877 
878 auto DenseElementsAttr::getComplexIntValues() const
879     -> llvm::iterator_range<ComplexIntElementIterator> {
880   assert(isComplexOfIntType(getElementType()) &&
881          "expected complex integral type");
882   return {ComplexIntElementIterator(*this, 0),
883           ComplexIntElementIterator(*this, getNumElements())};
884 }
885 auto DenseElementsAttr::complex_value_begin() const
886     -> ComplexIntElementIterator {
887   assert(isComplexOfIntType(getElementType()) &&
888          "expected complex integral type");
889   return ComplexIntElementIterator(*this, 0);
890 }
891 auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator {
892   assert(isComplexOfIntType(getElementType()) &&
893          "expected complex integral type");
894   return ComplexIntElementIterator(*this, getNumElements());
895 }
896 
897 /// Return the held element values as a range of APFloat. The element type of
898 /// this attribute must be of float type.
899 auto DenseElementsAttr::getFloatValues() const
900     -> llvm::iterator_range<FloatElementIterator> {
901   auto elementType = getElementType().cast<FloatType>();
902   const auto &elementSemantics = elementType.getFloatSemantics();
903   return {FloatElementIterator(elementSemantics, raw_int_begin()),
904           FloatElementIterator(elementSemantics, raw_int_end())};
905 }
906 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
907   auto elementType = getElementType().cast<FloatType>();
908   return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin());
909 }
910 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
911   auto elementType = getElementType().cast<FloatType>();
912   return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end());
913 }
914 
915 auto DenseElementsAttr::getComplexFloatValues() const
916     -> llvm::iterator_range<ComplexFloatElementIterator> {
917   Type eltTy = getElementType().cast<ComplexType>().getElementType();
918   assert(eltTy.isa<FloatType>() && "expected complex float type");
919   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
920   return {{semantics, {*this, 0}},
921           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
922 }
923 auto DenseElementsAttr::complex_float_value_begin() const
924     -> ComplexFloatElementIterator {
925   Type eltTy = getElementType().cast<ComplexType>().getElementType();
926   assert(eltTy.isa<FloatType>() && "expected complex float type");
927   return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}};
928 }
929 auto DenseElementsAttr::complex_float_value_end() const
930     -> ComplexFloatElementIterator {
931   Type eltTy = getElementType().cast<ComplexType>().getElementType();
932   assert(eltTy.isa<FloatType>() && "expected complex float type");
933   return {eltTy.cast<FloatType>().getFloatSemantics(),
934           {*this, static_cast<size_t>(getNumElements())}};
935 }
936 
937 /// Return the raw storage data held by this attribute.
938 ArrayRef<char> DenseElementsAttr::getRawData() const {
939   return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
940 }
941 
942 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
943   return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
944 }
945 
946 /// Return a new DenseElementsAttr that has the same data as the current
947 /// attribute, but has been reshaped to 'newType'. The new type must have the
948 /// same total number of elements as well as element type.
949 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
950   ShapedType curType = getType();
951   if (curType == newType)
952     return *this;
953 
954   assert(newType.getElementType() == curType.getElementType() &&
955          "expected the same element type");
956   assert(newType.getNumElements() == curType.getNumElements() &&
957          "expected the same number of elements");
958   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
959 }
960 
961 /// Return a new DenseElementsAttr that has the same data as the current
962 /// attribute, but has bitcast elements such that it is now 'newType'. The new
963 /// type must have the same shape and element types of the same bitwidth as the
964 /// current type.
965 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
966   ShapedType curType = getType();
967   Type curElType = curType.getElementType();
968   if (curElType == newElType)
969     return *this;
970 
971   assert(getDenseElementBitWidth(newElType) ==
972              getDenseElementBitWidth(curElType) &&
973          "expected element types with the same bitwidth");
974   return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
975                                           getRawData(), isSplat());
976 }
977 
978 DenseElementsAttr
979 DenseElementsAttr::mapValues(Type newElementType,
980                              function_ref<APInt(const APInt &)> mapping) const {
981   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
982 }
983 
984 DenseElementsAttr DenseElementsAttr::mapValues(
985     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
986   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
987 }
988 
989 ShapedType DenseElementsAttr::getType() const {
990   return Attribute::getType().cast<ShapedType>();
991 }
992 
993 Type DenseElementsAttr::getElementType() const {
994   return getType().getElementType();
995 }
996 
997 int64_t DenseElementsAttr::getNumElements() const {
998   return getType().getNumElements();
999 }
1000 
1001 //===----------------------------------------------------------------------===//
1002 // DenseIntOrFPElementsAttr
1003 //===----------------------------------------------------------------------===//
1004 
1005 /// Utility method to write a range of APInt values to a buffer.
1006 template <typename APRangeT>
1007 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1008                                 APRangeT &&values) {
1009   data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
1010   size_t offset = 0;
1011   for (auto it = values.begin(), e = values.end(); it != e;
1012        ++it, offset += storageWidth) {
1013     assert((*it).getBitWidth() <= storageWidth);
1014     writeBits(data.data(), offset, *it);
1015   }
1016 }
1017 
1018 /// Constructs a dense elements attribute from an array of raw APFloat values.
1019 /// Each APFloat value is expected to have the same bitwidth as the element
1020 /// type of 'type'. 'type' must be a vector or tensor with static shape.
1021 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1022                                                    size_t storageWidth,
1023                                                    ArrayRef<APFloat> values,
1024                                                    bool isSplat) {
1025   std::vector<char> data;
1026   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1027   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1028   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1029 }
1030 
1031 /// Constructs a dense elements attribute from an array of raw APInt values.
1032 /// Each APInt value is expected to have the same bitwidth as the element type
1033 /// of 'type'.
1034 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1035                                                    size_t storageWidth,
1036                                                    ArrayRef<APInt> values,
1037                                                    bool isSplat) {
1038   std::vector<char> data;
1039   writeAPIntsToBuffer(storageWidth, data, values);
1040   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1041 }
1042 
1043 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1044                                                    ArrayRef<char> data,
1045                                                    bool isSplat) {
1046   assert((type.isa<RankedTensorType, VectorType>()) &&
1047          "type must be ranked tensor or vector");
1048   assert(type.hasStaticShape() && "type must have static shape");
1049   return Base::get(type.getContext(), type, data, isSplat);
1050 }
1051 
1052 /// Overload of the raw 'get' method that asserts that the given type is of
1053 /// complex type. This method is used to verify type invariants that the
1054 /// templatized 'get' method cannot.
1055 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1056                                                           ArrayRef<char> data,
1057                                                           int64_t dataEltSize,
1058                                                           bool isInt,
1059                                                           bool isSigned) {
1060   assert(::isValidIntOrFloat(
1061       type.getElementType().cast<ComplexType>().getElementType(),
1062       dataEltSize / 2, isInt, isSigned));
1063 
1064   int64_t numElements = data.size() / dataEltSize;
1065   assert(numElements == 1 || numElements == type.getNumElements());
1066   return getRaw(type, data, /*isSplat=*/numElements == 1);
1067 }
1068 
1069 /// Overload of the 'getRaw' method that asserts that the given type is of
1070 /// integer type. This method is used to verify type invariants that the
1071 /// templatized 'get' method cannot.
1072 DenseElementsAttr
1073 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1074                                            int64_t dataEltSize, bool isInt,
1075                                            bool isSigned) {
1076   assert(
1077       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1078 
1079   int64_t numElements = data.size() / dataEltSize;
1080   assert(numElements == 1 || numElements == type.getNumElements());
1081   return getRaw(type, data, /*isSplat=*/numElements == 1);
1082 }
1083 
1084 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1085     const char *inRawData, char *outRawData, size_t elementBitWidth,
1086     size_t numElements) {
1087   using llvm::support::ulittle16_t;
1088   using llvm::support::ulittle32_t;
1089   using llvm::support::ulittle64_t;
1090 
1091   assert(llvm::support::endian::system_endianness() == // NOLINT
1092          llvm::support::endianness::big);              // NOLINT
1093   // NOLINT to avoid warning message about replacing by static_assert()
1094 
1095   // Following std::copy_n always converts endianness on BE machine.
1096   switch (elementBitWidth) {
1097   case 16: {
1098     const ulittle16_t *inRawDataPos =
1099         reinterpret_cast<const ulittle16_t *>(inRawData);
1100     uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1101     std::copy_n(inRawDataPos, numElements, outDataPos);
1102     break;
1103   }
1104   case 32: {
1105     const ulittle32_t *inRawDataPos =
1106         reinterpret_cast<const ulittle32_t *>(inRawData);
1107     uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1108     std::copy_n(inRawDataPos, numElements, outDataPos);
1109     break;
1110   }
1111   case 64: {
1112     const ulittle64_t *inRawDataPos =
1113         reinterpret_cast<const ulittle64_t *>(inRawData);
1114     uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1115     std::copy_n(inRawDataPos, numElements, outDataPos);
1116     break;
1117   }
1118   default: {
1119     size_t nBytes = elementBitWidth / CHAR_BIT;
1120     for (size_t i = 0; i < nBytes; i++)
1121       std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1122     break;
1123   }
1124   }
1125 }
1126 
1127 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1128     ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1129     ShapedType type) {
1130   size_t numElements = type.getNumElements();
1131   Type elementType = type.getElementType();
1132   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1133     elementType = complexTy.getElementType();
1134     numElements = numElements * 2;
1135   }
1136   size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1137   assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1138          inRawData.size() <= outRawData.size());
1139   convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1140                                   elementBitWidth, numElements);
1141 }
1142 
1143 //===----------------------------------------------------------------------===//
1144 // DenseFPElementsAttr
1145 //===----------------------------------------------------------------------===//
1146 
1147 template <typename Fn, typename Attr>
1148 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1149                                 Type newElementType,
1150                                 llvm::SmallVectorImpl<char> &data) {
1151   size_t bitWidth = getDenseElementBitWidth(newElementType);
1152   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1153 
1154   ShapedType newArrayType;
1155   if (inType.isa<RankedTensorType>())
1156     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1157   else if (inType.isa<UnrankedTensorType>())
1158     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1159   else if (inType.isa<VectorType>())
1160     newArrayType = VectorType::get(inType.getShape(), newElementType);
1161   else
1162     assert(newArrayType && "Unhandled tensor type");
1163 
1164   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1165   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1166 
1167   // Functor used to process a single element value of the attribute.
1168   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1169     auto newInt = mapping(value);
1170     assert(newInt.getBitWidth() == bitWidth);
1171     writeBits(data.data(), index * storageBitWidth, newInt);
1172   };
1173 
1174   // Check for the splat case.
1175   if (attr.isSplat()) {
1176     processElt(*attr.begin(), /*index=*/0);
1177     return newArrayType;
1178   }
1179 
1180   // Otherwise, process all of the element values.
1181   uint64_t elementIdx = 0;
1182   for (auto value : attr)
1183     processElt(value, elementIdx++);
1184   return newArrayType;
1185 }
1186 
1187 DenseElementsAttr DenseFPElementsAttr::mapValues(
1188     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1189   llvm::SmallVector<char, 8> elementData;
1190   auto newArrayType =
1191       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1192 
1193   return getRaw(newArrayType, elementData, isSplat());
1194 }
1195 
1196 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1197 bool DenseFPElementsAttr::classof(Attribute attr) {
1198   return attr.isa<DenseElementsAttr>() &&
1199          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1200 }
1201 
1202 //===----------------------------------------------------------------------===//
1203 // DenseIntElementsAttr
1204 //===----------------------------------------------------------------------===//
1205 
1206 DenseElementsAttr DenseIntElementsAttr::mapValues(
1207     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1208   llvm::SmallVector<char, 8> elementData;
1209   auto newArrayType =
1210       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1211 
1212   return getRaw(newArrayType, elementData, isSplat());
1213 }
1214 
1215 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1216 bool DenseIntElementsAttr::classof(Attribute attr) {
1217   return attr.isa<DenseElementsAttr>() &&
1218          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1219 }
1220 
1221 //===----------------------------------------------------------------------===//
1222 // OpaqueElementsAttr
1223 //===----------------------------------------------------------------------===//
1224 
1225 /// Return the value at the given index. If index does not refer to a valid
1226 /// element, then a null attribute is returned.
1227 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1228   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1229   return Attribute();
1230 }
1231 
1232 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1233   Dialect *dialect = getDialect().getDialect();
1234   if (!dialect)
1235     return true;
1236   auto *interface =
1237       dialect->getRegisteredInterface<DialectDecodeAttributesInterface>();
1238   if (!interface)
1239     return true;
1240   return failed(interface->decode(*this, result));
1241 }
1242 
1243 LogicalResult
1244 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1245                            Identifier dialect, StringRef value,
1246                            ShapedType type) {
1247   if (!Dialect::isValidNamespace(dialect.strref()))
1248     return emitError() << "invalid dialect namespace '" << dialect << "'";
1249   return success();
1250 }
1251 
1252 //===----------------------------------------------------------------------===//
1253 // SparseElementsAttr
1254 //===----------------------------------------------------------------------===//
1255 
1256 /// Return the value of the element at the given index.
1257 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1258   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1259   auto type = getType();
1260 
1261   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1262   // as a 1-D index array.
1263   auto sparseIndices = getIndices();
1264   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1265 
1266   // Check to see if the indices are a splat.
1267   if (sparseIndices.isSplat()) {
1268     // If the index is also not a splat of the index value, we know that the
1269     // value is zero.
1270     auto splatIndex = *sparseIndexValues.begin();
1271     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
1272       return getZeroAttr();
1273 
1274     // If the indices are a splat, we also expect the values to be a splat.
1275     assert(getValues().isSplat() && "expected splat values");
1276     return getValues().getSplatValue();
1277   }
1278 
1279   // Build a mapping between known indices and the offset of the stored element.
1280   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
1281   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1282   size_t rank = type.getRank();
1283   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1284     mappedIndices.try_emplace(
1285         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1286 
1287   // Look for the provided index key within the mapped indices. If the provided
1288   // index is not found, then return a zero attribute.
1289   auto it = mappedIndices.find(index);
1290   if (it == mappedIndices.end())
1291     return getZeroAttr();
1292 
1293   // Otherwise, return the held sparse value element.
1294   return getValues().getValue(it->second);
1295 }
1296 
1297 /// Get a zero APFloat for the given sparse attribute.
1298 APFloat SparseElementsAttr::getZeroAPFloat() const {
1299   auto eltType = getElementType().cast<FloatType>();
1300   return APFloat(eltType.getFloatSemantics());
1301 }
1302 
1303 /// Get a zero APInt for the given sparse attribute.
1304 APInt SparseElementsAttr::getZeroAPInt() const {
1305   auto eltType = getElementType().cast<IntegerType>();
1306   return APInt::getZero(eltType.getWidth());
1307 }
1308 
1309 /// Get a zero attribute for the given attribute type.
1310 Attribute SparseElementsAttr::getZeroAttr() const {
1311   auto eltType = getElementType();
1312 
1313   // Handle floating point elements.
1314   if (eltType.isa<FloatType>())
1315     return FloatAttr::get(eltType, 0);
1316 
1317   // Otherwise, this is an integer.
1318   // TODO: Handle StringAttr here.
1319   return IntegerAttr::get(eltType, 0);
1320 }
1321 
1322 /// Flatten, and return, all of the sparse indices in this attribute in
1323 /// row-major order.
1324 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1325   std::vector<ptrdiff_t> flatSparseIndices;
1326 
1327   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1328   // as a 1-D index array.
1329   auto sparseIndices = getIndices();
1330   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1331   if (sparseIndices.isSplat()) {
1332     SmallVector<uint64_t, 8> indices(getType().getRank(),
1333                                      *sparseIndexValues.begin());
1334     flatSparseIndices.push_back(getFlattenedIndex(indices));
1335     return flatSparseIndices;
1336   }
1337 
1338   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1339   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1340   size_t rank = getType().getRank();
1341   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1342     flatSparseIndices.push_back(getFlattenedIndex(
1343         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1344   return flatSparseIndices;
1345 }
1346 
1347 LogicalResult
1348 SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1349                            ShapedType type, DenseIntElementsAttr sparseIndices,
1350                            DenseElementsAttr values) {
1351   ShapedType valuesType = values.getType();
1352   if (valuesType.getRank() != 1)
1353     return emitError() << "expected 1-d tensor for sparse element values";
1354 
1355   // Verify the indices and values shape.
1356   ShapedType indicesType = sparseIndices.getType();
1357   auto emitShapeError = [&]() {
1358     return emitError() << "expected shape ([" << type.getShape()
1359                        << "]); inferred shape of indices literal (["
1360                        << indicesType.getShape()
1361                        << "]); inferred shape of values literal (["
1362                        << valuesType.getShape() << "])";
1363   };
1364   // Verify indices shape.
1365   size_t rank = type.getRank(), indicesRank = indicesType.getRank();
1366   if (indicesRank == 2) {
1367     if (indicesType.getDimSize(1) != static_cast<int64_t>(rank))
1368       return emitShapeError();
1369   } else if (indicesRank != 1 || rank != 1) {
1370     return emitShapeError();
1371   }
1372   // Verify the values shape.
1373   int64_t numSparseIndices = indicesType.getDimSize(0);
1374   if (numSparseIndices != valuesType.getDimSize(0))
1375     return emitShapeError();
1376 
1377   // Verify that the sparse indices are within the value shape.
1378   auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) {
1379     return emitError()
1380            << "sparse index #" << indexNum
1381            << " is not contained within the value shape, with index=[" << index
1382            << "], and type=" << type;
1383   };
1384 
1385   // Handle the case where the index values are a splat.
1386   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1387   if (sparseIndices.isSplat()) {
1388     SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin());
1389     if (!ElementsAttr::isValidIndex(type, indices))
1390       return emitIndexError(0, indices);
1391     return success();
1392   }
1393 
1394   // Otherwise, reinterpret each index as an ArrayRef.
1395   for (size_t i = 0, e = numSparseIndices; i != e; ++i) {
1396     ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank),
1397                              rank);
1398     if (!ElementsAttr::isValidIndex(type, index))
1399       return emitIndexError(i, index);
1400   }
1401 
1402   return success();
1403 }
1404 
1405 //===----------------------------------------------------------------------===//
1406 // TypeAttr
1407 //===----------------------------------------------------------------------===//
1408 
1409 void TypeAttr::walkImmediateSubElements(
1410     function_ref<void(Attribute)> walkAttrsFn,
1411     function_ref<void(Type)> walkTypesFn) const {
1412   walkTypesFn(getValue());
1413 }
1414