1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BuiltinDialect.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Types.h"
17 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
18 #include "llvm/ADT/APSInt.h"
19 #include "llvm/ADT/Sequence.h"
20 #include "llvm/ADT/Twine.h"
21 #include "llvm/Support/Endian.h"
22 
23 using namespace mlir;
24 using namespace mlir::detail;
25 
26 //===----------------------------------------------------------------------===//
27 /// Tablegen Attribute Definitions
28 //===----------------------------------------------------------------------===//
29 
30 #define GET_ATTRDEF_CLASSES
31 #include "mlir/IR/BuiltinAttributes.cpp.inc"
32 
33 //===----------------------------------------------------------------------===//
34 // BuiltinDialect
35 //===----------------------------------------------------------------------===//
36 
37 void BuiltinDialect::registerAttributes() {
38   addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
39                 DenseStringElementsAttr, DictionaryAttr, FloatAttr,
40                 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
41                 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
42                 UnitAttr>();
43 }
44 
45 //===----------------------------------------------------------------------===//
46 // ArrayAttr
47 //===----------------------------------------------------------------------===//
48 
49 void ArrayAttr::walkImmediateSubElements(
50     function_ref<void(Attribute)> walkAttrsFn,
51     function_ref<void(Type)> walkTypesFn) const {
52   for (Attribute attr : getValue())
53     walkAttrsFn(attr);
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // DictionaryAttr
58 //===----------------------------------------------------------------------===//
59 
60 /// Helper function that does either an in place sort or sorts from source array
61 /// into destination. If inPlace then storage is both the source and the
62 /// destination, else value is the source and storage destination. Returns
63 /// whether source was sorted.
64 template <bool inPlace>
65 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
66                                SmallVectorImpl<NamedAttribute> &storage) {
67   // Specialize for the common case.
68   switch (value.size()) {
69   case 0:
70     // Zero already sorted.
71     break;
72   case 1:
73     // One already sorted but may need to be copied.
74     if (!inPlace)
75       storage.assign({value[0]});
76     break;
77   case 2: {
78     bool isSorted = value[0] < value[1];
79     if (inPlace) {
80       if (!isSorted)
81         std::swap(storage[0], storage[1]);
82     } else if (isSorted) {
83       storage.assign({value[0], value[1]});
84     } else {
85       storage.assign({value[1], value[0]});
86     }
87     return !isSorted;
88   }
89   default:
90     if (!inPlace)
91       storage.assign(value.begin(), value.end());
92     // Check to see they are sorted already.
93     bool isSorted = llvm::is_sorted(value);
94     // If not, do a general sort.
95     if (!isSorted)
96       llvm::array_pod_sort(storage.begin(), storage.end());
97     return !isSorted;
98   }
99   return false;
100 }
101 
102 /// Returns an entry with a duplicate name from the given sorted array of named
103 /// attributes. Returns llvm::None if all elements have unique names.
104 static Optional<NamedAttribute>
105 findDuplicateElement(ArrayRef<NamedAttribute> value) {
106   const Optional<NamedAttribute> none{llvm::None};
107   if (value.size() < 2)
108     return none;
109 
110   if (value.size() == 2)
111     return value[0].first == value[1].first ? value[0] : none;
112 
113   auto it = std::adjacent_find(
114       value.begin(), value.end(),
115       [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; });
116   return it != value.end() ? *it : none;
117 }
118 
119 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
120                           SmallVectorImpl<NamedAttribute> &storage) {
121   bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
122   assert(!findDuplicateElement(storage) &&
123          "DictionaryAttr element names must be unique");
124   return isSorted;
125 }
126 
127 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
128   bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
129   assert(!findDuplicateElement(array) &&
130          "DictionaryAttr element names must be unique");
131   return isSorted;
132 }
133 
134 Optional<NamedAttribute>
135 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
136                               bool isSorted) {
137   if (!isSorted)
138     dictionaryAttrSort</*inPlace=*/true>(array, array);
139   return findDuplicateElement(array);
140 }
141 
142 DictionaryAttr DictionaryAttr::get(MLIRContext *context,
143                                    ArrayRef<NamedAttribute> value) {
144   if (value.empty())
145     return DictionaryAttr::getEmpty(context);
146   assert(llvm::all_of(value,
147                       [](const NamedAttribute &attr) { return attr.second; }) &&
148          "value cannot have null entries");
149 
150   // We need to sort the element list to canonicalize it.
151   SmallVector<NamedAttribute, 8> storage;
152   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
153     value = storage;
154   assert(!findDuplicateElement(value) &&
155          "DictionaryAttr element names must be unique");
156   return Base::get(context, value);
157 }
158 /// Construct a dictionary with an array of values that is known to already be
159 /// sorted by name and uniqued.
160 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
161                                              ArrayRef<NamedAttribute> value) {
162   if (value.empty())
163     return DictionaryAttr::getEmpty(context);
164   // Ensure that the attribute elements are unique and sorted.
165   assert(llvm::is_sorted(value,
166                          [](NamedAttribute l, NamedAttribute r) {
167                            return l.first.strref() < r.first.strref();
168                          }) &&
169          "expected attribute values to be sorted");
170   assert(!findDuplicateElement(value) &&
171          "DictionaryAttr element names must be unique");
172   return Base::get(context, value);
173 }
174 
175 /// Return the specified attribute if present, null otherwise.
176 Attribute DictionaryAttr::get(StringRef name) const {
177   Optional<NamedAttribute> attr = getNamed(name);
178   return attr ? attr->second : nullptr;
179 }
180 Attribute DictionaryAttr::get(Identifier name) const {
181   Optional<NamedAttribute> attr = getNamed(name);
182   return attr ? attr->second : nullptr;
183 }
184 
185 /// Return the specified named attribute if present, None otherwise.
186 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
187   ArrayRef<NamedAttribute> values = getValue();
188   const auto *it = llvm::lower_bound(values, name);
189   return it != values.end() && it->first == name ? *it
190                                                  : Optional<NamedAttribute>();
191 }
192 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
193   for (auto elt : getValue())
194     if (elt.first == name)
195       return elt;
196   return llvm::None;
197 }
198 
199 DictionaryAttr::iterator DictionaryAttr::begin() const {
200   return getValue().begin();
201 }
202 DictionaryAttr::iterator DictionaryAttr::end() const {
203   return getValue().end();
204 }
205 size_t DictionaryAttr::size() const { return getValue().size(); }
206 
207 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
208   return Base::get(context, ArrayRef<NamedAttribute>());
209 }
210 
211 void DictionaryAttr::walkImmediateSubElements(
212     function_ref<void(Attribute)> walkAttrsFn,
213     function_ref<void(Type)> walkTypesFn) const {
214   for (Attribute attr : llvm::make_second_range(getValue()))
215     walkAttrsFn(attr);
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // StringAttr
220 //===----------------------------------------------------------------------===//
221 
222 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) {
223   return Base::get(context, "", NoneType::get(context));
224 }
225 
226 /// Twine support for StringAttr.
227 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) {
228   // Fast-path empty twine.
229   if (twine.isTriviallyEmpty())
230     return get(context);
231   SmallVector<char, 32> tempStr;
232   return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context));
233 }
234 
235 /// Twine support for StringAttr.
236 StringAttr StringAttr::get(const Twine &twine, Type type) {
237   SmallVector<char, 32> tempStr;
238   return Base::get(type.getContext(), twine.toStringRef(tempStr), type);
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // FloatAttr
243 //===----------------------------------------------------------------------===//
244 
245 double FloatAttr::getValueAsDouble() const {
246   return getValueAsDouble(getValue());
247 }
248 double FloatAttr::getValueAsDouble(APFloat value) {
249   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
250     bool losesInfo = false;
251     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
252                   &losesInfo);
253   }
254   return value.convertToDouble();
255 }
256 
257 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
258                                 Type type, APFloat value) {
259   // Verify that the type is correct.
260   if (!type.isa<FloatType>())
261     return emitError() << "expected floating point type";
262 
263   // Verify that the type semantics match that of the value.
264   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
265     return emitError()
266            << "FloatAttr type doesn't match the type implied by its value";
267   }
268   return success();
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // SymbolRefAttr
273 //===----------------------------------------------------------------------===//
274 
275 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
276   return get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
277 }
278 
279 StringRef SymbolRefAttr::getLeafReference() const {
280   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
281   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
282 }
283 
284 //===----------------------------------------------------------------------===//
285 // IntegerAttr
286 //===----------------------------------------------------------------------===//
287 
288 int64_t IntegerAttr::getInt() const {
289   assert((getType().isIndex() || getType().isSignlessInteger()) &&
290          "must be signless integer");
291   return getValue().getSExtValue();
292 }
293 
294 int64_t IntegerAttr::getSInt() const {
295   assert(getType().isSignedInteger() && "must be signed integer");
296   return getValue().getSExtValue();
297 }
298 
299 uint64_t IntegerAttr::getUInt() const {
300   assert(getType().isUnsignedInteger() && "must be unsigned integer");
301   return getValue().getZExtValue();
302 }
303 
304 /// Return the value as an APSInt which carries the signed from the type of
305 /// the attribute.  This traps on signless integers types!
306 APSInt IntegerAttr::getAPSInt() const {
307   assert(!getType().isSignlessInteger() &&
308          "Signless integers don't carry a sign for APSInt");
309   return APSInt(getValue(), getType().isUnsignedInteger());
310 }
311 
312 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
313                                   Type type, APInt value) {
314   if (IntegerType integerType = type.dyn_cast<IntegerType>()) {
315     if (integerType.getWidth() != value.getBitWidth())
316       return emitError() << "integer type bit width (" << integerType.getWidth()
317                          << ") doesn't match value bit width ("
318                          << value.getBitWidth() << ")";
319     return success();
320   }
321   if (type.isa<IndexType>())
322     return success();
323   return emitError() << "expected integer or index type";
324 }
325 
326 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
327   auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
328   return attr.cast<BoolAttr>();
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // BoolAttr
333 
334 bool BoolAttr::getValue() const {
335   auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
336   return storage->value.getBoolValue();
337 }
338 
339 bool BoolAttr::classof(Attribute attr) {
340   IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
341   return intAttr && intAttr.getType().isSignlessInteger(1);
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // OpaqueAttr
346 //===----------------------------------------------------------------------===//
347 
348 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
349                                  Identifier dialect, StringRef attrData,
350                                  Type type) {
351   if (!Dialect::isValidNamespace(dialect.strref()))
352     return emitError() << "invalid dialect namespace '" << dialect << "'";
353 
354   // Check that the dialect is actually registered.
355   MLIRContext *context = dialect.getContext();
356   if (!context->allowsUnregisteredDialects() &&
357       !context->getLoadedDialect(dialect.strref())) {
358     return emitError()
359            << "#" << dialect << "<\"" << attrData << "\"> : " << type
360            << " attribute created with unregistered dialect. If this is "
361               "intended, please call allowUnregisteredDialects() on the "
362               "MLIRContext, or use -allow-unregistered-dialect with "
363               "the MLIR opt tool used";
364   }
365 
366   return success();
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // ElementsAttr
371 //===----------------------------------------------------------------------===//
372 
373 ShapedType ElementsAttr::getType() const {
374   return Attribute::getType().cast<ShapedType>();
375 }
376 
377 /// Returns the number of elements held by this attribute.
378 int64_t ElementsAttr::getNumElements() const {
379   return getType().getNumElements();
380 }
381 
382 /// Return the value at the given index. If index does not refer to a valid
383 /// element, then a null attribute is returned.
384 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
385   if (auto denseAttr = dyn_cast<DenseElementsAttr>())
386     return denseAttr.getValue(index);
387   if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>())
388     return opaqueAttr.getValue(index);
389   return cast<SparseElementsAttr>().getValue(index);
390 }
391 
392 /// Return if the given 'index' refers to a valid element in this attribute.
393 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
394   auto type = getType();
395 
396   // Verify that the rank of the indices matches the held type.
397   auto rank = type.getRank();
398   if (rank == 0 && index.size() == 1 && index[0] == 0)
399     return true;
400   if (rank != static_cast<int64_t>(index.size()))
401     return false;
402 
403   // Verify that all of the indices are within the shape dimensions.
404   auto shape = type.getShape();
405   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
406     int64_t dim = static_cast<int64_t>(index[i]);
407     return 0 <= dim && dim < shape[i];
408   });
409 }
410 
411 ElementsAttr
412 ElementsAttr::mapValues(Type newElementType,
413                         function_ref<APInt(const APInt &)> mapping) const {
414   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
415     return intOrFpAttr.mapValues(newElementType, mapping);
416   llvm_unreachable("unsupported ElementsAttr subtype");
417 }
418 
419 ElementsAttr
420 ElementsAttr::mapValues(Type newElementType,
421                         function_ref<APInt(const APFloat &)> mapping) const {
422   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
423     return intOrFpAttr.mapValues(newElementType, mapping);
424   llvm_unreachable("unsupported ElementsAttr subtype");
425 }
426 
427 /// Method for support type inquiry through isa, cast and dyn_cast.
428 bool ElementsAttr::classof(Attribute attr) {
429   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr,
430                   OpaqueElementsAttr, SparseElementsAttr>();
431 }
432 
433 /// Returns the 1 dimensional flattened row-major index from the given
434 /// multi-dimensional index.
435 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
436   assert(isValidIndex(index) && "expected valid multi-dimensional index");
437   auto type = getType();
438 
439   // Reduce the provided multidimensional index into a flattended 1D row-major
440   // index.
441   auto rank = type.getRank();
442   auto shape = type.getShape();
443   uint64_t valueIndex = 0;
444   uint64_t dimMultiplier = 1;
445   for (int i = rank - 1; i >= 0; --i) {
446     valueIndex += index[i] * dimMultiplier;
447     dimMultiplier *= shape[i];
448   }
449   return valueIndex;
450 }
451 
452 //===----------------------------------------------------------------------===//
453 // DenseElementsAttr Utilities
454 //===----------------------------------------------------------------------===//
455 
456 /// Get the bitwidth of a dense element type within the buffer.
457 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
458 static size_t getDenseElementStorageWidth(size_t origWidth) {
459   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
460 }
461 static size_t getDenseElementStorageWidth(Type elementType) {
462   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
463 }
464 
465 /// Set a bit to a specific value.
466 static void setBit(char *rawData, size_t bitPos, bool value) {
467   if (value)
468     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
469   else
470     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
471 }
472 
473 /// Return the value of the specified bit.
474 static bool getBit(const char *rawData, size_t bitPos) {
475   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
476 }
477 
478 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
479 /// BE format.
480 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
481                                          char *result) {
482   assert(llvm::support::endian::system_endianness() == // NOLINT
483          llvm::support::endianness::big);              // NOLINT
484   assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
485 
486   // Copy the words filled with data.
487   // For example, when `value` has 2 words, the first word is filled with data.
488   // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
489   size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
490   std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
491               numFilledWords, result);
492   // Convert last word of APInt to LE format and store it in char
493   // array(`valueLE`).
494   // ex. last word of `value` (BE): |------ij|  ==> `valueLE` (LE): |ji------|
495   size_t lastWordPos = numFilledWords;
496   SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
497   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
498       reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
499       valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
500   // Extract actual APInt data from `valueLE`, convert endianness to BE format,
501   // and store it in `result`.
502   // ex. `valueLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|ij|
503   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
504       valueLE.begin(), result + lastWordPos,
505       (numBytes - lastWordPos) * CHAR_BIT, 1);
506 }
507 
508 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
509 /// format.
510 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
511                                          APInt &result) {
512   assert(llvm::support::endian::system_endianness() == // NOLINT
513          llvm::support::endianness::big);              // NOLINT
514   assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
515 
516   // Copy the data that fills the word of `result` from `inArray`.
517   // For example, when `result` has 2 words, the first word will be filled with
518   // data. So, the first 8 bytes are copied from `inArray` here.
519   // `inArray` (10 bytes, BE): |abcdefgh|ij|
520   //                     ==> `result` (2 words, BE): |abcdefgh|--------|
521   size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
522   std::copy_n(
523       inArray, numFilledWords,
524       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
525 
526   // Convert array data which will be last word of `result` to LE format, and
527   // store it in char array(`inArrayLE`).
528   // ex. `inArray` (last two bytes, BE): |ij|  ==> `inArrayLE` (LE): |ji------|
529   size_t lastWordPos = numFilledWords;
530   SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
531   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
532       inArray + lastWordPos, inArrayLE.begin(),
533       (numBytes - lastWordPos) * CHAR_BIT, 1);
534 
535   // Convert `inArrayLE` to BE format, and store it in last word of `result`.
536   // ex. `inArrayLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|------ij|
537   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
538       inArrayLE.begin(),
539       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
540           lastWordPos,
541       APInt::APINT_BITS_PER_WORD, 1);
542 }
543 
544 /// Writes value to the bit position `bitPos` in array `rawData`.
545 static void writeBits(char *rawData, size_t bitPos, APInt value) {
546   size_t bitWidth = value.getBitWidth();
547 
548   // If the bitwidth is 1 we just toggle the specific bit.
549   if (bitWidth == 1)
550     return setBit(rawData, bitPos, value.isOneValue());
551 
552   // Otherwise, the bit position is guaranteed to be byte aligned.
553   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
554   if (llvm::support::endian::system_endianness() ==
555       llvm::support::endianness::big) {
556     // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
557     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
558     // work correctly in BE format.
559     // ex. `value` (2 words including 10 bytes)
560     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------|
561     copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
562                                  rawData + (bitPos / CHAR_BIT));
563   } else {
564     std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
565                 llvm::divideCeil(bitWidth, CHAR_BIT),
566                 rawData + (bitPos / CHAR_BIT));
567   }
568 }
569 
570 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
571 /// `rawData`.
572 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
573   // Handle a boolean bit position.
574   if (bitWidth == 1)
575     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
576 
577   // Otherwise, the bit position must be 8-bit aligned.
578   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
579   APInt result(bitWidth, 0);
580   if (llvm::support::endian::system_endianness() ==
581       llvm::support::endianness::big) {
582     // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
583     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
584     // work correctly in BE format.
585     // ex. `result` (2 words including 10 bytes)
586     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------| This function
587     copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
588                                  llvm::divideCeil(bitWidth, CHAR_BIT), result);
589   } else {
590     std::copy_n(rawData + (bitPos / CHAR_BIT),
591                 llvm::divideCeil(bitWidth, CHAR_BIT),
592                 const_cast<char *>(
593                     reinterpret_cast<const char *>(result.getRawData())));
594   }
595   return result;
596 }
597 
598 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
599 /// the same element count as 'type'.
600 template <typename Values>
601 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
602   return (values.size() == 1) ||
603          (type.getNumElements() == static_cast<int64_t>(values.size()));
604 }
605 
606 //===----------------------------------------------------------------------===//
607 // DenseElementsAttr Iterators
608 //===----------------------------------------------------------------------===//
609 
610 //===----------------------------------------------------------------------===//
611 // AttributeElementIterator
612 
613 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
614     DenseElementsAttr attr, size_t index)
615     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
616                                       Attribute, Attribute, Attribute>(
617           attr.getAsOpaquePointer(), index) {}
618 
619 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
620   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
621   Type eltTy = owner.getType().getElementType();
622   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
623     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
624   if (eltTy.isa<IndexType>())
625     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
626   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
627     IntElementIterator intIt(owner, index);
628     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
629     return FloatAttr::get(eltTy, *floatIt);
630   }
631   if (auto complexTy = eltTy.dyn_cast<ComplexType>()) {
632     auto complexEltTy = complexTy.getElementType();
633     ComplexIntElementIterator complexIntIt(owner, index);
634     if (complexEltTy.isa<IntegerType>()) {
635       auto value = *complexIntIt;
636       auto real = IntegerAttr::get(complexEltTy, value.real());
637       auto imag = IntegerAttr::get(complexEltTy, value.imag());
638       return ArrayAttr::get(complexTy.getContext(),
639                             ArrayRef<Attribute>{real, imag});
640     }
641 
642     ComplexFloatElementIterator complexFloatIt(
643         complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt);
644     auto value = *complexFloatIt;
645     auto real = FloatAttr::get(complexEltTy, value.real());
646     auto imag = FloatAttr::get(complexEltTy, value.imag());
647     return ArrayAttr::get(complexTy.getContext(),
648                           ArrayRef<Attribute>{real, imag});
649   }
650   if (owner.isa<DenseStringElementsAttr>()) {
651     ArrayRef<StringRef> vals = owner.getRawStringData();
652     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
653   }
654   llvm_unreachable("unexpected element type");
655 }
656 
657 //===----------------------------------------------------------------------===//
658 // BoolElementIterator
659 
660 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
661     DenseElementsAttr attr, size_t dataIndex)
662     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
663           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
664 
665 bool DenseElementsAttr::BoolElementIterator::operator*() const {
666   return getBit(getData(), getDataIndex());
667 }
668 
669 //===----------------------------------------------------------------------===//
670 // IntElementIterator
671 
672 DenseElementsAttr::IntElementIterator::IntElementIterator(
673     DenseElementsAttr attr, size_t dataIndex)
674     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
675           attr.getRawData().data(), attr.isSplat(), dataIndex),
676       bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
677 
678 APInt DenseElementsAttr::IntElementIterator::operator*() const {
679   return readBits(getData(),
680                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
681                   bitWidth);
682 }
683 
684 //===----------------------------------------------------------------------===//
685 // ComplexIntElementIterator
686 
687 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
688     DenseElementsAttr attr, size_t dataIndex)
689     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
690                                       std::complex<APInt>, std::complex<APInt>,
691                                       std::complex<APInt>>(
692           attr.getRawData().data(), attr.isSplat(), dataIndex) {
693   auto complexType = attr.getType().getElementType().cast<ComplexType>();
694   bitWidth = getDenseElementBitWidth(complexType.getElementType());
695 }
696 
697 std::complex<APInt>
698 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
699   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
700   size_t offset = getDataIndex() * storageWidth * 2;
701   return {readBits(getData(), offset, bitWidth),
702           readBits(getData(), offset + storageWidth, bitWidth)};
703 }
704 
705 //===----------------------------------------------------------------------===//
706 // FloatElementIterator
707 
708 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
709     const llvm::fltSemantics &smt, IntElementIterator it)
710     : llvm::mapped_iterator<IntElementIterator,
711                             std::function<APFloat(const APInt &)>>(
712           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
713 
714 //===----------------------------------------------------------------------===//
715 // ComplexFloatElementIterator
716 
717 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
718     const llvm::fltSemantics &smt, ComplexIntElementIterator it)
719     : llvm::mapped_iterator<
720           ComplexIntElementIterator,
721           std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
722           it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
723             return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
724           }) {}
725 
726 //===----------------------------------------------------------------------===//
727 // DenseElementsAttr
728 //===----------------------------------------------------------------------===//
729 
730 /// Method for support type inquiry through isa, cast and dyn_cast.
731 bool DenseElementsAttr::classof(Attribute attr) {
732   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
733 }
734 
735 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
736                                          ArrayRef<Attribute> values) {
737   assert(hasSameElementsOrSplat(type, values));
738 
739   // If the element type is not based on int/float/index, assume it is a string
740   // type.
741   auto eltType = type.getElementType();
742   if (!type.getElementType().isIntOrIndexOrFloat()) {
743     SmallVector<StringRef, 8> stringValues;
744     stringValues.reserve(values.size());
745     for (Attribute attr : values) {
746       assert(attr.isa<StringAttr>() &&
747              "expected string value for non integer/index/float element");
748       stringValues.push_back(attr.cast<StringAttr>().getValue());
749     }
750     return get(type, stringValues);
751   }
752 
753   // Otherwise, get the raw storage width to use for the allocation.
754   size_t bitWidth = getDenseElementBitWidth(eltType);
755   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
756 
757   // Compress the attribute values into a character buffer.
758   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
759                             values.size());
760   APInt intVal;
761   for (unsigned i = 0, e = values.size(); i < e; ++i) {
762     assert(eltType == values[i].getType() &&
763            "expected attribute value to have element type");
764     if (eltType.isa<FloatType>())
765       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
766     else if (eltType.isa<IntegerType, IndexType>())
767       intVal = values[i].cast<IntegerAttr>().getValue();
768     else
769       llvm_unreachable("unexpected element type");
770 
771     assert(intVal.getBitWidth() == bitWidth &&
772            "expected value to have same bitwidth as element type");
773     writeBits(data.data(), i * storageBitWidth, intVal);
774   }
775   return DenseIntOrFPElementsAttr::getRaw(type, data,
776                                           /*isSplat=*/(values.size() == 1));
777 }
778 
779 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
780                                          ArrayRef<bool> values) {
781   assert(hasSameElementsOrSplat(type, values));
782   assert(type.getElementType().isInteger(1));
783 
784   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
785   for (int i = 0, e = values.size(); i != e; ++i)
786     setBit(buff.data(), i, values[i]);
787   return DenseIntOrFPElementsAttr::getRaw(type, buff,
788                                           /*isSplat=*/(values.size() == 1));
789 }
790 
791 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
792                                          ArrayRef<StringRef> values) {
793   assert(!type.getElementType().isIntOrFloat());
794   return DenseStringElementsAttr::get(type, values);
795 }
796 
797 /// Constructs a dense integer elements attribute from an array of APInt
798 /// values. Each APInt value is expected to have the same bitwidth as the
799 /// element type of 'type'.
800 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
801                                          ArrayRef<APInt> values) {
802   assert(type.getElementType().isIntOrIndex());
803   assert(hasSameElementsOrSplat(type, values));
804   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
805   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
806                                           /*isSplat=*/(values.size() == 1));
807 }
808 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
809                                          ArrayRef<std::complex<APInt>> values) {
810   ComplexType complex = type.getElementType().cast<ComplexType>();
811   assert(complex.getElementType().isa<IntegerType>());
812   assert(hasSameElementsOrSplat(type, values));
813   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
814   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
815                           values.size() * 2);
816   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
817                                           /*isSplat=*/(values.size() == 1));
818 }
819 
820 // Constructs a dense float elements attribute from an array of APFloat
821 // values. Each APFloat value is expected to have the same bitwidth as the
822 // element type of 'type'.
823 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
824                                          ArrayRef<APFloat> values) {
825   assert(type.getElementType().isa<FloatType>());
826   assert(hasSameElementsOrSplat(type, values));
827   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
828   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
829                                           /*isSplat=*/(values.size() == 1));
830 }
831 DenseElementsAttr
832 DenseElementsAttr::get(ShapedType type,
833                        ArrayRef<std::complex<APFloat>> values) {
834   ComplexType complex = type.getElementType().cast<ComplexType>();
835   assert(complex.getElementType().isa<FloatType>());
836   assert(hasSameElementsOrSplat(type, values));
837   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
838                            values.size() * 2);
839   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
840   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
841                                           /*isSplat=*/(values.size() == 1));
842 }
843 
844 /// Construct a dense elements attribute from a raw buffer representing the
845 /// data for this attribute. Users should generally not use this methods as
846 /// the expected buffer format may not be a form the user expects.
847 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
848                                                       ArrayRef<char> rawBuffer,
849                                                       bool isSplatBuffer) {
850   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
851 }
852 
853 /// Returns true if the given buffer is a valid raw buffer for the given type.
854 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
855                                          ArrayRef<char> rawBuffer,
856                                          bool &detectedSplat) {
857   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
858   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
859 
860   // Storage width of 1 is special as it is packed by the bit.
861   if (storageWidth == 1) {
862     // Check for a splat, or a buffer equal to the number of elements.
863     if ((detectedSplat = rawBuffer.size() == 1))
864       return true;
865     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
866   }
867   // All other types are 8-bit aligned.
868   if ((detectedSplat = rawBufferWidth == storageWidth))
869     return true;
870   return rawBufferWidth == (storageWidth * type.getNumElements());
871 }
872 
873 /// Check the information for a C++ data type, check if this type is valid for
874 /// the current attribute. This method is used to verify specific type
875 /// invariants that the templatized 'getValues' method cannot.
876 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
877                               bool isSigned) {
878   // Make sure that the data element size is the same as the type element width.
879   if (getDenseElementBitWidth(type) !=
880       static_cast<size_t>(dataEltSize * CHAR_BIT))
881     return false;
882 
883   // Check that the element type is either float or integer or index.
884   if (!isInt)
885     return type.isa<FloatType>();
886   if (type.isIndex())
887     return true;
888 
889   auto intType = type.dyn_cast<IntegerType>();
890   if (!intType)
891     return false;
892 
893   // Make sure signedness semantics is consistent.
894   if (intType.isSignless())
895     return true;
896   return intType.isSigned() ? isSigned : !isSigned;
897 }
898 
899 /// Defaults down the subclass implementation.
900 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
901                                                    ArrayRef<char> data,
902                                                    int64_t dataEltSize,
903                                                    bool isInt, bool isSigned) {
904   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
905                                                  isSigned);
906 }
907 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
908                                                       ArrayRef<char> data,
909                                                       int64_t dataEltSize,
910                                                       bool isInt,
911                                                       bool isSigned) {
912   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
913                                                     isInt, isSigned);
914 }
915 
916 /// A method used to verify specific type invariants that the templatized 'get'
917 /// method cannot.
918 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
919                                           bool isSigned) const {
920   return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
921                              isSigned);
922 }
923 
924 /// Check the information for a C++ data type, check if this type is valid for
925 /// the current attribute.
926 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
927                                        bool isSigned) const {
928   return ::isValidIntOrFloat(
929       getType().getElementType().cast<ComplexType>().getElementType(),
930       dataEltSize / 2, isInt, isSigned);
931 }
932 
933 /// Returns true if this attribute corresponds to a splat, i.e. if all element
934 /// values are the same.
935 bool DenseElementsAttr::isSplat() const {
936   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
937 }
938 
939 /// Return the held element values as a range of Attributes.
940 auto DenseElementsAttr::getAttributeValues() const
941     -> llvm::iterator_range<AttributeElementIterator> {
942   return {attr_value_begin(), attr_value_end()};
943 }
944 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
945   return AttributeElementIterator(*this, 0);
946 }
947 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
948   return AttributeElementIterator(*this, getNumElements());
949 }
950 
951 /// Return the held element values as a range of bool. The element type of
952 /// this attribute must be of integer type of bitwidth 1.
953 auto DenseElementsAttr::getBoolValues() const
954     -> llvm::iterator_range<BoolElementIterator> {
955   auto eltType = getType().getElementType().dyn_cast<IntegerType>();
956   assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
957   (void)eltType;
958   return {BoolElementIterator(*this, 0),
959           BoolElementIterator(*this, getNumElements())};
960 }
961 
962 /// Return the held element values as a range of APInts. The element type of
963 /// this attribute must be of integer type.
964 auto DenseElementsAttr::getIntValues() const
965     -> llvm::iterator_range<IntElementIterator> {
966   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
967   return {raw_int_begin(), raw_int_end()};
968 }
969 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
970   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
971   return raw_int_begin();
972 }
973 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
974   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
975   return raw_int_end();
976 }
977 auto DenseElementsAttr::getComplexIntValues() const
978     -> llvm::iterator_range<ComplexIntElementIterator> {
979   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
980   (void)eltTy;
981   assert(eltTy.isa<IntegerType>() && "expected complex integral type");
982   return {ComplexIntElementIterator(*this, 0),
983           ComplexIntElementIterator(*this, getNumElements())};
984 }
985 
986 /// Return the held element values as a range of APFloat. The element type of
987 /// this attribute must be of float type.
988 auto DenseElementsAttr::getFloatValues() const
989     -> llvm::iterator_range<FloatElementIterator> {
990   auto elementType = getType().getElementType().cast<FloatType>();
991   const auto &elementSemantics = elementType.getFloatSemantics();
992   return {FloatElementIterator(elementSemantics, raw_int_begin()),
993           FloatElementIterator(elementSemantics, raw_int_end())};
994 }
995 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
996   return getFloatValues().begin();
997 }
998 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
999   return getFloatValues().end();
1000 }
1001 auto DenseElementsAttr::getComplexFloatValues() const
1002     -> llvm::iterator_range<ComplexFloatElementIterator> {
1003   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
1004   assert(eltTy.isa<FloatType>() && "expected complex float type");
1005   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
1006   return {{semantics, {*this, 0}},
1007           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
1008 }
1009 
1010 /// Return the raw storage data held by this attribute.
1011 ArrayRef<char> DenseElementsAttr::getRawData() const {
1012   return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
1013 }
1014 
1015 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1016   return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
1017 }
1018 
1019 /// Return a new DenseElementsAttr that has the same data as the current
1020 /// attribute, but has been reshaped to 'newType'. The new type must have the
1021 /// same total number of elements as well as element type.
1022 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1023   ShapedType curType = getType();
1024   if (curType == newType)
1025     return *this;
1026 
1027   assert(newType.getElementType() == curType.getElementType() &&
1028          "expected the same element type");
1029   assert(newType.getNumElements() == curType.getNumElements() &&
1030          "expected the same number of elements");
1031   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
1032 }
1033 
1034 /// Return a new DenseElementsAttr that has the same data as the current
1035 /// attribute, but has bitcast elements such that it is now 'newType'. The new
1036 /// type must have the same shape and element types of the same bitwidth as the
1037 /// current type.
1038 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
1039   ShapedType curType = getType();
1040   Type curElType = curType.getElementType();
1041   if (curElType == newElType)
1042     return *this;
1043 
1044   assert(getDenseElementBitWidth(newElType) ==
1045              getDenseElementBitWidth(curElType) &&
1046          "expected element types with the same bitwidth");
1047   return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
1048                                           getRawData(), isSplat());
1049 }
1050 
1051 DenseElementsAttr
1052 DenseElementsAttr::mapValues(Type newElementType,
1053                              function_ref<APInt(const APInt &)> mapping) const {
1054   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1055 }
1056 
1057 DenseElementsAttr DenseElementsAttr::mapValues(
1058     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1059   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1060 }
1061 
1062 //===----------------------------------------------------------------------===//
1063 // DenseIntOrFPElementsAttr
1064 //===----------------------------------------------------------------------===//
1065 
1066 /// Utility method to write a range of APInt values to a buffer.
1067 template <typename APRangeT>
1068 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1069                                 APRangeT &&values) {
1070   data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
1071   size_t offset = 0;
1072   for (auto it = values.begin(), e = values.end(); it != e;
1073        ++it, offset += storageWidth) {
1074     assert((*it).getBitWidth() <= storageWidth);
1075     writeBits(data.data(), offset, *it);
1076   }
1077 }
1078 
1079 /// Constructs a dense elements attribute from an array of raw APFloat values.
1080 /// Each APFloat value is expected to have the same bitwidth as the element
1081 /// type of 'type'. 'type' must be a vector or tensor with static shape.
1082 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1083                                                    size_t storageWidth,
1084                                                    ArrayRef<APFloat> values,
1085                                                    bool isSplat) {
1086   std::vector<char> data;
1087   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1088   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1089   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1090 }
1091 
1092 /// Constructs a dense elements attribute from an array of raw APInt values.
1093 /// Each APInt value is expected to have the same bitwidth as the element type
1094 /// of 'type'.
1095 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1096                                                    size_t storageWidth,
1097                                                    ArrayRef<APInt> values,
1098                                                    bool isSplat) {
1099   std::vector<char> data;
1100   writeAPIntsToBuffer(storageWidth, data, values);
1101   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1102 }
1103 
1104 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1105                                                    ArrayRef<char> data,
1106                                                    bool isSplat) {
1107   assert((type.isa<RankedTensorType, VectorType>()) &&
1108          "type must be ranked tensor or vector");
1109   assert(type.hasStaticShape() && "type must have static shape");
1110   return Base::get(type.getContext(), type, data, isSplat);
1111 }
1112 
1113 /// Overload of the raw 'get' method that asserts that the given type is of
1114 /// complex type. This method is used to verify type invariants that the
1115 /// templatized 'get' method cannot.
1116 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1117                                                           ArrayRef<char> data,
1118                                                           int64_t dataEltSize,
1119                                                           bool isInt,
1120                                                           bool isSigned) {
1121   assert(::isValidIntOrFloat(
1122       type.getElementType().cast<ComplexType>().getElementType(),
1123       dataEltSize / 2, isInt, isSigned));
1124 
1125   int64_t numElements = data.size() / dataEltSize;
1126   assert(numElements == 1 || numElements == type.getNumElements());
1127   return getRaw(type, data, /*isSplat=*/numElements == 1);
1128 }
1129 
1130 /// Overload of the 'getRaw' method that asserts that the given type is of
1131 /// integer type. This method is used to verify type invariants that the
1132 /// templatized 'get' method cannot.
1133 DenseElementsAttr
1134 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1135                                            int64_t dataEltSize, bool isInt,
1136                                            bool isSigned) {
1137   assert(
1138       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1139 
1140   int64_t numElements = data.size() / dataEltSize;
1141   assert(numElements == 1 || numElements == type.getNumElements());
1142   return getRaw(type, data, /*isSplat=*/numElements == 1);
1143 }
1144 
1145 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1146     const char *inRawData, char *outRawData, size_t elementBitWidth,
1147     size_t numElements) {
1148   using llvm::support::ulittle16_t;
1149   using llvm::support::ulittle32_t;
1150   using llvm::support::ulittle64_t;
1151 
1152   assert(llvm::support::endian::system_endianness() == // NOLINT
1153          llvm::support::endianness::big);              // NOLINT
1154   // NOLINT to avoid warning message about replacing by static_assert()
1155 
1156   // Following std::copy_n always converts endianness on BE machine.
1157   switch (elementBitWidth) {
1158   case 16: {
1159     const ulittle16_t *inRawDataPos =
1160         reinterpret_cast<const ulittle16_t *>(inRawData);
1161     uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1162     std::copy_n(inRawDataPos, numElements, outDataPos);
1163     break;
1164   }
1165   case 32: {
1166     const ulittle32_t *inRawDataPos =
1167         reinterpret_cast<const ulittle32_t *>(inRawData);
1168     uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1169     std::copy_n(inRawDataPos, numElements, outDataPos);
1170     break;
1171   }
1172   case 64: {
1173     const ulittle64_t *inRawDataPos =
1174         reinterpret_cast<const ulittle64_t *>(inRawData);
1175     uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1176     std::copy_n(inRawDataPos, numElements, outDataPos);
1177     break;
1178   }
1179   default: {
1180     size_t nBytes = elementBitWidth / CHAR_BIT;
1181     for (size_t i = 0; i < nBytes; i++)
1182       std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1183     break;
1184   }
1185   }
1186 }
1187 
1188 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1189     ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1190     ShapedType type) {
1191   size_t numElements = type.getNumElements();
1192   Type elementType = type.getElementType();
1193   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1194     elementType = complexTy.getElementType();
1195     numElements = numElements * 2;
1196   }
1197   size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1198   assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1199          inRawData.size() <= outRawData.size());
1200   convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1201                                   elementBitWidth, numElements);
1202 }
1203 
1204 //===----------------------------------------------------------------------===//
1205 // DenseFPElementsAttr
1206 //===----------------------------------------------------------------------===//
1207 
1208 template <typename Fn, typename Attr>
1209 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1210                                 Type newElementType,
1211                                 llvm::SmallVectorImpl<char> &data) {
1212   size_t bitWidth = getDenseElementBitWidth(newElementType);
1213   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1214 
1215   ShapedType newArrayType;
1216   if (inType.isa<RankedTensorType>())
1217     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1218   else if (inType.isa<UnrankedTensorType>())
1219     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1220   else if (inType.isa<VectorType>())
1221     newArrayType = VectorType::get(inType.getShape(), newElementType);
1222   else
1223     assert(newArrayType && "Unhandled tensor type");
1224 
1225   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1226   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1227 
1228   // Functor used to process a single element value of the attribute.
1229   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1230     auto newInt = mapping(value);
1231     assert(newInt.getBitWidth() == bitWidth);
1232     writeBits(data.data(), index * storageBitWidth, newInt);
1233   };
1234 
1235   // Check for the splat case.
1236   if (attr.isSplat()) {
1237     processElt(*attr.begin(), /*index=*/0);
1238     return newArrayType;
1239   }
1240 
1241   // Otherwise, process all of the element values.
1242   uint64_t elementIdx = 0;
1243   for (auto value : attr)
1244     processElt(value, elementIdx++);
1245   return newArrayType;
1246 }
1247 
1248 DenseElementsAttr DenseFPElementsAttr::mapValues(
1249     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1250   llvm::SmallVector<char, 8> elementData;
1251   auto newArrayType =
1252       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1253 
1254   return getRaw(newArrayType, elementData, isSplat());
1255 }
1256 
1257 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1258 bool DenseFPElementsAttr::classof(Attribute attr) {
1259   return attr.isa<DenseElementsAttr>() &&
1260          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1261 }
1262 
1263 //===----------------------------------------------------------------------===//
1264 // DenseIntElementsAttr
1265 //===----------------------------------------------------------------------===//
1266 
1267 DenseElementsAttr DenseIntElementsAttr::mapValues(
1268     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1269   llvm::SmallVector<char, 8> elementData;
1270   auto newArrayType =
1271       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1272 
1273   return getRaw(newArrayType, elementData, isSplat());
1274 }
1275 
1276 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1277 bool DenseIntElementsAttr::classof(Attribute attr) {
1278   return attr.isa<DenseElementsAttr>() &&
1279          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1280 }
1281 
1282 //===----------------------------------------------------------------------===//
1283 // OpaqueElementsAttr
1284 //===----------------------------------------------------------------------===//
1285 
1286 /// Return the value at the given index. If index does not refer to a valid
1287 /// element, then a null attribute is returned.
1288 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1289   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1290   return Attribute();
1291 }
1292 
1293 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1294   Dialect *dialect = getDialect().getDialect();
1295   if (!dialect)
1296     return true;
1297   auto *interface =
1298       dialect->getRegisteredInterface<DialectDecodeAttributesInterface>();
1299   if (!interface)
1300     return true;
1301   return failed(interface->decode(*this, result));
1302 }
1303 
1304 LogicalResult
1305 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1306                            Identifier dialect, StringRef value,
1307                            ShapedType type) {
1308   if (!Dialect::isValidNamespace(dialect.strref()))
1309     return emitError() << "invalid dialect namespace '" << dialect << "'";
1310   return success();
1311 }
1312 
1313 //===----------------------------------------------------------------------===//
1314 // SparseElementsAttr
1315 //===----------------------------------------------------------------------===//
1316 
1317 /// Return the value of the element at the given index.
1318 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1319   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1320   auto type = getType();
1321 
1322   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1323   // as a 1-D index array.
1324   auto sparseIndices = getIndices();
1325   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1326 
1327   // Check to see if the indices are a splat.
1328   if (sparseIndices.isSplat()) {
1329     // If the index is also not a splat of the index value, we know that the
1330     // value is zero.
1331     auto splatIndex = *sparseIndexValues.begin();
1332     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
1333       return getZeroAttr();
1334 
1335     // If the indices are a splat, we also expect the values to be a splat.
1336     assert(getValues().isSplat() && "expected splat values");
1337     return getValues().getSplatValue();
1338   }
1339 
1340   // Build a mapping between known indices and the offset of the stored element.
1341   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
1342   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1343   size_t rank = type.getRank();
1344   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1345     mappedIndices.try_emplace(
1346         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1347 
1348   // Look for the provided index key within the mapped indices. If the provided
1349   // index is not found, then return a zero attribute.
1350   auto it = mappedIndices.find(index);
1351   if (it == mappedIndices.end())
1352     return getZeroAttr();
1353 
1354   // Otherwise, return the held sparse value element.
1355   return getValues().getValue(it->second);
1356 }
1357 
1358 /// Get a zero APFloat for the given sparse attribute.
1359 APFloat SparseElementsAttr::getZeroAPFloat() const {
1360   auto eltType = getType().getElementType().cast<FloatType>();
1361   return APFloat(eltType.getFloatSemantics());
1362 }
1363 
1364 /// Get a zero APInt for the given sparse attribute.
1365 APInt SparseElementsAttr::getZeroAPInt() const {
1366   auto eltType = getType().getElementType().cast<IntegerType>();
1367   return APInt::getNullValue(eltType.getWidth());
1368 }
1369 
1370 /// Get a zero attribute for the given attribute type.
1371 Attribute SparseElementsAttr::getZeroAttr() const {
1372   auto eltType = getType().getElementType();
1373 
1374   // Handle floating point elements.
1375   if (eltType.isa<FloatType>())
1376     return FloatAttr::get(eltType, 0);
1377 
1378   // Otherwise, this is an integer.
1379   // TODO: Handle StringAttr here.
1380   return IntegerAttr::get(eltType, 0);
1381 }
1382 
1383 /// Flatten, and return, all of the sparse indices in this attribute in
1384 /// row-major order.
1385 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1386   std::vector<ptrdiff_t> flatSparseIndices;
1387 
1388   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1389   // as a 1-D index array.
1390   auto sparseIndices = getIndices();
1391   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1392   if (sparseIndices.isSplat()) {
1393     SmallVector<uint64_t, 8> indices(getType().getRank(),
1394                                      *sparseIndexValues.begin());
1395     flatSparseIndices.push_back(getFlattenedIndex(indices));
1396     return flatSparseIndices;
1397   }
1398 
1399   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1400   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1401   size_t rank = getType().getRank();
1402   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1403     flatSparseIndices.push_back(getFlattenedIndex(
1404         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1405   return flatSparseIndices;
1406 }
1407 
1408 //===----------------------------------------------------------------------===//
1409 // TypeAttr
1410 //===----------------------------------------------------------------------===//
1411 
1412 void TypeAttr::walkImmediateSubElements(
1413     function_ref<void(Attribute)> walkAttrsFn,
1414     function_ref<void(Type)> walkTypesFn) const {
1415   walkTypesFn(getValue());
1416 }
1417