1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BuiltinDialect.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/IntegerSet.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/SymbolTable.h"
17 #include "mlir/IR/Types.h"
18 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
19 #include "llvm/ADT/APSInt.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/Support/Endian.h"
22 
23 using namespace mlir;
24 using namespace mlir::detail;
25 
26 //===----------------------------------------------------------------------===//
27 /// Tablegen Attribute Definitions
28 //===----------------------------------------------------------------------===//
29 
30 #define GET_ATTRDEF_CLASSES
31 #include "mlir/IR/BuiltinAttributes.cpp.inc"
32 
33 //===----------------------------------------------------------------------===//
34 // BuiltinDialect
35 //===----------------------------------------------------------------------===//
36 
37 void BuiltinDialect::registerAttributes() {
38   addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
39                 DenseStringElementsAttr, DictionaryAttr, FloatAttr,
40                 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
41                 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
42                 UnitAttr>();
43 }
44 
45 //===----------------------------------------------------------------------===//
46 // ArrayAttr
47 //===----------------------------------------------------------------------===//
48 
49 void ArrayAttr::walkImmediateSubElements(
50     function_ref<void(Attribute)> walkAttrsFn,
51     function_ref<void(Type)> walkTypesFn) const {
52   for (Attribute attr : getValue())
53     walkAttrsFn(attr);
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // DictionaryAttr
58 //===----------------------------------------------------------------------===//
59 
60 /// Helper function that does either an in place sort or sorts from source array
61 /// into destination. If inPlace then storage is both the source and the
62 /// destination, else value is the source and storage destination. Returns
63 /// whether source was sorted.
64 template <bool inPlace>
65 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
66                                SmallVectorImpl<NamedAttribute> &storage) {
67   // Specialize for the common case.
68   switch (value.size()) {
69   case 0:
70     // Zero already sorted.
71     if (!inPlace)
72       storage.clear();
73     break;
74   case 1:
75     // One already sorted but may need to be copied.
76     if (!inPlace)
77       storage.assign({value[0]});
78     break;
79   case 2: {
80     bool isSorted = value[0] < value[1];
81     if (inPlace) {
82       if (!isSorted)
83         std::swap(storage[0], storage[1]);
84     } else if (isSorted) {
85       storage.assign({value[0], value[1]});
86     } else {
87       storage.assign({value[1], value[0]});
88     }
89     return !isSorted;
90   }
91   default:
92     if (!inPlace)
93       storage.assign(value.begin(), value.end());
94     // Check to see they are sorted already.
95     bool isSorted = llvm::is_sorted(value);
96     // If not, do a general sort.
97     if (!isSorted)
98       llvm::array_pod_sort(storage.begin(), storage.end());
99     return !isSorted;
100   }
101   return false;
102 }
103 
104 /// Returns an entry with a duplicate name from the given sorted array of named
105 /// attributes. Returns llvm::None if all elements have unique names.
106 static Optional<NamedAttribute>
107 findDuplicateElement(ArrayRef<NamedAttribute> value) {
108   const Optional<NamedAttribute> none{llvm::None};
109   if (value.size() < 2)
110     return none;
111 
112   if (value.size() == 2)
113     return value[0].first == value[1].first ? value[0] : none;
114 
115   auto it = std::adjacent_find(
116       value.begin(), value.end(),
117       [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; });
118   return it != value.end() ? *it : none;
119 }
120 
121 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
122                           SmallVectorImpl<NamedAttribute> &storage) {
123   bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
124   assert(!findDuplicateElement(storage) &&
125          "DictionaryAttr element names must be unique");
126   return isSorted;
127 }
128 
129 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
130   bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
131   assert(!findDuplicateElement(array) &&
132          "DictionaryAttr element names must be unique");
133   return isSorted;
134 }
135 
136 Optional<NamedAttribute>
137 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
138                               bool isSorted) {
139   if (!isSorted)
140     dictionaryAttrSort</*inPlace=*/true>(array, array);
141   return findDuplicateElement(array);
142 }
143 
144 DictionaryAttr DictionaryAttr::get(MLIRContext *context,
145                                    ArrayRef<NamedAttribute> value) {
146   if (value.empty())
147     return DictionaryAttr::getEmpty(context);
148   assert(llvm::all_of(value,
149                       [](const NamedAttribute &attr) { return attr.second; }) &&
150          "value cannot have null entries");
151 
152   // We need to sort the element list to canonicalize it.
153   SmallVector<NamedAttribute, 8> storage;
154   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
155     value = storage;
156   assert(!findDuplicateElement(value) &&
157          "DictionaryAttr element names must be unique");
158   return Base::get(context, value);
159 }
160 /// Construct a dictionary with an array of values that is known to already be
161 /// sorted by name and uniqued.
162 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
163                                              ArrayRef<NamedAttribute> value) {
164   if (value.empty())
165     return DictionaryAttr::getEmpty(context);
166   // Ensure that the attribute elements are unique and sorted.
167   assert(llvm::is_sorted(value,
168                          [](NamedAttribute l, NamedAttribute r) {
169                            return l.first.strref() < r.first.strref();
170                          }) &&
171          "expected attribute values to be sorted");
172   assert(!findDuplicateElement(value) &&
173          "DictionaryAttr element names must be unique");
174   return Base::get(context, value);
175 }
176 
177 /// Return the specified attribute if present, null otherwise.
178 Attribute DictionaryAttr::get(StringRef name) const {
179   Optional<NamedAttribute> attr = getNamed(name);
180   return attr ? attr->second : nullptr;
181 }
182 Attribute DictionaryAttr::get(Identifier name) const {
183   Optional<NamedAttribute> attr = getNamed(name);
184   return attr ? attr->second : nullptr;
185 }
186 
187 /// Return the specified named attribute if present, None otherwise.
188 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
189   ArrayRef<NamedAttribute> values = getValue();
190   const auto *it = llvm::lower_bound(values, name);
191   return it != values.end() && it->first == name ? *it
192                                                  : Optional<NamedAttribute>();
193 }
194 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
195   for (auto elt : getValue())
196     if (elt.first == name)
197       return elt;
198   return llvm::None;
199 }
200 
201 DictionaryAttr::iterator DictionaryAttr::begin() const {
202   return getValue().begin();
203 }
204 DictionaryAttr::iterator DictionaryAttr::end() const {
205   return getValue().end();
206 }
207 size_t DictionaryAttr::size() const { return getValue().size(); }
208 
209 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
210   return Base::get(context, ArrayRef<NamedAttribute>());
211 }
212 
213 void DictionaryAttr::walkImmediateSubElements(
214     function_ref<void(Attribute)> walkAttrsFn,
215     function_ref<void(Type)> walkTypesFn) const {
216   for (Attribute attr : llvm::make_second_range(getValue()))
217     walkAttrsFn(attr);
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // StringAttr
222 //===----------------------------------------------------------------------===//
223 
224 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) {
225   return Base::get(context, "", NoneType::get(context));
226 }
227 
228 /// Twine support for StringAttr.
229 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) {
230   // Fast-path empty twine.
231   if (twine.isTriviallyEmpty())
232     return get(context);
233   SmallVector<char, 32> tempStr;
234   return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context));
235 }
236 
237 /// Twine support for StringAttr.
238 StringAttr StringAttr::get(const Twine &twine, Type type) {
239   SmallVector<char, 32> tempStr;
240   return Base::get(type.getContext(), twine.toStringRef(tempStr), type);
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // FloatAttr
245 //===----------------------------------------------------------------------===//
246 
247 double FloatAttr::getValueAsDouble() const {
248   return getValueAsDouble(getValue());
249 }
250 double FloatAttr::getValueAsDouble(APFloat value) {
251   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
252     bool losesInfo = false;
253     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
254                   &losesInfo);
255   }
256   return value.convertToDouble();
257 }
258 
259 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
260                                 Type type, APFloat value) {
261   // Verify that the type is correct.
262   if (!type.isa<FloatType>())
263     return emitError() << "expected floating point type";
264 
265   // Verify that the type semantics match that of the value.
266   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
267     return emitError()
268            << "FloatAttr type doesn't match the type implied by its value";
269   }
270   return success();
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // SymbolRefAttr
275 //===----------------------------------------------------------------------===//
276 
277 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
278                                  ArrayRef<FlatSymbolRefAttr> nestedRefs) {
279   return get(StringAttr::get(ctx, value), nestedRefs);
280 }
281 
282 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
283   return get(ctx, value, {}).cast<FlatSymbolRefAttr>();
284 }
285 
286 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
287   return get(value, {}).cast<FlatSymbolRefAttr>();
288 }
289 
290 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) {
291   auto symName =
292       symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
293   assert(symName && "value does not have a valid symbol name");
294   return SymbolRefAttr::get(symName);
295 }
296 
297 StringAttr SymbolRefAttr::getLeafReference() const {
298   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
299   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // IntegerAttr
304 //===----------------------------------------------------------------------===//
305 
306 int64_t IntegerAttr::getInt() const {
307   assert((getType().isIndex() || getType().isSignlessInteger()) &&
308          "must be signless integer");
309   return getValue().getSExtValue();
310 }
311 
312 int64_t IntegerAttr::getSInt() const {
313   assert(getType().isSignedInteger() && "must be signed integer");
314   return getValue().getSExtValue();
315 }
316 
317 uint64_t IntegerAttr::getUInt() const {
318   assert(getType().isUnsignedInteger() && "must be unsigned integer");
319   return getValue().getZExtValue();
320 }
321 
322 /// Return the value as an APSInt which carries the signed from the type of
323 /// the attribute.  This traps on signless integers types!
324 APSInt IntegerAttr::getAPSInt() const {
325   assert(!getType().isSignlessInteger() &&
326          "Signless integers don't carry a sign for APSInt");
327   return APSInt(getValue(), getType().isUnsignedInteger());
328 }
329 
330 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
331                                   Type type, APInt value) {
332   if (IntegerType integerType = type.dyn_cast<IntegerType>()) {
333     if (integerType.getWidth() != value.getBitWidth())
334       return emitError() << "integer type bit width (" << integerType.getWidth()
335                          << ") doesn't match value bit width ("
336                          << value.getBitWidth() << ")";
337     return success();
338   }
339   if (type.isa<IndexType>())
340     return success();
341   return emitError() << "expected integer or index type";
342 }
343 
344 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
345   auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
346   return attr.cast<BoolAttr>();
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // BoolAttr
351 
352 bool BoolAttr::getValue() const {
353   auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
354   return storage->value.getBoolValue();
355 }
356 
357 bool BoolAttr::classof(Attribute attr) {
358   IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
359   return intAttr && intAttr.getType().isSignlessInteger(1);
360 }
361 
362 //===----------------------------------------------------------------------===//
363 // OpaqueAttr
364 //===----------------------------------------------------------------------===//
365 
366 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
367                                  Identifier dialect, StringRef attrData,
368                                  Type type) {
369   if (!Dialect::isValidNamespace(dialect.strref()))
370     return emitError() << "invalid dialect namespace '" << dialect << "'";
371 
372   // Check that the dialect is actually registered.
373   MLIRContext *context = dialect.getContext();
374   if (!context->allowsUnregisteredDialects() &&
375       !context->getLoadedDialect(dialect.strref())) {
376     return emitError()
377            << "#" << dialect << "<\"" << attrData << "\"> : " << type
378            << " attribute created with unregistered dialect. If this is "
379               "intended, please call allowUnregisteredDialects() on the "
380               "MLIRContext, or use -allow-unregistered-dialect with "
381               "the MLIR opt tool used";
382   }
383 
384   return success();
385 }
386 
387 //===----------------------------------------------------------------------===//
388 // DenseElementsAttr Utilities
389 //===----------------------------------------------------------------------===//
390 
391 /// Get the bitwidth of a dense element type within the buffer.
392 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
393 static size_t getDenseElementStorageWidth(size_t origWidth) {
394   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
395 }
396 static size_t getDenseElementStorageWidth(Type elementType) {
397   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
398 }
399 
400 /// Set a bit to a specific value.
401 static void setBit(char *rawData, size_t bitPos, bool value) {
402   if (value)
403     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
404   else
405     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
406 }
407 
408 /// Return the value of the specified bit.
409 static bool getBit(const char *rawData, size_t bitPos) {
410   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
411 }
412 
413 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
414 /// BE format.
415 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
416                                          char *result) {
417   assert(llvm::support::endian::system_endianness() == // NOLINT
418          llvm::support::endianness::big);              // NOLINT
419   assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
420 
421   // Copy the words filled with data.
422   // For example, when `value` has 2 words, the first word is filled with data.
423   // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
424   size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
425   std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
426               numFilledWords, result);
427   // Convert last word of APInt to LE format and store it in char
428   // array(`valueLE`).
429   // ex. last word of `value` (BE): |------ij|  ==> `valueLE` (LE): |ji------|
430   size_t lastWordPos = numFilledWords;
431   SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
432   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
433       reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
434       valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
435   // Extract actual APInt data from `valueLE`, convert endianness to BE format,
436   // and store it in `result`.
437   // ex. `valueLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|ij|
438   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
439       valueLE.begin(), result + lastWordPos,
440       (numBytes - lastWordPos) * CHAR_BIT, 1);
441 }
442 
443 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
444 /// format.
445 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
446                                          APInt &result) {
447   assert(llvm::support::endian::system_endianness() == // NOLINT
448          llvm::support::endianness::big);              // NOLINT
449   assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
450 
451   // Copy the data that fills the word of `result` from `inArray`.
452   // For example, when `result` has 2 words, the first word will be filled with
453   // data. So, the first 8 bytes are copied from `inArray` here.
454   // `inArray` (10 bytes, BE): |abcdefgh|ij|
455   //                     ==> `result` (2 words, BE): |abcdefgh|--------|
456   size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
457   std::copy_n(
458       inArray, numFilledWords,
459       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
460 
461   // Convert array data which will be last word of `result` to LE format, and
462   // store it in char array(`inArrayLE`).
463   // ex. `inArray` (last two bytes, BE): |ij|  ==> `inArrayLE` (LE): |ji------|
464   size_t lastWordPos = numFilledWords;
465   SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
466   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
467       inArray + lastWordPos, inArrayLE.begin(),
468       (numBytes - lastWordPos) * CHAR_BIT, 1);
469 
470   // Convert `inArrayLE` to BE format, and store it in last word of `result`.
471   // ex. `inArrayLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|------ij|
472   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
473       inArrayLE.begin(),
474       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
475           lastWordPos,
476       APInt::APINT_BITS_PER_WORD, 1);
477 }
478 
479 /// Writes value to the bit position `bitPos` in array `rawData`.
480 static void writeBits(char *rawData, size_t bitPos, APInt value) {
481   size_t bitWidth = value.getBitWidth();
482 
483   // If the bitwidth is 1 we just toggle the specific bit.
484   if (bitWidth == 1)
485     return setBit(rawData, bitPos, value.isOneValue());
486 
487   // Otherwise, the bit position is guaranteed to be byte aligned.
488   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
489   if (llvm::support::endian::system_endianness() ==
490       llvm::support::endianness::big) {
491     // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
492     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
493     // work correctly in BE format.
494     // ex. `value` (2 words including 10 bytes)
495     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------|
496     copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
497                                  rawData + (bitPos / CHAR_BIT));
498   } else {
499     std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
500                 llvm::divideCeil(bitWidth, CHAR_BIT),
501                 rawData + (bitPos / CHAR_BIT));
502   }
503 }
504 
505 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
506 /// `rawData`.
507 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
508   // Handle a boolean bit position.
509   if (bitWidth == 1)
510     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
511 
512   // Otherwise, the bit position must be 8-bit aligned.
513   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
514   APInt result(bitWidth, 0);
515   if (llvm::support::endian::system_endianness() ==
516       llvm::support::endianness::big) {
517     // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
518     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
519     // work correctly in BE format.
520     // ex. `result` (2 words including 10 bytes)
521     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------| This function
522     copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
523                                  llvm::divideCeil(bitWidth, CHAR_BIT), result);
524   } else {
525     std::copy_n(rawData + (bitPos / CHAR_BIT),
526                 llvm::divideCeil(bitWidth, CHAR_BIT),
527                 const_cast<char *>(
528                     reinterpret_cast<const char *>(result.getRawData())));
529   }
530   return result;
531 }
532 
533 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
534 /// the same element count as 'type'.
535 template <typename Values>
536 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
537   return (values.size() == 1) ||
538          (type.getNumElements() == static_cast<int64_t>(values.size()));
539 }
540 
541 //===----------------------------------------------------------------------===//
542 // DenseElementsAttr Iterators
543 //===----------------------------------------------------------------------===//
544 
545 //===----------------------------------------------------------------------===//
546 // AttributeElementIterator
547 
548 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
549     DenseElementsAttr attr, size_t index)
550     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
551                                       Attribute, Attribute, Attribute>(
552           attr.getAsOpaquePointer(), index) {}
553 
554 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
555   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
556   Type eltTy = owner.getElementType();
557   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
558     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
559   if (eltTy.isa<IndexType>())
560     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
561   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
562     IntElementIterator intIt(owner, index);
563     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
564     return FloatAttr::get(eltTy, *floatIt);
565   }
566   if (auto complexTy = eltTy.dyn_cast<ComplexType>()) {
567     auto complexEltTy = complexTy.getElementType();
568     ComplexIntElementIterator complexIntIt(owner, index);
569     if (complexEltTy.isa<IntegerType>()) {
570       auto value = *complexIntIt;
571       auto real = IntegerAttr::get(complexEltTy, value.real());
572       auto imag = IntegerAttr::get(complexEltTy, value.imag());
573       return ArrayAttr::get(complexTy.getContext(),
574                             ArrayRef<Attribute>{real, imag});
575     }
576 
577     ComplexFloatElementIterator complexFloatIt(
578         complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt);
579     auto value = *complexFloatIt;
580     auto real = FloatAttr::get(complexEltTy, value.real());
581     auto imag = FloatAttr::get(complexEltTy, value.imag());
582     return ArrayAttr::get(complexTy.getContext(),
583                           ArrayRef<Attribute>{real, imag});
584   }
585   if (owner.isa<DenseStringElementsAttr>()) {
586     ArrayRef<StringRef> vals = owner.getRawStringData();
587     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
588   }
589   llvm_unreachable("unexpected element type");
590 }
591 
592 //===----------------------------------------------------------------------===//
593 // BoolElementIterator
594 
595 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
596     DenseElementsAttr attr, size_t dataIndex)
597     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
598           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
599 
600 bool DenseElementsAttr::BoolElementIterator::operator*() const {
601   return getBit(getData(), getDataIndex());
602 }
603 
604 //===----------------------------------------------------------------------===//
605 // IntElementIterator
606 
607 DenseElementsAttr::IntElementIterator::IntElementIterator(
608     DenseElementsAttr attr, size_t dataIndex)
609     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
610           attr.getRawData().data(), attr.isSplat(), dataIndex),
611       bitWidth(getDenseElementBitWidth(attr.getElementType())) {}
612 
613 APInt DenseElementsAttr::IntElementIterator::operator*() const {
614   return readBits(getData(),
615                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
616                   bitWidth);
617 }
618 
619 //===----------------------------------------------------------------------===//
620 // ComplexIntElementIterator
621 
622 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
623     DenseElementsAttr attr, size_t dataIndex)
624     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
625                                       std::complex<APInt>, std::complex<APInt>,
626                                       std::complex<APInt>>(
627           attr.getRawData().data(), attr.isSplat(), dataIndex) {
628   auto complexType = attr.getElementType().cast<ComplexType>();
629   bitWidth = getDenseElementBitWidth(complexType.getElementType());
630 }
631 
632 std::complex<APInt>
633 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
634   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
635   size_t offset = getDataIndex() * storageWidth * 2;
636   return {readBits(getData(), offset, bitWidth),
637           readBits(getData(), offset + storageWidth, bitWidth)};
638 }
639 
640 //===----------------------------------------------------------------------===//
641 // FloatElementIterator
642 
643 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
644     const llvm::fltSemantics &smt, IntElementIterator it)
645     : llvm::mapped_iterator<IntElementIterator,
646                             std::function<APFloat(const APInt &)>>(
647           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
648 
649 //===----------------------------------------------------------------------===//
650 // ComplexFloatElementIterator
651 
652 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
653     const llvm::fltSemantics &smt, ComplexIntElementIterator it)
654     : llvm::mapped_iterator<
655           ComplexIntElementIterator,
656           std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
657           it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
658             return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
659           }) {}
660 
661 //===----------------------------------------------------------------------===//
662 // DenseElementsAttr
663 //===----------------------------------------------------------------------===//
664 
665 /// Method for support type inquiry through isa, cast and dyn_cast.
666 bool DenseElementsAttr::classof(Attribute attr) {
667   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
668 }
669 
670 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
671                                          ArrayRef<Attribute> values) {
672   assert(hasSameElementsOrSplat(type, values));
673 
674   // If the element type is not based on int/float/index, assume it is a string
675   // type.
676   auto eltType = type.getElementType();
677   if (!type.getElementType().isIntOrIndexOrFloat()) {
678     SmallVector<StringRef, 8> stringValues;
679     stringValues.reserve(values.size());
680     for (Attribute attr : values) {
681       assert(attr.isa<StringAttr>() &&
682              "expected string value for non integer/index/float element");
683       stringValues.push_back(attr.cast<StringAttr>().getValue());
684     }
685     return get(type, stringValues);
686   }
687 
688   // Otherwise, get the raw storage width to use for the allocation.
689   size_t bitWidth = getDenseElementBitWidth(eltType);
690   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
691 
692   // Compress the attribute values into a character buffer.
693   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
694                             values.size());
695   APInt intVal;
696   for (unsigned i = 0, e = values.size(); i < e; ++i) {
697     assert(eltType == values[i].getType() &&
698            "expected attribute value to have element type");
699     if (eltType.isa<FloatType>())
700       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
701     else if (eltType.isa<IntegerType, IndexType>())
702       intVal = values[i].cast<IntegerAttr>().getValue();
703     else
704       llvm_unreachable("unexpected element type");
705 
706     assert(intVal.getBitWidth() == bitWidth &&
707            "expected value to have same bitwidth as element type");
708     writeBits(data.data(), i * storageBitWidth, intVal);
709   }
710   return DenseIntOrFPElementsAttr::getRaw(type, data,
711                                           /*isSplat=*/(values.size() == 1));
712 }
713 
714 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
715                                          ArrayRef<bool> values) {
716   assert(hasSameElementsOrSplat(type, values));
717   assert(type.getElementType().isInteger(1));
718 
719   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
720   for (int i = 0, e = values.size(); i != e; ++i)
721     setBit(buff.data(), i, values[i]);
722   return DenseIntOrFPElementsAttr::getRaw(type, buff,
723                                           /*isSplat=*/(values.size() == 1));
724 }
725 
726 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
727                                          ArrayRef<StringRef> values) {
728   assert(!type.getElementType().isIntOrFloat());
729   return DenseStringElementsAttr::get(type, values);
730 }
731 
732 /// Constructs a dense integer elements attribute from an array of APInt
733 /// values. Each APInt value is expected to have the same bitwidth as the
734 /// element type of 'type'.
735 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
736                                          ArrayRef<APInt> values) {
737   assert(type.getElementType().isIntOrIndex());
738   assert(hasSameElementsOrSplat(type, values));
739   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
740   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
741                                           /*isSplat=*/(values.size() == 1));
742 }
743 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
744                                          ArrayRef<std::complex<APInt>> values) {
745   ComplexType complex = type.getElementType().cast<ComplexType>();
746   assert(complex.getElementType().isa<IntegerType>());
747   assert(hasSameElementsOrSplat(type, values));
748   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
749   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
750                           values.size() * 2);
751   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
752                                           /*isSplat=*/(values.size() == 1));
753 }
754 
755 // Constructs a dense float elements attribute from an array of APFloat
756 // values. Each APFloat value is expected to have the same bitwidth as the
757 // element type of 'type'.
758 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
759                                          ArrayRef<APFloat> values) {
760   assert(type.getElementType().isa<FloatType>());
761   assert(hasSameElementsOrSplat(type, values));
762   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
763   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
764                                           /*isSplat=*/(values.size() == 1));
765 }
766 DenseElementsAttr
767 DenseElementsAttr::get(ShapedType type,
768                        ArrayRef<std::complex<APFloat>> values) {
769   ComplexType complex = type.getElementType().cast<ComplexType>();
770   assert(complex.getElementType().isa<FloatType>());
771   assert(hasSameElementsOrSplat(type, values));
772   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
773                            values.size() * 2);
774   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
775   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
776                                           /*isSplat=*/(values.size() == 1));
777 }
778 
779 /// Construct a dense elements attribute from a raw buffer representing the
780 /// data for this attribute. Users should generally not use this methods as
781 /// the expected buffer format may not be a form the user expects.
782 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
783                                                       ArrayRef<char> rawBuffer,
784                                                       bool isSplatBuffer) {
785   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
786 }
787 
788 /// Returns true if the given buffer is a valid raw buffer for the given type.
789 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
790                                          ArrayRef<char> rawBuffer,
791                                          bool &detectedSplat) {
792   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
793   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
794 
795   // Storage width of 1 is special as it is packed by the bit.
796   if (storageWidth == 1) {
797     // Check for a splat, or a buffer equal to the number of elements which
798     // consists of either all 0's or all 1's.
799     detectedSplat = false;
800     if (rawBuffer.size() == 1) {
801       auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
802       if (rawByte == 0 || rawByte == 0xff) {
803         detectedSplat = true;
804         return true;
805       }
806     }
807     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
808   }
809   // All other types are 8-bit aligned.
810   if ((detectedSplat = rawBufferWidth == storageWidth))
811     return true;
812   return rawBufferWidth == (storageWidth * type.getNumElements());
813 }
814 
815 /// Check the information for a C++ data type, check if this type is valid for
816 /// the current attribute. This method is used to verify specific type
817 /// invariants that the templatized 'getValues' method cannot.
818 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
819                               bool isSigned) {
820   // Make sure that the data element size is the same as the type element width.
821   if (getDenseElementBitWidth(type) !=
822       static_cast<size_t>(dataEltSize * CHAR_BIT))
823     return false;
824 
825   // Check that the element type is either float or integer or index.
826   if (!isInt)
827     return type.isa<FloatType>();
828   if (type.isIndex())
829     return true;
830 
831   auto intType = type.dyn_cast<IntegerType>();
832   if (!intType)
833     return false;
834 
835   // Make sure signedness semantics is consistent.
836   if (intType.isSignless())
837     return true;
838   return intType.isSigned() ? isSigned : !isSigned;
839 }
840 
841 /// Defaults down the subclass implementation.
842 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
843                                                    ArrayRef<char> data,
844                                                    int64_t dataEltSize,
845                                                    bool isInt, bool isSigned) {
846   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
847                                                  isSigned);
848 }
849 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
850                                                       ArrayRef<char> data,
851                                                       int64_t dataEltSize,
852                                                       bool isInt,
853                                                       bool isSigned) {
854   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
855                                                     isInt, isSigned);
856 }
857 
858 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
859                                           bool isSigned) const {
860   return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned);
861 }
862 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
863                                        bool isSigned) const {
864   return ::isValidIntOrFloat(
865       getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2,
866       isInt, isSigned);
867 }
868 
869 /// Returns true if this attribute corresponds to a splat, i.e. if all element
870 /// values are the same.
871 bool DenseElementsAttr::isSplat() const {
872   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
873 }
874 
875 /// Return if the given complex type has an integer element type.
876 static bool isComplexOfIntType(Type type) {
877   return type.cast<ComplexType>().getElementType().isa<IntegerType>();
878 }
879 
880 auto DenseElementsAttr::getComplexIntValues() const
881     -> llvm::iterator_range<ComplexIntElementIterator> {
882   assert(isComplexOfIntType(getElementType()) &&
883          "expected complex integral type");
884   return {ComplexIntElementIterator(*this, 0),
885           ComplexIntElementIterator(*this, getNumElements())};
886 }
887 auto DenseElementsAttr::complex_value_begin() const
888     -> ComplexIntElementIterator {
889   assert(isComplexOfIntType(getElementType()) &&
890          "expected complex integral type");
891   return ComplexIntElementIterator(*this, 0);
892 }
893 auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator {
894   assert(isComplexOfIntType(getElementType()) &&
895          "expected complex integral type");
896   return ComplexIntElementIterator(*this, getNumElements());
897 }
898 
899 /// Return the held element values as a range of APFloat. The element type of
900 /// this attribute must be of float type.
901 auto DenseElementsAttr::getFloatValues() const
902     -> llvm::iterator_range<FloatElementIterator> {
903   auto elementType = getElementType().cast<FloatType>();
904   const auto &elementSemantics = elementType.getFloatSemantics();
905   return {FloatElementIterator(elementSemantics, raw_int_begin()),
906           FloatElementIterator(elementSemantics, raw_int_end())};
907 }
908 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
909   auto elementType = getElementType().cast<FloatType>();
910   return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin());
911 }
912 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
913   auto elementType = getElementType().cast<FloatType>();
914   return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end());
915 }
916 
917 auto DenseElementsAttr::getComplexFloatValues() const
918     -> llvm::iterator_range<ComplexFloatElementIterator> {
919   Type eltTy = getElementType().cast<ComplexType>().getElementType();
920   assert(eltTy.isa<FloatType>() && "expected complex float type");
921   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
922   return {{semantics, {*this, 0}},
923           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
924 }
925 auto DenseElementsAttr::complex_float_value_begin() const
926     -> ComplexFloatElementIterator {
927   Type eltTy = getElementType().cast<ComplexType>().getElementType();
928   assert(eltTy.isa<FloatType>() && "expected complex float type");
929   return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}};
930 }
931 auto DenseElementsAttr::complex_float_value_end() const
932     -> ComplexFloatElementIterator {
933   Type eltTy = getElementType().cast<ComplexType>().getElementType();
934   assert(eltTy.isa<FloatType>() && "expected complex float type");
935   return {eltTy.cast<FloatType>().getFloatSemantics(),
936           {*this, static_cast<size_t>(getNumElements())}};
937 }
938 
939 /// Return the raw storage data held by this attribute.
940 ArrayRef<char> DenseElementsAttr::getRawData() const {
941   return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
942 }
943 
944 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
945   return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
946 }
947 
948 /// Return a new DenseElementsAttr that has the same data as the current
949 /// attribute, but has been reshaped to 'newType'. The new type must have the
950 /// same total number of elements as well as element type.
951 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
952   ShapedType curType = getType();
953   if (curType == newType)
954     return *this;
955 
956   assert(newType.getElementType() == curType.getElementType() &&
957          "expected the same element type");
958   assert(newType.getNumElements() == curType.getNumElements() &&
959          "expected the same number of elements");
960   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
961 }
962 
963 /// Return a new DenseElementsAttr that has the same data as the current
964 /// attribute, but has bitcast elements such that it is now 'newType'. The new
965 /// type must have the same shape and element types of the same bitwidth as the
966 /// current type.
967 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
968   ShapedType curType = getType();
969   Type curElType = curType.getElementType();
970   if (curElType == newElType)
971     return *this;
972 
973   assert(getDenseElementBitWidth(newElType) ==
974              getDenseElementBitWidth(curElType) &&
975          "expected element types with the same bitwidth");
976   return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
977                                           getRawData(), isSplat());
978 }
979 
980 DenseElementsAttr
981 DenseElementsAttr::mapValues(Type newElementType,
982                              function_ref<APInt(const APInt &)> mapping) const {
983   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
984 }
985 
986 DenseElementsAttr DenseElementsAttr::mapValues(
987     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
988   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
989 }
990 
991 ShapedType DenseElementsAttr::getType() const {
992   return Attribute::getType().cast<ShapedType>();
993 }
994 
995 Type DenseElementsAttr::getElementType() const {
996   return getType().getElementType();
997 }
998 
999 int64_t DenseElementsAttr::getNumElements() const {
1000   return getType().getNumElements();
1001 }
1002 
1003 //===----------------------------------------------------------------------===//
1004 // DenseIntOrFPElementsAttr
1005 //===----------------------------------------------------------------------===//
1006 
1007 /// Utility method to write a range of APInt values to a buffer.
1008 template <typename APRangeT>
1009 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1010                                 APRangeT &&values) {
1011   data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
1012   size_t offset = 0;
1013   for (auto it = values.begin(), e = values.end(); it != e;
1014        ++it, offset += storageWidth) {
1015     assert((*it).getBitWidth() <= storageWidth);
1016     writeBits(data.data(), offset, *it);
1017   }
1018 }
1019 
1020 /// Constructs a dense elements attribute from an array of raw APFloat values.
1021 /// Each APFloat value is expected to have the same bitwidth as the element
1022 /// type of 'type'. 'type' must be a vector or tensor with static shape.
1023 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1024                                                    size_t storageWidth,
1025                                                    ArrayRef<APFloat> values,
1026                                                    bool isSplat) {
1027   std::vector<char> data;
1028   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1029   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1030   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1031 }
1032 
1033 /// Constructs a dense elements attribute from an array of raw APInt values.
1034 /// Each APInt value is expected to have the same bitwidth as the element type
1035 /// of 'type'.
1036 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1037                                                    size_t storageWidth,
1038                                                    ArrayRef<APInt> values,
1039                                                    bool isSplat) {
1040   std::vector<char> data;
1041   writeAPIntsToBuffer(storageWidth, data, values);
1042   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1043 }
1044 
1045 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1046                                                    ArrayRef<char> data,
1047                                                    bool isSplat) {
1048   assert((type.isa<RankedTensorType, VectorType>()) &&
1049          "type must be ranked tensor or vector");
1050   assert(type.hasStaticShape() && "type must have static shape");
1051   return Base::get(type.getContext(), type, data, isSplat);
1052 }
1053 
1054 /// Overload of the raw 'get' method that asserts that the given type is of
1055 /// complex type. This method is used to verify type invariants that the
1056 /// templatized 'get' method cannot.
1057 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1058                                                           ArrayRef<char> data,
1059                                                           int64_t dataEltSize,
1060                                                           bool isInt,
1061                                                           bool isSigned) {
1062   assert(::isValidIntOrFloat(
1063       type.getElementType().cast<ComplexType>().getElementType(),
1064       dataEltSize / 2, isInt, isSigned));
1065 
1066   int64_t numElements = data.size() / dataEltSize;
1067   assert(numElements == 1 || numElements == type.getNumElements());
1068   return getRaw(type, data, /*isSplat=*/numElements == 1);
1069 }
1070 
1071 /// Overload of the 'getRaw' method that asserts that the given type is of
1072 /// integer type. This method is used to verify type invariants that the
1073 /// templatized 'get' method cannot.
1074 DenseElementsAttr
1075 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1076                                            int64_t dataEltSize, bool isInt,
1077                                            bool isSigned) {
1078   assert(
1079       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1080 
1081   int64_t numElements = data.size() / dataEltSize;
1082   assert(numElements == 1 || numElements == type.getNumElements());
1083   return getRaw(type, data, /*isSplat=*/numElements == 1);
1084 }
1085 
1086 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1087     const char *inRawData, char *outRawData, size_t elementBitWidth,
1088     size_t numElements) {
1089   using llvm::support::ulittle16_t;
1090   using llvm::support::ulittle32_t;
1091   using llvm::support::ulittle64_t;
1092 
1093   assert(llvm::support::endian::system_endianness() == // NOLINT
1094          llvm::support::endianness::big);              // NOLINT
1095   // NOLINT to avoid warning message about replacing by static_assert()
1096 
1097   // Following std::copy_n always converts endianness on BE machine.
1098   switch (elementBitWidth) {
1099   case 16: {
1100     const ulittle16_t *inRawDataPos =
1101         reinterpret_cast<const ulittle16_t *>(inRawData);
1102     uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1103     std::copy_n(inRawDataPos, numElements, outDataPos);
1104     break;
1105   }
1106   case 32: {
1107     const ulittle32_t *inRawDataPos =
1108         reinterpret_cast<const ulittle32_t *>(inRawData);
1109     uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1110     std::copy_n(inRawDataPos, numElements, outDataPos);
1111     break;
1112   }
1113   case 64: {
1114     const ulittle64_t *inRawDataPos =
1115         reinterpret_cast<const ulittle64_t *>(inRawData);
1116     uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1117     std::copy_n(inRawDataPos, numElements, outDataPos);
1118     break;
1119   }
1120   default: {
1121     size_t nBytes = elementBitWidth / CHAR_BIT;
1122     for (size_t i = 0; i < nBytes; i++)
1123       std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1124     break;
1125   }
1126   }
1127 }
1128 
1129 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1130     ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1131     ShapedType type) {
1132   size_t numElements = type.getNumElements();
1133   Type elementType = type.getElementType();
1134   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1135     elementType = complexTy.getElementType();
1136     numElements = numElements * 2;
1137   }
1138   size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1139   assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1140          inRawData.size() <= outRawData.size());
1141   convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1142                                   elementBitWidth, numElements);
1143 }
1144 
1145 //===----------------------------------------------------------------------===//
1146 // DenseFPElementsAttr
1147 //===----------------------------------------------------------------------===//
1148 
1149 template <typename Fn, typename Attr>
1150 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1151                                 Type newElementType,
1152                                 llvm::SmallVectorImpl<char> &data) {
1153   size_t bitWidth = getDenseElementBitWidth(newElementType);
1154   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1155 
1156   ShapedType newArrayType;
1157   if (inType.isa<RankedTensorType>())
1158     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1159   else if (inType.isa<UnrankedTensorType>())
1160     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1161   else if (inType.isa<VectorType>())
1162     newArrayType = VectorType::get(inType.getShape(), newElementType);
1163   else
1164     assert(newArrayType && "Unhandled tensor type");
1165 
1166   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1167   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1168 
1169   // Functor used to process a single element value of the attribute.
1170   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1171     auto newInt = mapping(value);
1172     assert(newInt.getBitWidth() == bitWidth);
1173     writeBits(data.data(), index * storageBitWidth, newInt);
1174   };
1175 
1176   // Check for the splat case.
1177   if (attr.isSplat()) {
1178     processElt(*attr.begin(), /*index=*/0);
1179     return newArrayType;
1180   }
1181 
1182   // Otherwise, process all of the element values.
1183   uint64_t elementIdx = 0;
1184   for (auto value : attr)
1185     processElt(value, elementIdx++);
1186   return newArrayType;
1187 }
1188 
1189 DenseElementsAttr DenseFPElementsAttr::mapValues(
1190     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1191   llvm::SmallVector<char, 8> elementData;
1192   auto newArrayType =
1193       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1194 
1195   return getRaw(newArrayType, elementData, isSplat());
1196 }
1197 
1198 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1199 bool DenseFPElementsAttr::classof(Attribute attr) {
1200   return attr.isa<DenseElementsAttr>() &&
1201          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1202 }
1203 
1204 //===----------------------------------------------------------------------===//
1205 // DenseIntElementsAttr
1206 //===----------------------------------------------------------------------===//
1207 
1208 DenseElementsAttr DenseIntElementsAttr::mapValues(
1209     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1210   llvm::SmallVector<char, 8> elementData;
1211   auto newArrayType =
1212       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1213 
1214   return getRaw(newArrayType, elementData, isSplat());
1215 }
1216 
1217 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1218 bool DenseIntElementsAttr::classof(Attribute attr) {
1219   return attr.isa<DenseElementsAttr>() &&
1220          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1221 }
1222 
1223 //===----------------------------------------------------------------------===//
1224 // OpaqueElementsAttr
1225 //===----------------------------------------------------------------------===//
1226 
1227 /// Return the value at the given index. If index does not refer to a valid
1228 /// element, then a null attribute is returned.
1229 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1230   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1231   return Attribute();
1232 }
1233 
1234 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1235   Dialect *dialect = getDialect().getDialect();
1236   if (!dialect)
1237     return true;
1238   auto *interface =
1239       dialect->getRegisteredInterface<DialectDecodeAttributesInterface>();
1240   if (!interface)
1241     return true;
1242   return failed(interface->decode(*this, result));
1243 }
1244 
1245 LogicalResult
1246 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1247                            Identifier dialect, StringRef value,
1248                            ShapedType type) {
1249   if (!Dialect::isValidNamespace(dialect.strref()))
1250     return emitError() << "invalid dialect namespace '" << dialect << "'";
1251   return success();
1252 }
1253 
1254 //===----------------------------------------------------------------------===//
1255 // SparseElementsAttr
1256 //===----------------------------------------------------------------------===//
1257 
1258 /// Return the value of the element at the given index.
1259 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1260   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1261   auto type = getType();
1262 
1263   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1264   // as a 1-D index array.
1265   auto sparseIndices = getIndices();
1266   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1267 
1268   // Check to see if the indices are a splat.
1269   if (sparseIndices.isSplat()) {
1270     // If the index is also not a splat of the index value, we know that the
1271     // value is zero.
1272     auto splatIndex = *sparseIndexValues.begin();
1273     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
1274       return getZeroAttr();
1275 
1276     // If the indices are a splat, we also expect the values to be a splat.
1277     assert(getValues().isSplat() && "expected splat values");
1278     return getValues().getSplatValue();
1279   }
1280 
1281   // Build a mapping between known indices and the offset of the stored element.
1282   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
1283   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1284   size_t rank = type.getRank();
1285   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1286     mappedIndices.try_emplace(
1287         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1288 
1289   // Look for the provided index key within the mapped indices. If the provided
1290   // index is not found, then return a zero attribute.
1291   auto it = mappedIndices.find(index);
1292   if (it == mappedIndices.end())
1293     return getZeroAttr();
1294 
1295   // Otherwise, return the held sparse value element.
1296   return getValues().getValue(it->second);
1297 }
1298 
1299 /// Get a zero APFloat for the given sparse attribute.
1300 APFloat SparseElementsAttr::getZeroAPFloat() const {
1301   auto eltType = getElementType().cast<FloatType>();
1302   return APFloat(eltType.getFloatSemantics());
1303 }
1304 
1305 /// Get a zero APInt for the given sparse attribute.
1306 APInt SparseElementsAttr::getZeroAPInt() const {
1307   auto eltType = getElementType().cast<IntegerType>();
1308   return APInt::getZero(eltType.getWidth());
1309 }
1310 
1311 /// Get a zero attribute for the given attribute type.
1312 Attribute SparseElementsAttr::getZeroAttr() const {
1313   auto eltType = getElementType();
1314 
1315   // Handle floating point elements.
1316   if (eltType.isa<FloatType>())
1317     return FloatAttr::get(eltType, 0);
1318 
1319   // Otherwise, this is an integer.
1320   // TODO: Handle StringAttr here.
1321   return IntegerAttr::get(eltType, 0);
1322 }
1323 
1324 /// Flatten, and return, all of the sparse indices in this attribute in
1325 /// row-major order.
1326 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1327   std::vector<ptrdiff_t> flatSparseIndices;
1328 
1329   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1330   // as a 1-D index array.
1331   auto sparseIndices = getIndices();
1332   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1333   if (sparseIndices.isSplat()) {
1334     SmallVector<uint64_t, 8> indices(getType().getRank(),
1335                                      *sparseIndexValues.begin());
1336     flatSparseIndices.push_back(getFlattenedIndex(indices));
1337     return flatSparseIndices;
1338   }
1339 
1340   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1341   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1342   size_t rank = getType().getRank();
1343   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1344     flatSparseIndices.push_back(getFlattenedIndex(
1345         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1346   return flatSparseIndices;
1347 }
1348 
1349 LogicalResult
1350 SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1351                            ShapedType type, DenseIntElementsAttr sparseIndices,
1352                            DenseElementsAttr values) {
1353   ShapedType valuesType = values.getType();
1354   if (valuesType.getRank() != 1)
1355     return emitError() << "expected 1-d tensor for sparse element values";
1356 
1357   // Verify the indices and values shape.
1358   ShapedType indicesType = sparseIndices.getType();
1359   auto emitShapeError = [&]() {
1360     return emitError() << "expected shape ([" << type.getShape()
1361                        << "]); inferred shape of indices literal (["
1362                        << indicesType.getShape()
1363                        << "]); inferred shape of values literal (["
1364                        << valuesType.getShape() << "])";
1365   };
1366   // Verify indices shape.
1367   size_t rank = type.getRank(), indicesRank = indicesType.getRank();
1368   if (indicesRank == 2) {
1369     if (indicesType.getDimSize(1) != static_cast<int64_t>(rank))
1370       return emitShapeError();
1371   } else if (indicesRank != 1 || rank != 1) {
1372     return emitShapeError();
1373   }
1374   // Verify the values shape.
1375   int64_t numSparseIndices = indicesType.getDimSize(0);
1376   if (numSparseIndices != valuesType.getDimSize(0))
1377     return emitShapeError();
1378 
1379   // Verify that the sparse indices are within the value shape.
1380   auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) {
1381     return emitError()
1382            << "sparse index #" << indexNum
1383            << " is not contained within the value shape, with index=[" << index
1384            << "], and type=" << type;
1385   };
1386 
1387   // Handle the case where the index values are a splat.
1388   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1389   if (sparseIndices.isSplat()) {
1390     SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin());
1391     if (!ElementsAttr::isValidIndex(type, indices))
1392       return emitIndexError(0, indices);
1393     return success();
1394   }
1395 
1396   // Otherwise, reinterpret each index as an ArrayRef.
1397   for (size_t i = 0, e = numSparseIndices; i != e; ++i) {
1398     ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank),
1399                              rank);
1400     if (!ElementsAttr::isValidIndex(type, index))
1401       return emitIndexError(i, index);
1402   }
1403 
1404   return success();
1405 }
1406 
1407 //===----------------------------------------------------------------------===//
1408 // TypeAttr
1409 //===----------------------------------------------------------------------===//
1410 
1411 void TypeAttr::walkImmediateSubElements(
1412     function_ref<void(Attribute)> walkAttrsFn,
1413     function_ref<void(Type)> walkTypesFn) const {
1414   walkTypesFn(getValue());
1415 }
1416