1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BuiltinDialect.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/IntegerSet.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/SymbolTable.h"
17 #include "mlir/IR/Types.h"
18 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
19 #include "llvm/ADT/APSInt.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/Support/Endian.h"
22 
23 using namespace mlir;
24 using namespace mlir::detail;
25 
26 //===----------------------------------------------------------------------===//
27 /// Tablegen Attribute Definitions
28 //===----------------------------------------------------------------------===//
29 
30 #define GET_ATTRDEF_CLASSES
31 #include "mlir/IR/BuiltinAttributes.cpp.inc"
32 
33 //===----------------------------------------------------------------------===//
34 // BuiltinDialect
35 //===----------------------------------------------------------------------===//
36 
37 void BuiltinDialect::registerAttributes() {
38   addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
39                 DenseStringElementsAttr, DictionaryAttr, FloatAttr,
40                 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
41                 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
42                 UnitAttr>();
43 }
44 
45 //===----------------------------------------------------------------------===//
46 // ArrayAttr
47 //===----------------------------------------------------------------------===//
48 
49 void ArrayAttr::walkImmediateSubElements(
50     function_ref<void(Attribute)> walkAttrsFn,
51     function_ref<void(Type)> walkTypesFn) const {
52   for (Attribute attr : getValue())
53     walkAttrsFn(attr);
54 }
55 
56 SubElementAttrInterface ArrayAttr::replaceImmediateSubAttribute(
57     ArrayRef<std::pair<size_t, Attribute>> replacements) const {
58   std::vector<Attribute> vector = getValue().vec();
59   for (auto &it : replacements) {
60     vector[it.first] = it.second;
61   }
62   return get(getContext(), vector);
63 }
64 
65 //===----------------------------------------------------------------------===//
66 // DictionaryAttr
67 //===----------------------------------------------------------------------===//
68 
69 /// Helper function that does either an in place sort or sorts from source array
70 /// into destination. If inPlace then storage is both the source and the
71 /// destination, else value is the source and storage destination. Returns
72 /// whether source was sorted.
73 template <bool inPlace>
74 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
75                                SmallVectorImpl<NamedAttribute> &storage) {
76   // Specialize for the common case.
77   switch (value.size()) {
78   case 0:
79     // Zero already sorted.
80     if (!inPlace)
81       storage.clear();
82     break;
83   case 1:
84     // One already sorted but may need to be copied.
85     if (!inPlace)
86       storage.assign({value[0]});
87     break;
88   case 2: {
89     bool isSorted = value[0] < value[1];
90     if (inPlace) {
91       if (!isSorted)
92         std::swap(storage[0], storage[1]);
93     } else if (isSorted) {
94       storage.assign({value[0], value[1]});
95     } else {
96       storage.assign({value[1], value[0]});
97     }
98     return !isSorted;
99   }
100   default:
101     if (!inPlace)
102       storage.assign(value.begin(), value.end());
103     // Check to see they are sorted already.
104     bool isSorted = llvm::is_sorted(value);
105     // If not, do a general sort.
106     if (!isSorted)
107       llvm::array_pod_sort(storage.begin(), storage.end());
108     return !isSorted;
109   }
110   return false;
111 }
112 
113 /// Returns an entry with a duplicate name from the given sorted array of named
114 /// attributes. Returns llvm::None if all elements have unique names.
115 static Optional<NamedAttribute>
116 findDuplicateElement(ArrayRef<NamedAttribute> value) {
117   const Optional<NamedAttribute> none{llvm::None};
118   if (value.size() < 2)
119     return none;
120 
121   if (value.size() == 2)
122     return value[0].getName() == value[1].getName() ? value[0] : none;
123 
124   const auto *it = std::adjacent_find(value.begin(), value.end(),
125                                       [](NamedAttribute l, NamedAttribute r) {
126                                         return l.getName() == r.getName();
127                                       });
128   return it != value.end() ? *it : none;
129 }
130 
131 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
132                           SmallVectorImpl<NamedAttribute> &storage) {
133   bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
134   assert(!findDuplicateElement(storage) &&
135          "DictionaryAttr element names must be unique");
136   return isSorted;
137 }
138 
139 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
140   bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
141   assert(!findDuplicateElement(array) &&
142          "DictionaryAttr element names must be unique");
143   return isSorted;
144 }
145 
146 Optional<NamedAttribute>
147 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
148                               bool isSorted) {
149   if (!isSorted)
150     dictionaryAttrSort</*inPlace=*/true>(array, array);
151   return findDuplicateElement(array);
152 }
153 
154 DictionaryAttr DictionaryAttr::get(MLIRContext *context,
155                                    ArrayRef<NamedAttribute> value) {
156   if (value.empty())
157     return DictionaryAttr::getEmpty(context);
158 
159   // We need to sort the element list to canonicalize it.
160   SmallVector<NamedAttribute, 8> storage;
161   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
162     value = storage;
163   assert(!findDuplicateElement(value) &&
164          "DictionaryAttr element names must be unique");
165   return Base::get(context, value);
166 }
167 /// Construct a dictionary with an array of values that is known to already be
168 /// sorted by name and uniqued.
169 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
170                                              ArrayRef<NamedAttribute> value) {
171   if (value.empty())
172     return DictionaryAttr::getEmpty(context);
173   // Ensure that the attribute elements are unique and sorted.
174   assert(llvm::is_sorted(
175              value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) &&
176          "expected attribute values to be sorted");
177   assert(!findDuplicateElement(value) &&
178          "DictionaryAttr element names must be unique");
179   return Base::get(context, value);
180 }
181 
182 /// Return the specified attribute if present, null otherwise.
183 Attribute DictionaryAttr::get(StringRef name) const {
184   auto it = impl::findAttrSorted(begin(), end(), name);
185   return it.second ? it.first->getValue() : Attribute();
186 }
187 Attribute DictionaryAttr::get(StringAttr name) const {
188   auto it = impl::findAttrSorted(begin(), end(), name);
189   return it.second ? it.first->getValue() : Attribute();
190 }
191 
192 /// Return the specified named attribute if present, None otherwise.
193 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
194   auto it = impl::findAttrSorted(begin(), end(), name);
195   return it.second ? *it.first : Optional<NamedAttribute>();
196 }
197 Optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const {
198   auto it = impl::findAttrSorted(begin(), end(), name);
199   return it.second ? *it.first : Optional<NamedAttribute>();
200 }
201 
202 /// Return whether the specified attribute is present.
203 bool DictionaryAttr::contains(StringRef name) const {
204   return impl::findAttrSorted(begin(), end(), name).second;
205 }
206 bool DictionaryAttr::contains(StringAttr name) const {
207   return impl::findAttrSorted(begin(), end(), name).second;
208 }
209 
210 DictionaryAttr::iterator DictionaryAttr::begin() const {
211   return getValue().begin();
212 }
213 DictionaryAttr::iterator DictionaryAttr::end() const {
214   return getValue().end();
215 }
216 size_t DictionaryAttr::size() const { return getValue().size(); }
217 
218 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
219   return Base::get(context, ArrayRef<NamedAttribute>());
220 }
221 
222 void DictionaryAttr::walkImmediateSubElements(
223     function_ref<void(Attribute)> walkAttrsFn,
224     function_ref<void(Type)> walkTypesFn) const {
225   for (const NamedAttribute &attr : getValue())
226     walkAttrsFn(attr.getValue());
227 }
228 
229 SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute(
230     ArrayRef<std::pair<size_t, Attribute>> replacements) const {
231   std::vector<NamedAttribute> vec = getValue().vec();
232   for (auto &it : replacements)
233     vec[it.first].setValue(it.second);
234 
235   // The above only modifies the mapped value, but not the key, and therefore
236   // not the order of the elements. It remains sorted
237   return getWithSorted(getContext(), vec);
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // StringAttr
242 //===----------------------------------------------------------------------===//
243 
244 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) {
245   return Base::get(context, "", NoneType::get(context));
246 }
247 
248 /// Twine support for StringAttr.
249 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) {
250   // Fast-path empty twine.
251   if (twine.isTriviallyEmpty())
252     return get(context);
253   SmallVector<char, 32> tempStr;
254   return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context));
255 }
256 
257 /// Twine support for StringAttr.
258 StringAttr StringAttr::get(const Twine &twine, Type type) {
259   SmallVector<char, 32> tempStr;
260   return Base::get(type.getContext(), twine.toStringRef(tempStr), type);
261 }
262 
263 StringRef StringAttr::getValue() const { return getImpl()->value; }
264 
265 Dialect *StringAttr::getReferencedDialect() const {
266   return getImpl()->referencedDialect;
267 }
268 
269 //===----------------------------------------------------------------------===//
270 // FloatAttr
271 //===----------------------------------------------------------------------===//
272 
273 double FloatAttr::getValueAsDouble() const {
274   return getValueAsDouble(getValue());
275 }
276 double FloatAttr::getValueAsDouble(APFloat value) {
277   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
278     bool losesInfo = false;
279     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
280                   &losesInfo);
281   }
282   return value.convertToDouble();
283 }
284 
285 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
286                                 Type type, APFloat value) {
287   // Verify that the type is correct.
288   if (!type.isa<FloatType>())
289     return emitError() << "expected floating point type";
290 
291   // Verify that the type semantics match that of the value.
292   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
293     return emitError()
294            << "FloatAttr type doesn't match the type implied by its value";
295   }
296   return success();
297 }
298 
299 //===----------------------------------------------------------------------===//
300 // SymbolRefAttr
301 //===----------------------------------------------------------------------===//
302 
303 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
304                                  ArrayRef<FlatSymbolRefAttr> nestedRefs) {
305   return get(StringAttr::get(ctx, value), nestedRefs);
306 }
307 
308 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
309   return get(ctx, value, {}).cast<FlatSymbolRefAttr>();
310 }
311 
312 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
313   return get(value, {}).cast<FlatSymbolRefAttr>();
314 }
315 
316 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) {
317   auto symName =
318       symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
319   assert(symName && "value does not have a valid symbol name");
320   return SymbolRefAttr::get(symName);
321 }
322 
323 StringAttr SymbolRefAttr::getLeafReference() const {
324   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
325   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
326 }
327 
328 //===----------------------------------------------------------------------===//
329 // IntegerAttr
330 //===----------------------------------------------------------------------===//
331 
332 int64_t IntegerAttr::getInt() const {
333   assert((getType().isIndex() || getType().isSignlessInteger()) &&
334          "must be signless integer");
335   return getValue().getSExtValue();
336 }
337 
338 int64_t IntegerAttr::getSInt() const {
339   assert(getType().isSignedInteger() && "must be signed integer");
340   return getValue().getSExtValue();
341 }
342 
343 uint64_t IntegerAttr::getUInt() const {
344   assert(getType().isUnsignedInteger() && "must be unsigned integer");
345   return getValue().getZExtValue();
346 }
347 
348 /// Return the value as an APSInt which carries the signed from the type of
349 /// the attribute.  This traps on signless integers types!
350 APSInt IntegerAttr::getAPSInt() const {
351   assert(!getType().isSignlessInteger() &&
352          "Signless integers don't carry a sign for APSInt");
353   return APSInt(getValue(), getType().isUnsignedInteger());
354 }
355 
356 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
357                                   Type type, APInt value) {
358   if (IntegerType integerType = type.dyn_cast<IntegerType>()) {
359     if (integerType.getWidth() != value.getBitWidth())
360       return emitError() << "integer type bit width (" << integerType.getWidth()
361                          << ") doesn't match value bit width ("
362                          << value.getBitWidth() << ")";
363     return success();
364   }
365   if (type.isa<IndexType>())
366     return success();
367   return emitError() << "expected integer or index type";
368 }
369 
370 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
371   auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
372   return attr.cast<BoolAttr>();
373 }
374 
375 //===----------------------------------------------------------------------===//
376 // BoolAttr
377 //===----------------------------------------------------------------------===//
378 
379 bool BoolAttr::getValue() const {
380   auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
381   return storage->value.getBoolValue();
382 }
383 
384 bool BoolAttr::classof(Attribute attr) {
385   IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
386   return intAttr && intAttr.getType().isSignlessInteger(1);
387 }
388 
389 //===----------------------------------------------------------------------===//
390 // OpaqueAttr
391 //===----------------------------------------------------------------------===//
392 
393 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
394                                  StringAttr dialect, StringRef attrData,
395                                  Type type) {
396   if (!Dialect::isValidNamespace(dialect.strref()))
397     return emitError() << "invalid dialect namespace '" << dialect << "'";
398 
399   // Check that the dialect is actually registered.
400   MLIRContext *context = dialect.getContext();
401   if (!context->allowsUnregisteredDialects() &&
402       !context->getLoadedDialect(dialect.strref())) {
403     return emitError()
404            << "#" << dialect << "<\"" << attrData << "\"> : " << type
405            << " attribute created with unregistered dialect. If this is "
406               "intended, please call allowUnregisteredDialects() on the "
407               "MLIRContext, or use -allow-unregistered-dialect with "
408               "the MLIR opt tool used";
409   }
410 
411   return success();
412 }
413 
414 //===----------------------------------------------------------------------===//
415 // DenseElementsAttr Utilities
416 //===----------------------------------------------------------------------===//
417 
418 /// Get the bitwidth of a dense element type within the buffer.
419 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
420 static size_t getDenseElementStorageWidth(size_t origWidth) {
421   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
422 }
423 static size_t getDenseElementStorageWidth(Type elementType) {
424   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
425 }
426 
427 /// Set a bit to a specific value.
428 static void setBit(char *rawData, size_t bitPos, bool value) {
429   if (value)
430     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
431   else
432     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
433 }
434 
435 /// Return the value of the specified bit.
436 static bool getBit(const char *rawData, size_t bitPos) {
437   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
438 }
439 
440 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
441 /// BE format.
442 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
443                                          char *result) {
444   assert(llvm::support::endian::system_endianness() == // NOLINT
445          llvm::support::endianness::big);              // NOLINT
446   assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
447 
448   // Copy the words filled with data.
449   // For example, when `value` has 2 words, the first word is filled with data.
450   // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
451   size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
452   std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
453               numFilledWords, result);
454   // Convert last word of APInt to LE format and store it in char
455   // array(`valueLE`).
456   // ex. last word of `value` (BE): |------ij|  ==> `valueLE` (LE): |ji------|
457   size_t lastWordPos = numFilledWords;
458   SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
459   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
460       reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
461       valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
462   // Extract actual APInt data from `valueLE`, convert endianness to BE format,
463   // and store it in `result`.
464   // ex. `valueLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|ij|
465   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
466       valueLE.begin(), result + lastWordPos,
467       (numBytes - lastWordPos) * CHAR_BIT, 1);
468 }
469 
470 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
471 /// format.
472 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
473                                          APInt &result) {
474   assert(llvm::support::endian::system_endianness() == // NOLINT
475          llvm::support::endianness::big);              // NOLINT
476   assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
477 
478   // Copy the data that fills the word of `result` from `inArray`.
479   // For example, when `result` has 2 words, the first word will be filled with
480   // data. So, the first 8 bytes are copied from `inArray` here.
481   // `inArray` (10 bytes, BE): |abcdefgh|ij|
482   //                     ==> `result` (2 words, BE): |abcdefgh|--------|
483   size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
484   std::copy_n(
485       inArray, numFilledWords,
486       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
487 
488   // Convert array data which will be last word of `result` to LE format, and
489   // store it in char array(`inArrayLE`).
490   // ex. `inArray` (last two bytes, BE): |ij|  ==> `inArrayLE` (LE): |ji------|
491   size_t lastWordPos = numFilledWords;
492   SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
493   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
494       inArray + lastWordPos, inArrayLE.begin(),
495       (numBytes - lastWordPos) * CHAR_BIT, 1);
496 
497   // Convert `inArrayLE` to BE format, and store it in last word of `result`.
498   // ex. `inArrayLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|------ij|
499   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
500       inArrayLE.begin(),
501       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
502           lastWordPos,
503       APInt::APINT_BITS_PER_WORD, 1);
504 }
505 
506 /// Writes value to the bit position `bitPos` in array `rawData`.
507 static void writeBits(char *rawData, size_t bitPos, APInt value) {
508   size_t bitWidth = value.getBitWidth();
509 
510   // If the bitwidth is 1 we just toggle the specific bit.
511   if (bitWidth == 1)
512     return setBit(rawData, bitPos, value.isOneValue());
513 
514   // Otherwise, the bit position is guaranteed to be byte aligned.
515   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
516   if (llvm::support::endian::system_endianness() ==
517       llvm::support::endianness::big) {
518     // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
519     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
520     // work correctly in BE format.
521     // ex. `value` (2 words including 10 bytes)
522     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------|
523     copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
524                                  rawData + (bitPos / CHAR_BIT));
525   } else {
526     std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
527                 llvm::divideCeil(bitWidth, CHAR_BIT),
528                 rawData + (bitPos / CHAR_BIT));
529   }
530 }
531 
532 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
533 /// `rawData`.
534 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
535   // Handle a boolean bit position.
536   if (bitWidth == 1)
537     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
538 
539   // Otherwise, the bit position must be 8-bit aligned.
540   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
541   APInt result(bitWidth, 0);
542   if (llvm::support::endian::system_endianness() ==
543       llvm::support::endianness::big) {
544     // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
545     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
546     // work correctly in BE format.
547     // ex. `result` (2 words including 10 bytes)
548     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------| This function
549     copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
550                                  llvm::divideCeil(bitWidth, CHAR_BIT), result);
551   } else {
552     std::copy_n(rawData + (bitPos / CHAR_BIT),
553                 llvm::divideCeil(bitWidth, CHAR_BIT),
554                 const_cast<char *>(
555                     reinterpret_cast<const char *>(result.getRawData())));
556   }
557   return result;
558 }
559 
560 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
561 /// the same element count as 'type'.
562 template <typename Values>
563 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
564   return (values.size() == 1) ||
565          (type.getNumElements() == static_cast<int64_t>(values.size()));
566 }
567 
568 //===----------------------------------------------------------------------===//
569 // DenseElementsAttr Iterators
570 //===----------------------------------------------------------------------===//
571 
572 //===----------------------------------------------------------------------===//
573 // AttributeElementIterator
574 
575 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
576     DenseElementsAttr attr, size_t index)
577     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
578                                       Attribute, Attribute, Attribute>(
579           attr.getAsOpaquePointer(), index) {}
580 
581 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
582   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
583   Type eltTy = owner.getElementType();
584   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
585     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
586   if (eltTy.isa<IndexType>())
587     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
588   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
589     IntElementIterator intIt(owner, index);
590     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
591     return FloatAttr::get(eltTy, *floatIt);
592   }
593   if (auto complexTy = eltTy.dyn_cast<ComplexType>()) {
594     auto complexEltTy = complexTy.getElementType();
595     ComplexIntElementIterator complexIntIt(owner, index);
596     if (complexEltTy.isa<IntegerType>()) {
597       auto value = *complexIntIt;
598       auto real = IntegerAttr::get(complexEltTy, value.real());
599       auto imag = IntegerAttr::get(complexEltTy, value.imag());
600       return ArrayAttr::get(complexTy.getContext(),
601                             ArrayRef<Attribute>{real, imag});
602     }
603 
604     ComplexFloatElementIterator complexFloatIt(
605         complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt);
606     auto value = *complexFloatIt;
607     auto real = FloatAttr::get(complexEltTy, value.real());
608     auto imag = FloatAttr::get(complexEltTy, value.imag());
609     return ArrayAttr::get(complexTy.getContext(),
610                           ArrayRef<Attribute>{real, imag});
611   }
612   if (owner.isa<DenseStringElementsAttr>()) {
613     ArrayRef<StringRef> vals = owner.getRawStringData();
614     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
615   }
616   llvm_unreachable("unexpected element type");
617 }
618 
619 //===----------------------------------------------------------------------===//
620 // BoolElementIterator
621 
622 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
623     DenseElementsAttr attr, size_t dataIndex)
624     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
625           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
626 
627 bool DenseElementsAttr::BoolElementIterator::operator*() const {
628   return getBit(getData(), getDataIndex());
629 }
630 
631 //===----------------------------------------------------------------------===//
632 // IntElementIterator
633 
634 DenseElementsAttr::IntElementIterator::IntElementIterator(
635     DenseElementsAttr attr, size_t dataIndex)
636     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
637           attr.getRawData().data(), attr.isSplat(), dataIndex),
638       bitWidth(getDenseElementBitWidth(attr.getElementType())) {}
639 
640 APInt DenseElementsAttr::IntElementIterator::operator*() const {
641   return readBits(getData(),
642                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
643                   bitWidth);
644 }
645 
646 //===----------------------------------------------------------------------===//
647 // ComplexIntElementIterator
648 
649 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
650     DenseElementsAttr attr, size_t dataIndex)
651     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
652                                       std::complex<APInt>, std::complex<APInt>,
653                                       std::complex<APInt>>(
654           attr.getRawData().data(), attr.isSplat(), dataIndex) {
655   auto complexType = attr.getElementType().cast<ComplexType>();
656   bitWidth = getDenseElementBitWidth(complexType.getElementType());
657 }
658 
659 std::complex<APInt>
660 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
661   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
662   size_t offset = getDataIndex() * storageWidth * 2;
663   return {readBits(getData(), offset, bitWidth),
664           readBits(getData(), offset + storageWidth, bitWidth)};
665 }
666 
667 //===----------------------------------------------------------------------===//
668 // DenseElementsAttr
669 //===----------------------------------------------------------------------===//
670 
671 /// Method for support type inquiry through isa, cast and dyn_cast.
672 bool DenseElementsAttr::classof(Attribute attr) {
673   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
674 }
675 
676 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
677                                          ArrayRef<Attribute> values) {
678   assert(hasSameElementsOrSplat(type, values));
679 
680   // If the element type is not based on int/float/index, assume it is a string
681   // type.
682   auto eltType = type.getElementType();
683   if (!type.getElementType().isIntOrIndexOrFloat()) {
684     SmallVector<StringRef, 8> stringValues;
685     stringValues.reserve(values.size());
686     for (Attribute attr : values) {
687       assert(attr.isa<StringAttr>() &&
688              "expected string value for non integer/index/float element");
689       stringValues.push_back(attr.cast<StringAttr>().getValue());
690     }
691     return get(type, stringValues);
692   }
693 
694   // Otherwise, get the raw storage width to use for the allocation.
695   size_t bitWidth = getDenseElementBitWidth(eltType);
696   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
697 
698   // Compress the attribute values into a character buffer.
699   SmallVector<char, 8> data(
700       llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT));
701   APInt intVal;
702   for (unsigned i = 0, e = values.size(); i < e; ++i) {
703     assert(eltType == values[i].getType() &&
704            "expected attribute value to have element type");
705     if (eltType.isa<FloatType>())
706       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
707     else if (eltType.isa<IntegerType, IndexType>())
708       intVal = values[i].cast<IntegerAttr>().getValue();
709     else
710       llvm_unreachable("unexpected element type");
711 
712     assert(intVal.getBitWidth() == bitWidth &&
713            "expected value to have same bitwidth as element type");
714     writeBits(data.data(), i * storageBitWidth, intVal);
715   }
716 
717   // Handle the special encoding of splat of bool.
718   if (values.size() == 1 && values[0].getType().isInteger(1))
719     data[0] = data[0] ? -1 : 0;
720 
721   return DenseIntOrFPElementsAttr::getRaw(type, data);
722 }
723 
724 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
725                                          ArrayRef<bool> values) {
726   assert(hasSameElementsOrSplat(type, values));
727   assert(type.getElementType().isInteger(1));
728 
729   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
730 
731   if (!values.empty()) {
732     bool isSplat = true;
733     bool firstValue = values[0];
734     for (int i = 0, e = values.size(); i != e; ++i) {
735       isSplat &= values[i] == firstValue;
736       setBit(buff.data(), i, values[i]);
737     }
738 
739     // Splat of bool is encoded as a byte with all-ones in it.
740     if (isSplat) {
741       buff.resize(1);
742       buff[0] = values[0] ? -1 : 0;
743     }
744   }
745 
746   return DenseIntOrFPElementsAttr::getRaw(type, buff);
747 }
748 
749 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
750                                          ArrayRef<StringRef> values) {
751   assert(!type.getElementType().isIntOrFloat());
752   return DenseStringElementsAttr::get(type, values);
753 }
754 
755 /// Constructs a dense integer elements attribute from an array of APInt
756 /// values. Each APInt value is expected to have the same bitwidth as the
757 /// element type of 'type'.
758 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
759                                          ArrayRef<APInt> values) {
760   assert(type.getElementType().isIntOrIndex());
761   assert(hasSameElementsOrSplat(type, values));
762   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
763   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
764 }
765 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
766                                          ArrayRef<std::complex<APInt>> values) {
767   ComplexType complex = type.getElementType().cast<ComplexType>();
768   assert(complex.getElementType().isa<IntegerType>());
769   assert(hasSameElementsOrSplat(type, values));
770   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
771   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
772                           values.size() * 2);
773   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals);
774 }
775 
776 // Constructs a dense float elements attribute from an array of APFloat
777 // values. Each APFloat value is expected to have the same bitwidth as the
778 // element type of 'type'.
779 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
780                                          ArrayRef<APFloat> values) {
781   assert(type.getElementType().isa<FloatType>());
782   assert(hasSameElementsOrSplat(type, values));
783   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
784   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
785 }
786 DenseElementsAttr
787 DenseElementsAttr::get(ShapedType type,
788                        ArrayRef<std::complex<APFloat>> values) {
789   ComplexType complex = type.getElementType().cast<ComplexType>();
790   assert(complex.getElementType().isa<FloatType>());
791   assert(hasSameElementsOrSplat(type, values));
792   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
793                            values.size() * 2);
794   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
795   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals);
796 }
797 
798 /// Construct a dense elements attribute from a raw buffer representing the
799 /// data for this attribute. Users should generally not use this methods as
800 /// the expected buffer format may not be a form the user expects.
801 DenseElementsAttr
802 DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) {
803   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer);
804 }
805 
806 /// Returns true if the given buffer is a valid raw buffer for the given type.
807 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
808                                          ArrayRef<char> rawBuffer,
809                                          bool &detectedSplat) {
810   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
811   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
812   int64_t numElements = type.getNumElements();
813 
814   // The initializer is always a splat if the result type has a single element.
815   detectedSplat = numElements == 1;
816 
817   // Storage width of 1 is special as it is packed by the bit.
818   if (storageWidth == 1) {
819     // Check for a splat, or a buffer equal to the number of elements which
820     // consists of either all 0's or all 1's.
821     if (rawBuffer.size() == 1) {
822       auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
823       if (rawByte == 0 || rawByte == 0xff) {
824         detectedSplat = true;
825         return true;
826       }
827     }
828 
829     // This is a valid non-splat buffer if it has the right size.
830     return rawBufferWidth == llvm::alignTo<8>(numElements);
831   }
832 
833   // All other types are 8-bit aligned, so we can just check the buffer width
834   // to know if only a single initializer element was passed in.
835   if (rawBufferWidth == storageWidth) {
836     detectedSplat = true;
837     return true;
838   }
839 
840   // The raw buffer is valid if it has the right size.
841   return rawBufferWidth == storageWidth * numElements;
842 }
843 
844 /// Check the information for a C++ data type, check if this type is valid for
845 /// the current attribute. This method is used to verify specific type
846 /// invariants that the templatized 'getValues' method cannot.
847 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
848                               bool isSigned) {
849   // Make sure that the data element size is the same as the type element width.
850   if (getDenseElementBitWidth(type) !=
851       static_cast<size_t>(dataEltSize * CHAR_BIT))
852     return false;
853 
854   // Check that the element type is either float or integer or index.
855   if (!isInt)
856     return type.isa<FloatType>();
857   if (type.isIndex())
858     return true;
859 
860   auto intType = type.dyn_cast<IntegerType>();
861   if (!intType)
862     return false;
863 
864   // Make sure signedness semantics is consistent.
865   if (intType.isSignless())
866     return true;
867   return intType.isSigned() ? isSigned : !isSigned;
868 }
869 
870 /// Defaults down the subclass implementation.
871 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
872                                                    ArrayRef<char> data,
873                                                    int64_t dataEltSize,
874                                                    bool isInt, bool isSigned) {
875   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
876                                                  isSigned);
877 }
878 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
879                                                       ArrayRef<char> data,
880                                                       int64_t dataEltSize,
881                                                       bool isInt,
882                                                       bool isSigned) {
883   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
884                                                     isInt, isSigned);
885 }
886 
887 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
888                                           bool isSigned) const {
889   return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned);
890 }
891 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
892                                        bool isSigned) const {
893   return ::isValidIntOrFloat(
894       getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2,
895       isInt, isSigned);
896 }
897 
898 /// Returns true if this attribute corresponds to a splat, i.e. if all element
899 /// values are the same.
900 bool DenseElementsAttr::isSplat() const {
901   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
902 }
903 
904 /// Return if the given complex type has an integer element type.
905 LLVM_ATTRIBUTE_UNUSED static bool isComplexOfIntType(Type type) {
906   return type.cast<ComplexType>().getElementType().isa<IntegerType>();
907 }
908 
909 auto DenseElementsAttr::getComplexIntValues() const
910     -> iterator_range_impl<ComplexIntElementIterator> {
911   assert(isComplexOfIntType(getElementType()) &&
912          "expected complex integral type");
913   return {getType(), ComplexIntElementIterator(*this, 0),
914           ComplexIntElementIterator(*this, getNumElements())};
915 }
916 auto DenseElementsAttr::complex_value_begin() const
917     -> ComplexIntElementIterator {
918   assert(isComplexOfIntType(getElementType()) &&
919          "expected complex integral type");
920   return ComplexIntElementIterator(*this, 0);
921 }
922 auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator {
923   assert(isComplexOfIntType(getElementType()) &&
924          "expected complex integral type");
925   return ComplexIntElementIterator(*this, getNumElements());
926 }
927 
928 /// Return the held element values as a range of APFloat. The element type of
929 /// this attribute must be of float type.
930 auto DenseElementsAttr::getFloatValues() const
931     -> iterator_range_impl<FloatElementIterator> {
932   auto elementType = getElementType().cast<FloatType>();
933   const auto &elementSemantics = elementType.getFloatSemantics();
934   return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()),
935           FloatElementIterator(elementSemantics, raw_int_end())};
936 }
937 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
938   auto elementType = getElementType().cast<FloatType>();
939   return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin());
940 }
941 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
942   auto elementType = getElementType().cast<FloatType>();
943   return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end());
944 }
945 
946 auto DenseElementsAttr::getComplexFloatValues() const
947     -> iterator_range_impl<ComplexFloatElementIterator> {
948   Type eltTy = getElementType().cast<ComplexType>().getElementType();
949   assert(eltTy.isa<FloatType>() && "expected complex float type");
950   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
951   return {getType(),
952           {semantics, {*this, 0}},
953           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
954 }
955 auto DenseElementsAttr::complex_float_value_begin() const
956     -> ComplexFloatElementIterator {
957   Type eltTy = getElementType().cast<ComplexType>().getElementType();
958   assert(eltTy.isa<FloatType>() && "expected complex float type");
959   return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}};
960 }
961 auto DenseElementsAttr::complex_float_value_end() const
962     -> ComplexFloatElementIterator {
963   Type eltTy = getElementType().cast<ComplexType>().getElementType();
964   assert(eltTy.isa<FloatType>() && "expected complex float type");
965   return {eltTy.cast<FloatType>().getFloatSemantics(),
966           {*this, static_cast<size_t>(getNumElements())}};
967 }
968 
969 /// Return the raw storage data held by this attribute.
970 ArrayRef<char> DenseElementsAttr::getRawData() const {
971   return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
972 }
973 
974 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
975   return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
976 }
977 
978 /// Return a new DenseElementsAttr that has the same data as the current
979 /// attribute, but has been reshaped to 'newType'. The new type must have the
980 /// same total number of elements as well as element type.
981 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
982   ShapedType curType = getType();
983   if (curType == newType)
984     return *this;
985 
986   assert(newType.getElementType() == curType.getElementType() &&
987          "expected the same element type");
988   assert(newType.getNumElements() == curType.getNumElements() &&
989          "expected the same number of elements");
990   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
991 }
992 
993 DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
994   assert(isSplat() && "expected a splat type");
995 
996   ShapedType curType = getType();
997   if (curType == newType)
998     return *this;
999 
1000   assert(newType.getElementType() == curType.getElementType() &&
1001          "expected the same element type");
1002   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1003 }
1004 
1005 /// Return a new DenseElementsAttr that has the same data as the current
1006 /// attribute, but has bitcast elements such that it is now 'newType'. The new
1007 /// type must have the same shape and element types of the same bitwidth as the
1008 /// current type.
1009 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
1010   ShapedType curType = getType();
1011   Type curElType = curType.getElementType();
1012   if (curElType == newElType)
1013     return *this;
1014 
1015   assert(getDenseElementBitWidth(newElType) ==
1016              getDenseElementBitWidth(curElType) &&
1017          "expected element types with the same bitwidth");
1018   return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
1019                                           getRawData());
1020 }
1021 
1022 DenseElementsAttr
1023 DenseElementsAttr::mapValues(Type newElementType,
1024                              function_ref<APInt(const APInt &)> mapping) const {
1025   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1026 }
1027 
1028 DenseElementsAttr DenseElementsAttr::mapValues(
1029     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1030   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1031 }
1032 
1033 ShapedType DenseElementsAttr::getType() const {
1034   return Attribute::getType().cast<ShapedType>();
1035 }
1036 
1037 Type DenseElementsAttr::getElementType() const {
1038   return getType().getElementType();
1039 }
1040 
1041 int64_t DenseElementsAttr::getNumElements() const {
1042   return getType().getNumElements();
1043 }
1044 
1045 //===----------------------------------------------------------------------===//
1046 // DenseIntOrFPElementsAttr
1047 //===----------------------------------------------------------------------===//
1048 
1049 /// Utility method to write a range of APInt values to a buffer.
1050 template <typename APRangeT>
1051 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1052                                 APRangeT &&values) {
1053   size_t numValues = llvm::size(values);
1054   data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT));
1055   size_t offset = 0;
1056   for (auto it = values.begin(), e = values.end(); it != e;
1057        ++it, offset += storageWidth) {
1058     assert((*it).getBitWidth() <= storageWidth);
1059     writeBits(data.data(), offset, *it);
1060   }
1061 
1062   // Handle the special encoding of splat of a boolean.
1063   if (numValues == 1 && (*values.begin()).getBitWidth() == 1)
1064     data[0] = data[0] ? -1 : 0;
1065 }
1066 
1067 /// Constructs a dense elements attribute from an array of raw APFloat values.
1068 /// Each APFloat value is expected to have the same bitwidth as the element
1069 /// type of 'type'. 'type' must be a vector or tensor with static shape.
1070 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1071                                                    size_t storageWidth,
1072                                                    ArrayRef<APFloat> values) {
1073   std::vector<char> data;
1074   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1075   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1076   return DenseIntOrFPElementsAttr::getRaw(type, data);
1077 }
1078 
1079 /// Constructs a dense elements attribute from an array of raw APInt values.
1080 /// Each APInt value is expected to have the same bitwidth as the element type
1081 /// of 'type'.
1082 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1083                                                    size_t storageWidth,
1084                                                    ArrayRef<APInt> values) {
1085   std::vector<char> data;
1086   writeAPIntsToBuffer(storageWidth, data, values);
1087   return DenseIntOrFPElementsAttr::getRaw(type, data);
1088 }
1089 
1090 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1091                                                    ArrayRef<char> data) {
1092   assert((type.isa<RankedTensorType, VectorType>()) &&
1093          "type must be ranked tensor or vector");
1094   assert(type.hasStaticShape() && "type must have static shape");
1095   bool isSplat = false;
1096   bool isValid = isValidRawBuffer(type, data, isSplat);
1097   assert(isValid);
1098   (void)isValid;
1099   return Base::get(type.getContext(), type, data, isSplat);
1100 }
1101 
1102 /// Overload of the raw 'get' method that asserts that the given type is of
1103 /// complex type. This method is used to verify type invariants that the
1104 /// templatized 'get' method cannot.
1105 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1106                                                           ArrayRef<char> data,
1107                                                           int64_t dataEltSize,
1108                                                           bool isInt,
1109                                                           bool isSigned) {
1110   assert(::isValidIntOrFloat(
1111       type.getElementType().cast<ComplexType>().getElementType(),
1112       dataEltSize / 2, isInt, isSigned));
1113 
1114   int64_t numElements = data.size() / dataEltSize;
1115   (void)numElements;
1116   assert(numElements == 1 || numElements == type.getNumElements());
1117   return getRaw(type, data);
1118 }
1119 
1120 /// Overload of the 'getRaw' method that asserts that the given type is of
1121 /// integer type. This method is used to verify type invariants that the
1122 /// templatized 'get' method cannot.
1123 DenseElementsAttr
1124 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1125                                            int64_t dataEltSize, bool isInt,
1126                                            bool isSigned) {
1127   assert(
1128       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1129 
1130   int64_t numElements = data.size() / dataEltSize;
1131   assert(numElements == 1 || numElements == type.getNumElements());
1132   (void)numElements;
1133   return getRaw(type, data);
1134 }
1135 
1136 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1137     const char *inRawData, char *outRawData, size_t elementBitWidth,
1138     size_t numElements) {
1139   using llvm::support::ulittle16_t;
1140   using llvm::support::ulittle32_t;
1141   using llvm::support::ulittle64_t;
1142 
1143   assert(llvm::support::endian::system_endianness() == // NOLINT
1144          llvm::support::endianness::big);              // NOLINT
1145   // NOLINT to avoid warning message about replacing by static_assert()
1146 
1147   // Following std::copy_n always converts endianness on BE machine.
1148   switch (elementBitWidth) {
1149   case 16: {
1150     const ulittle16_t *inRawDataPos =
1151         reinterpret_cast<const ulittle16_t *>(inRawData);
1152     uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1153     std::copy_n(inRawDataPos, numElements, outDataPos);
1154     break;
1155   }
1156   case 32: {
1157     const ulittle32_t *inRawDataPos =
1158         reinterpret_cast<const ulittle32_t *>(inRawData);
1159     uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1160     std::copy_n(inRawDataPos, numElements, outDataPos);
1161     break;
1162   }
1163   case 64: {
1164     const ulittle64_t *inRawDataPos =
1165         reinterpret_cast<const ulittle64_t *>(inRawData);
1166     uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1167     std::copy_n(inRawDataPos, numElements, outDataPos);
1168     break;
1169   }
1170   default: {
1171     size_t nBytes = elementBitWidth / CHAR_BIT;
1172     for (size_t i = 0; i < nBytes; i++)
1173       std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1174     break;
1175   }
1176   }
1177 }
1178 
1179 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1180     ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1181     ShapedType type) {
1182   size_t numElements = type.getNumElements();
1183   Type elementType = type.getElementType();
1184   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1185     elementType = complexTy.getElementType();
1186     numElements = numElements * 2;
1187   }
1188   size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1189   assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1190          inRawData.size() <= outRawData.size());
1191   if (elementBitWidth <= CHAR_BIT)
1192     std::memcpy(outRawData.begin(), inRawData.begin(), inRawData.size());
1193   else
1194     convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1195                                     elementBitWidth, numElements);
1196 }
1197 
1198 //===----------------------------------------------------------------------===//
1199 // DenseFPElementsAttr
1200 //===----------------------------------------------------------------------===//
1201 
1202 template <typename Fn, typename Attr>
1203 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1204                                 Type newElementType,
1205                                 llvm::SmallVectorImpl<char> &data) {
1206   size_t bitWidth = getDenseElementBitWidth(newElementType);
1207   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1208 
1209   ShapedType newArrayType;
1210   if (inType.isa<RankedTensorType>())
1211     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1212   else if (inType.isa<UnrankedTensorType>())
1213     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1214   else if (auto vType = inType.dyn_cast<VectorType>())
1215     newArrayType = VectorType::get(vType.getShape(), newElementType,
1216                                    vType.getNumScalableDims());
1217   else
1218     assert(newArrayType && "Unhandled tensor type");
1219 
1220   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1221   data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT));
1222 
1223   // Functor used to process a single element value of the attribute.
1224   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1225     auto newInt = mapping(value);
1226     assert(newInt.getBitWidth() == bitWidth);
1227     writeBits(data.data(), index * storageBitWidth, newInt);
1228   };
1229 
1230   // Check for the splat case.
1231   if (attr.isSplat()) {
1232     processElt(*attr.begin(), /*index=*/0);
1233     return newArrayType;
1234   }
1235 
1236   // Otherwise, process all of the element values.
1237   uint64_t elementIdx = 0;
1238   for (auto value : attr)
1239     processElt(value, elementIdx++);
1240   return newArrayType;
1241 }
1242 
1243 DenseElementsAttr DenseFPElementsAttr::mapValues(
1244     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1245   llvm::SmallVector<char, 8> elementData;
1246   auto newArrayType =
1247       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1248 
1249   return getRaw(newArrayType, elementData);
1250 }
1251 
1252 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1253 bool DenseFPElementsAttr::classof(Attribute attr) {
1254   return attr.isa<DenseElementsAttr>() &&
1255          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1256 }
1257 
1258 //===----------------------------------------------------------------------===//
1259 // DenseIntElementsAttr
1260 //===----------------------------------------------------------------------===//
1261 
1262 DenseElementsAttr DenseIntElementsAttr::mapValues(
1263     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1264   llvm::SmallVector<char, 8> elementData;
1265   auto newArrayType =
1266       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1267   return getRaw(newArrayType, elementData);
1268 }
1269 
1270 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1271 bool DenseIntElementsAttr::classof(Attribute attr) {
1272   return attr.isa<DenseElementsAttr>() &&
1273          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1274 }
1275 
1276 //===----------------------------------------------------------------------===//
1277 // OpaqueElementsAttr
1278 //===----------------------------------------------------------------------===//
1279 
1280 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1281   Dialect *dialect = getContext()->getLoadedDialect(getDialect());
1282   if (!dialect)
1283     return true;
1284   auto *interface = llvm::dyn_cast<DialectDecodeAttributesInterface>(dialect);
1285   if (!interface)
1286     return true;
1287   return failed(interface->decode(*this, result));
1288 }
1289 
1290 LogicalResult
1291 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1292                            StringAttr dialect, StringRef value,
1293                            ShapedType type) {
1294   if (!Dialect::isValidNamespace(dialect.strref()))
1295     return emitError() << "invalid dialect namespace '" << dialect << "'";
1296   return success();
1297 }
1298 
1299 //===----------------------------------------------------------------------===//
1300 // SparseElementsAttr
1301 //===----------------------------------------------------------------------===//
1302 
1303 /// Get a zero APFloat for the given sparse attribute.
1304 APFloat SparseElementsAttr::getZeroAPFloat() const {
1305   auto eltType = getElementType().cast<FloatType>();
1306   return APFloat(eltType.getFloatSemantics());
1307 }
1308 
1309 /// Get a zero APInt for the given sparse attribute.
1310 APInt SparseElementsAttr::getZeroAPInt() const {
1311   auto eltType = getElementType().cast<IntegerType>();
1312   return APInt::getZero(eltType.getWidth());
1313 }
1314 
1315 /// Get a zero attribute for the given attribute type.
1316 Attribute SparseElementsAttr::getZeroAttr() const {
1317   auto eltType = getElementType();
1318 
1319   // Handle floating point elements.
1320   if (eltType.isa<FloatType>())
1321     return FloatAttr::get(eltType, 0);
1322 
1323   // Handle string type.
1324   if (getValues().isa<DenseStringElementsAttr>())
1325     return StringAttr::get("", eltType);
1326 
1327   // Otherwise, this is an integer.
1328   return IntegerAttr::get(eltType, 0);
1329 }
1330 
1331 /// Flatten, and return, all of the sparse indices in this attribute in
1332 /// row-major order.
1333 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1334   std::vector<ptrdiff_t> flatSparseIndices;
1335 
1336   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1337   // as a 1-D index array.
1338   auto sparseIndices = getIndices();
1339   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1340   if (sparseIndices.isSplat()) {
1341     SmallVector<uint64_t, 8> indices(getType().getRank(),
1342                                      *sparseIndexValues.begin());
1343     flatSparseIndices.push_back(getFlattenedIndex(indices));
1344     return flatSparseIndices;
1345   }
1346 
1347   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1348   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1349   size_t rank = getType().getRank();
1350   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1351     flatSparseIndices.push_back(getFlattenedIndex(
1352         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1353   return flatSparseIndices;
1354 }
1355 
1356 LogicalResult
1357 SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1358                            ShapedType type, DenseIntElementsAttr sparseIndices,
1359                            DenseElementsAttr values) {
1360   ShapedType valuesType = values.getType();
1361   if (valuesType.getRank() != 1)
1362     return emitError() << "expected 1-d tensor for sparse element values";
1363 
1364   // Verify the indices and values shape.
1365   ShapedType indicesType = sparseIndices.getType();
1366   auto emitShapeError = [&]() {
1367     return emitError() << "expected shape ([" << type.getShape()
1368                        << "]); inferred shape of indices literal (["
1369                        << indicesType.getShape()
1370                        << "]); inferred shape of values literal (["
1371                        << valuesType.getShape() << "])";
1372   };
1373   // Verify indices shape.
1374   size_t rank = type.getRank(), indicesRank = indicesType.getRank();
1375   if (indicesRank == 2) {
1376     if (indicesType.getDimSize(1) != static_cast<int64_t>(rank))
1377       return emitShapeError();
1378   } else if (indicesRank != 1 || rank != 1) {
1379     return emitShapeError();
1380   }
1381   // Verify the values shape.
1382   int64_t numSparseIndices = indicesType.getDimSize(0);
1383   if (numSparseIndices != valuesType.getDimSize(0))
1384     return emitShapeError();
1385 
1386   // Verify that the sparse indices are within the value shape.
1387   auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) {
1388     return emitError()
1389            << "sparse index #" << indexNum
1390            << " is not contained within the value shape, with index=[" << index
1391            << "], and type=" << type;
1392   };
1393 
1394   // Handle the case where the index values are a splat.
1395   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1396   if (sparseIndices.isSplat()) {
1397     SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin());
1398     if (!ElementsAttr::isValidIndex(type, indices))
1399       return emitIndexError(0, indices);
1400     return success();
1401   }
1402 
1403   // Otherwise, reinterpret each index as an ArrayRef.
1404   for (size_t i = 0, e = numSparseIndices; i != e; ++i) {
1405     ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank),
1406                              rank);
1407     if (!ElementsAttr::isValidIndex(type, index))
1408       return emitIndexError(i, index);
1409   }
1410 
1411   return success();
1412 }
1413 
1414 //===----------------------------------------------------------------------===//
1415 // TypeAttr
1416 //===----------------------------------------------------------------------===//
1417 
1418 void TypeAttr::walkImmediateSubElements(
1419     function_ref<void(Attribute)> walkAttrsFn,
1420     function_ref<void(Type)> walkTypesFn) const {
1421   walkTypesFn(getValue());
1422 }
1423