1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Diagnostics.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/IntegerSet.h"
15 #include "mlir/IR/Types.h"
16 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Endian.h"
20 
21 using namespace mlir;
22 using namespace mlir::detail;
23 
24 //===----------------------------------------------------------------------===//
25 /// Tablegen Attribute Definitions
26 //===----------------------------------------------------------------------===//
27 
28 #define GET_ATTRDEF_CLASSES
29 #include "mlir/IR/BuiltinAttributes.cpp.inc"
30 
31 //===----------------------------------------------------------------------===//
32 // DictionaryAttr
33 //===----------------------------------------------------------------------===//
34 
35 /// Helper function that does either an in place sort or sorts from source array
36 /// into destination. If inPlace then storage is both the source and the
37 /// destination, else value is the source and storage destination. Returns
38 /// whether source was sorted.
39 template <bool inPlace>
40 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
41                                SmallVectorImpl<NamedAttribute> &storage) {
42   // Specialize for the common case.
43   switch (value.size()) {
44   case 0:
45     // Zero already sorted.
46     break;
47   case 1:
48     // One already sorted but may need to be copied.
49     if (!inPlace)
50       storage.assign({value[0]});
51     break;
52   case 2: {
53     bool isSorted = value[0] < value[1];
54     if (inPlace) {
55       if (!isSorted)
56         std::swap(storage[0], storage[1]);
57     } else if (isSorted) {
58       storage.assign({value[0], value[1]});
59     } else {
60       storage.assign({value[1], value[0]});
61     }
62     return !isSorted;
63   }
64   default:
65     if (!inPlace)
66       storage.assign(value.begin(), value.end());
67     // Check to see they are sorted already.
68     bool isSorted = llvm::is_sorted(value);
69     if (!isSorted) {
70       // If not, do a general sort.
71       llvm::array_pod_sort(storage.begin(), storage.end());
72       value = storage;
73     }
74     return !isSorted;
75   }
76   return false;
77 }
78 
79 /// Returns an entry with a duplicate name from the given sorted array of named
80 /// attributes. Returns llvm::None if all elements have unique names.
81 static Optional<NamedAttribute>
82 findDuplicateElement(ArrayRef<NamedAttribute> value) {
83   const Optional<NamedAttribute> none{llvm::None};
84   if (value.size() < 2)
85     return none;
86 
87   if (value.size() == 2)
88     return value[0].first == value[1].first ? value[0] : none;
89 
90   auto it = std::adjacent_find(
91       value.begin(), value.end(),
92       [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; });
93   return it != value.end() ? *it : none;
94 }
95 
96 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
97                           SmallVectorImpl<NamedAttribute> &storage) {
98   bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
99   assert(!findDuplicateElement(storage) &&
100          "DictionaryAttr element names must be unique");
101   return isSorted;
102 }
103 
104 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
105   bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
106   assert(!findDuplicateElement(array) &&
107          "DictionaryAttr element names must be unique");
108   return isSorted;
109 }
110 
111 Optional<NamedAttribute>
112 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
113                               bool isSorted) {
114   if (!isSorted)
115     dictionaryAttrSort</*inPlace=*/true>(array, array);
116   return findDuplicateElement(array);
117 }
118 
119 DictionaryAttr DictionaryAttr::get(MLIRContext *context,
120                                    ArrayRef<NamedAttribute> value) {
121   if (value.empty())
122     return DictionaryAttr::getEmpty(context);
123   assert(llvm::all_of(value,
124                       [](const NamedAttribute &attr) { return attr.second; }) &&
125          "value cannot have null entries");
126 
127   // We need to sort the element list to canonicalize it.
128   SmallVector<NamedAttribute, 8> storage;
129   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
130     value = storage;
131   assert(!findDuplicateElement(value) &&
132          "DictionaryAttr element names must be unique");
133   return Base::get(context, value);
134 }
135 /// Construct a dictionary with an array of values that is known to already be
136 /// sorted by name and uniqued.
137 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
138                                              ArrayRef<NamedAttribute> value) {
139   if (value.empty())
140     return DictionaryAttr::getEmpty(context);
141   // Ensure that the attribute elements are unique and sorted.
142   assert(llvm::is_sorted(value,
143                          [](NamedAttribute l, NamedAttribute r) {
144                            return l.first.strref() < r.first.strref();
145                          }) &&
146          "expected attribute values to be sorted");
147   assert(!findDuplicateElement(value) &&
148          "DictionaryAttr element names must be unique");
149   return Base::get(context, value);
150 }
151 
152 /// Return the specified attribute if present, null otherwise.
153 Attribute DictionaryAttr::get(StringRef name) const {
154   Optional<NamedAttribute> attr = getNamed(name);
155   return attr ? attr->second : nullptr;
156 }
157 Attribute DictionaryAttr::get(Identifier name) const {
158   Optional<NamedAttribute> attr = getNamed(name);
159   return attr ? attr->second : nullptr;
160 }
161 
162 /// Return the specified named attribute if present, None otherwise.
163 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
164   ArrayRef<NamedAttribute> values = getValue();
165   const auto *it = llvm::lower_bound(values, name);
166   return it != values.end() && it->first == name ? *it
167                                                  : Optional<NamedAttribute>();
168 }
169 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
170   for (auto elt : getValue())
171     if (elt.first == name)
172       return elt;
173   return llvm::None;
174 }
175 
176 DictionaryAttr::iterator DictionaryAttr::begin() const {
177   return getValue().begin();
178 }
179 DictionaryAttr::iterator DictionaryAttr::end() const {
180   return getValue().end();
181 }
182 size_t DictionaryAttr::size() const { return getValue().size(); }
183 
184 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
185   return Base::get(context, ArrayRef<NamedAttribute>());
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // FloatAttr
190 //===----------------------------------------------------------------------===//
191 
192 FloatAttr FloatAttr::get(Type type, double value) {
193   return Base::get(type.getContext(), type, value);
194 }
195 
196 FloatAttr FloatAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
197                                 Type type, double value) {
198   return Base::getChecked(emitError, type.getContext(), type, value);
199 }
200 
201 FloatAttr FloatAttr::get(Type type, const APFloat &value) {
202   return Base::get(type.getContext(), type, value);
203 }
204 
205 FloatAttr FloatAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
206                                 Type type, const APFloat &value) {
207   return Base::getChecked(emitError, type.getContext(), type, value);
208 }
209 
210 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
211 
212 double FloatAttr::getValueAsDouble() const {
213   return getValueAsDouble(getValue());
214 }
215 double FloatAttr::getValueAsDouble(APFloat value) {
216   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
217     bool losesInfo = false;
218     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
219                   &losesInfo);
220   }
221   return value.convertToDouble();
222 }
223 
224 /// Verify construction invariants.
225 static LogicalResult
226 verifyFloatTypeInvariants(function_ref<InFlightDiagnostic()> emitError,
227                           Type type) {
228   if (!type.isa<FloatType>())
229     return emitError() << "expected floating point type";
230   return success();
231 }
232 
233 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
234                                 Type type, double value) {
235   return verifyFloatTypeInvariants(emitError, type);
236 }
237 
238 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
239                                 Type type, const APFloat &value) {
240   // Verify that the type is correct.
241   if (failed(verifyFloatTypeInvariants(emitError, type)))
242     return failure();
243 
244   // Verify that the type semantics match that of the value.
245   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
246     return emitError()
247            << "FloatAttr type doesn't match the type implied by its value";
248   }
249   return success();
250 }
251 
252 //===----------------------------------------------------------------------===//
253 // SymbolRefAttr
254 //===----------------------------------------------------------------------===//
255 
256 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
257   return get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
258 }
259 
260 StringRef SymbolRefAttr::getLeafReference() const {
261   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
262   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
263 }
264 
265 //===----------------------------------------------------------------------===//
266 // IntegerAttr
267 //===----------------------------------------------------------------------===//
268 
269 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
270   if (type.isSignlessInteger(1))
271     return BoolAttr::get(type.getContext(), value.getBoolValue());
272   return Base::get(type.getContext(), type, value);
273 }
274 
275 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
276   // This uses 64 bit APInts by default for index type.
277   if (type.isIndex())
278     return get(type, APInt(IndexType::kInternalStorageBitWidth, value));
279 
280   auto intType = type.cast<IntegerType>();
281   return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
282 }
283 
284 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
285 
286 int64_t IntegerAttr::getInt() const {
287   assert((getImpl()->getType().isIndex() ||
288           getImpl()->getType().isSignlessInteger()) &&
289          "must be signless integer");
290   return getValue().getSExtValue();
291 }
292 
293 int64_t IntegerAttr::getSInt() const {
294   assert(getImpl()->getType().isSignedInteger() && "must be signed integer");
295   return getValue().getSExtValue();
296 }
297 
298 uint64_t IntegerAttr::getUInt() const {
299   assert(getImpl()->getType().isUnsignedInteger() &&
300          "must be unsigned integer");
301   return getValue().getZExtValue();
302 }
303 
304 static LogicalResult
305 verifyIntegerTypeInvariants(function_ref<InFlightDiagnostic()> emitError,
306                             Type type) {
307   if (type.isa<IntegerType, IndexType>())
308     return success();
309   return emitError() << "expected integer or index type";
310 }
311 
312 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
313                                   Type type, int64_t value) {
314   return verifyIntegerTypeInvariants(emitError, type);
315 }
316 
317 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
318                                   Type type, const APInt &value) {
319   if (failed(verifyIntegerTypeInvariants(emitError, type)))
320     return failure();
321   if (auto integerType = type.dyn_cast<IntegerType>())
322     if (integerType.getWidth() != value.getBitWidth())
323       return emitError() << "integer type bit width (" << integerType.getWidth()
324                          << ") doesn't match value bit width ("
325                          << value.getBitWidth() << ")";
326   return success();
327 }
328 
329 //===----------------------------------------------------------------------===//
330 // BoolAttr
331 
332 bool BoolAttr::getValue() const {
333   auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl);
334   return storage->getValue().getBoolValue();
335 }
336 
337 bool BoolAttr::classof(Attribute attr) {
338   IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
339   return intAttr && intAttr.getType().isSignlessInteger(1);
340 }
341 
342 //===----------------------------------------------------------------------===//
343 // OpaqueAttr
344 //===----------------------------------------------------------------------===//
345 
346 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
347                                  Identifier dialect, StringRef attrData,
348                                  Type type) {
349   if (!Dialect::isValidNamespace(dialect.strref()))
350     return emitError() << "invalid dialect namespace '" << dialect << "'";
351   return success();
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // ElementsAttr
356 //===----------------------------------------------------------------------===//
357 
358 ShapedType ElementsAttr::getType() const {
359   return Attribute::getType().cast<ShapedType>();
360 }
361 
362 /// Returns the number of elements held by this attribute.
363 int64_t ElementsAttr::getNumElements() const {
364   return getType().getNumElements();
365 }
366 
367 /// Return the value at the given index. If index does not refer to a valid
368 /// element, then a null attribute is returned.
369 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
370   if (auto denseAttr = dyn_cast<DenseElementsAttr>())
371     return denseAttr.getValue(index);
372   if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>())
373     return opaqueAttr.getValue(index);
374   return cast<SparseElementsAttr>().getValue(index);
375 }
376 
377 /// Return if the given 'index' refers to a valid element in this attribute.
378 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
379   auto type = getType();
380 
381   // Verify that the rank of the indices matches the held type.
382   auto rank = type.getRank();
383   if (rank == 0 && index.size() == 1 && index[0] == 0)
384     return true;
385   if (rank != static_cast<int64_t>(index.size()))
386     return false;
387 
388   // Verify that all of the indices are within the shape dimensions.
389   auto shape = type.getShape();
390   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
391     return static_cast<int64_t>(index[i]) < shape[i];
392   });
393 }
394 
395 ElementsAttr
396 ElementsAttr::mapValues(Type newElementType,
397                         function_ref<APInt(const APInt &)> mapping) const {
398   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
399     return intOrFpAttr.mapValues(newElementType, mapping);
400   llvm_unreachable("unsupported ElementsAttr subtype");
401 }
402 
403 ElementsAttr
404 ElementsAttr::mapValues(Type newElementType,
405                         function_ref<APInt(const APFloat &)> mapping) const {
406   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
407     return intOrFpAttr.mapValues(newElementType, mapping);
408   llvm_unreachable("unsupported ElementsAttr subtype");
409 }
410 
411 /// Method for support type inquiry through isa, cast and dyn_cast.
412 bool ElementsAttr::classof(Attribute attr) {
413   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr,
414                   OpaqueElementsAttr, SparseElementsAttr>();
415 }
416 
417 /// Returns the 1 dimensional flattened row-major index from the given
418 /// multi-dimensional index.
419 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
420   assert(isValidIndex(index) && "expected valid multi-dimensional index");
421   auto type = getType();
422 
423   // Reduce the provided multidimensional index into a flattended 1D row-major
424   // index.
425   auto rank = type.getRank();
426   auto shape = type.getShape();
427   uint64_t valueIndex = 0;
428   uint64_t dimMultiplier = 1;
429   for (int i = rank - 1; i >= 0; --i) {
430     valueIndex += index[i] * dimMultiplier;
431     dimMultiplier *= shape[i];
432   }
433   return valueIndex;
434 }
435 
436 //===----------------------------------------------------------------------===//
437 // DenseElementsAttr Utilities
438 //===----------------------------------------------------------------------===//
439 
440 /// Get the bitwidth of a dense element type within the buffer.
441 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
442 static size_t getDenseElementStorageWidth(size_t origWidth) {
443   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
444 }
445 static size_t getDenseElementStorageWidth(Type elementType) {
446   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
447 }
448 
449 /// Set a bit to a specific value.
450 static void setBit(char *rawData, size_t bitPos, bool value) {
451   if (value)
452     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
453   else
454     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
455 }
456 
457 /// Return the value of the specified bit.
458 static bool getBit(const char *rawData, size_t bitPos) {
459   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
460 }
461 
462 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
463 /// BE format.
464 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
465                                          char *result) {
466   assert(llvm::support::endian::system_endianness() == // NOLINT
467          llvm::support::endianness::big);              // NOLINT
468   assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
469 
470   // Copy the words filled with data.
471   // For example, when `value` has 2 words, the first word is filled with data.
472   // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
473   size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
474   std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
475               numFilledWords, result);
476   // Convert last word of APInt to LE format and store it in char
477   // array(`valueLE`).
478   // ex. last word of `value` (BE): |------ij|  ==> `valueLE` (LE): |ji------|
479   size_t lastWordPos = numFilledWords;
480   SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
481   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
482       reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
483       valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
484   // Extract actual APInt data from `valueLE`, convert endianness to BE format,
485   // and store it in `result`.
486   // ex. `valueLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|ij|
487   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
488       valueLE.begin(), result + lastWordPos,
489       (numBytes - lastWordPos) * CHAR_BIT, 1);
490 }
491 
492 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
493 /// format.
494 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
495                                          APInt &result) {
496   assert(llvm::support::endian::system_endianness() == // NOLINT
497          llvm::support::endianness::big);              // NOLINT
498   assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
499 
500   // Copy the data that fills the word of `result` from `inArray`.
501   // For example, when `result` has 2 words, the first word will be filled with
502   // data. So, the first 8 bytes are copied from `inArray` here.
503   // `inArray` (10 bytes, BE): |abcdefgh|ij|
504   //                     ==> `result` (2 words, BE): |abcdefgh|--------|
505   size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
506   std::copy_n(
507       inArray, numFilledWords,
508       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
509 
510   // Convert array data which will be last word of `result` to LE format, and
511   // store it in char array(`inArrayLE`).
512   // ex. `inArray` (last two bytes, BE): |ij|  ==> `inArrayLE` (LE): |ji------|
513   size_t lastWordPos = numFilledWords;
514   SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
515   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
516       inArray + lastWordPos, inArrayLE.begin(),
517       (numBytes - lastWordPos) * CHAR_BIT, 1);
518 
519   // Convert `inArrayLE` to BE format, and store it in last word of `result`.
520   // ex. `inArrayLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|------ij|
521   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
522       inArrayLE.begin(),
523       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
524           lastWordPos,
525       APInt::APINT_BITS_PER_WORD, 1);
526 }
527 
528 /// Writes value to the bit position `bitPos` in array `rawData`.
529 static void writeBits(char *rawData, size_t bitPos, APInt value) {
530   size_t bitWidth = value.getBitWidth();
531 
532   // If the bitwidth is 1 we just toggle the specific bit.
533   if (bitWidth == 1)
534     return setBit(rawData, bitPos, value.isOneValue());
535 
536   // Otherwise, the bit position is guaranteed to be byte aligned.
537   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
538   if (llvm::support::endian::system_endianness() ==
539       llvm::support::endianness::big) {
540     // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
541     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
542     // work correctly in BE format.
543     // ex. `value` (2 words including 10 bytes)
544     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------|
545     copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
546                                  rawData + (bitPos / CHAR_BIT));
547   } else {
548     std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
549                 llvm::divideCeil(bitWidth, CHAR_BIT),
550                 rawData + (bitPos / CHAR_BIT));
551   }
552 }
553 
554 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
555 /// `rawData`.
556 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
557   // Handle a boolean bit position.
558   if (bitWidth == 1)
559     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
560 
561   // Otherwise, the bit position must be 8-bit aligned.
562   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
563   APInt result(bitWidth, 0);
564   if (llvm::support::endian::system_endianness() ==
565       llvm::support::endianness::big) {
566     // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
567     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
568     // work correctly in BE format.
569     // ex. `result` (2 words including 10 bytes)
570     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------| This function
571     copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
572                                  llvm::divideCeil(bitWidth, CHAR_BIT), result);
573   } else {
574     std::copy_n(rawData + (bitPos / CHAR_BIT),
575                 llvm::divideCeil(bitWidth, CHAR_BIT),
576                 const_cast<char *>(
577                     reinterpret_cast<const char *>(result.getRawData())));
578   }
579   return result;
580 }
581 
582 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
583 /// the same element count as 'type'.
584 template <typename Values>
585 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
586   return (values.size() == 1) ||
587          (type.getNumElements() == static_cast<int64_t>(values.size()));
588 }
589 
590 //===----------------------------------------------------------------------===//
591 // DenseElementsAttr Iterators
592 //===----------------------------------------------------------------------===//
593 
594 //===----------------------------------------------------------------------===//
595 // AttributeElementIterator
596 
597 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
598     DenseElementsAttr attr, size_t index)
599     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
600                                       Attribute, Attribute, Attribute>(
601           attr.getAsOpaquePointer(), index) {}
602 
603 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
604   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
605   Type eltTy = owner.getType().getElementType();
606   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
607     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
608   if (eltTy.isa<IndexType>())
609     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
610   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
611     IntElementIterator intIt(owner, index);
612     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
613     return FloatAttr::get(eltTy, *floatIt);
614   }
615   if (owner.isa<DenseStringElementsAttr>()) {
616     ArrayRef<StringRef> vals = owner.getRawStringData();
617     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
618   }
619   llvm_unreachable("unexpected element type");
620 }
621 
622 //===----------------------------------------------------------------------===//
623 // BoolElementIterator
624 
625 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
626     DenseElementsAttr attr, size_t dataIndex)
627     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
628           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
629 
630 bool DenseElementsAttr::BoolElementIterator::operator*() const {
631   return getBit(getData(), getDataIndex());
632 }
633 
634 //===----------------------------------------------------------------------===//
635 // IntElementIterator
636 
637 DenseElementsAttr::IntElementIterator::IntElementIterator(
638     DenseElementsAttr attr, size_t dataIndex)
639     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
640           attr.getRawData().data(), attr.isSplat(), dataIndex),
641       bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
642 
643 APInt DenseElementsAttr::IntElementIterator::operator*() const {
644   return readBits(getData(),
645                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
646                   bitWidth);
647 }
648 
649 //===----------------------------------------------------------------------===//
650 // ComplexIntElementIterator
651 
652 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
653     DenseElementsAttr attr, size_t dataIndex)
654     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
655                                       std::complex<APInt>, std::complex<APInt>,
656                                       std::complex<APInt>>(
657           attr.getRawData().data(), attr.isSplat(), dataIndex) {
658   auto complexType = attr.getType().getElementType().cast<ComplexType>();
659   bitWidth = getDenseElementBitWidth(complexType.getElementType());
660 }
661 
662 std::complex<APInt>
663 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
664   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
665   size_t offset = getDataIndex() * storageWidth * 2;
666   return {readBits(getData(), offset, bitWidth),
667           readBits(getData(), offset + storageWidth, bitWidth)};
668 }
669 
670 //===----------------------------------------------------------------------===//
671 // FloatElementIterator
672 
673 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
674     const llvm::fltSemantics &smt, IntElementIterator it)
675     : llvm::mapped_iterator<IntElementIterator,
676                             std::function<APFloat(const APInt &)>>(
677           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
678 
679 //===----------------------------------------------------------------------===//
680 // ComplexFloatElementIterator
681 
682 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
683     const llvm::fltSemantics &smt, ComplexIntElementIterator it)
684     : llvm::mapped_iterator<
685           ComplexIntElementIterator,
686           std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
687           it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
688             return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
689           }) {}
690 
691 //===----------------------------------------------------------------------===//
692 // DenseElementsAttr
693 //===----------------------------------------------------------------------===//
694 
695 /// Method for support type inquiry through isa, cast and dyn_cast.
696 bool DenseElementsAttr::classof(Attribute attr) {
697   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
698 }
699 
700 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
701                                          ArrayRef<Attribute> values) {
702   assert(hasSameElementsOrSplat(type, values));
703 
704   // If the element type is not based on int/float/index, assume it is a string
705   // type.
706   auto eltType = type.getElementType();
707   if (!type.getElementType().isIntOrIndexOrFloat()) {
708     SmallVector<StringRef, 8> stringValues;
709     stringValues.reserve(values.size());
710     for (Attribute attr : values) {
711       assert(attr.isa<StringAttr>() &&
712              "expected string value for non integer/index/float element");
713       stringValues.push_back(attr.cast<StringAttr>().getValue());
714     }
715     return get(type, stringValues);
716   }
717 
718   // Otherwise, get the raw storage width to use for the allocation.
719   size_t bitWidth = getDenseElementBitWidth(eltType);
720   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
721 
722   // Compress the attribute values into a character buffer.
723   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
724                             values.size());
725   APInt intVal;
726   for (unsigned i = 0, e = values.size(); i < e; ++i) {
727     assert(eltType == values[i].getType() &&
728            "expected attribute value to have element type");
729     if (eltType.isa<FloatType>())
730       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
731     else if (eltType.isa<IntegerType>())
732       intVal = values[i].cast<IntegerAttr>().getValue();
733     else
734       llvm_unreachable("unexpected element type");
735 
736     assert(intVal.getBitWidth() == bitWidth &&
737            "expected value to have same bitwidth as element type");
738     writeBits(data.data(), i * storageBitWidth, intVal);
739   }
740   return DenseIntOrFPElementsAttr::getRaw(type, data,
741                                           /*isSplat=*/(values.size() == 1));
742 }
743 
744 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
745                                          ArrayRef<bool> values) {
746   assert(hasSameElementsOrSplat(type, values));
747   assert(type.getElementType().isInteger(1));
748 
749   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
750   for (int i = 0, e = values.size(); i != e; ++i)
751     setBit(buff.data(), i, values[i]);
752   return DenseIntOrFPElementsAttr::getRaw(type, buff,
753                                           /*isSplat=*/(values.size() == 1));
754 }
755 
756 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
757                                          ArrayRef<StringRef> values) {
758   assert(!type.getElementType().isIntOrFloat());
759   return DenseStringElementsAttr::get(type, values);
760 }
761 
762 /// Constructs a dense integer elements attribute from an array of APInt
763 /// values. Each APInt value is expected to have the same bitwidth as the
764 /// element type of 'type'.
765 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
766                                          ArrayRef<APInt> values) {
767   assert(type.getElementType().isIntOrIndex());
768   assert(hasSameElementsOrSplat(type, values));
769   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
770   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
771                                           /*isSplat=*/(values.size() == 1));
772 }
773 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
774                                          ArrayRef<std::complex<APInt>> values) {
775   ComplexType complex = type.getElementType().cast<ComplexType>();
776   assert(complex.getElementType().isa<IntegerType>());
777   assert(hasSameElementsOrSplat(type, values));
778   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
779   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
780                           values.size() * 2);
781   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
782                                           /*isSplat=*/(values.size() == 1));
783 }
784 
785 // Constructs a dense float elements attribute from an array of APFloat
786 // values. Each APFloat value is expected to have the same bitwidth as the
787 // element type of 'type'.
788 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
789                                          ArrayRef<APFloat> values) {
790   assert(type.getElementType().isa<FloatType>());
791   assert(hasSameElementsOrSplat(type, values));
792   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
793   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
794                                           /*isSplat=*/(values.size() == 1));
795 }
796 DenseElementsAttr
797 DenseElementsAttr::get(ShapedType type,
798                        ArrayRef<std::complex<APFloat>> values) {
799   ComplexType complex = type.getElementType().cast<ComplexType>();
800   assert(complex.getElementType().isa<FloatType>());
801   assert(hasSameElementsOrSplat(type, values));
802   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
803                            values.size() * 2);
804   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
805   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
806                                           /*isSplat=*/(values.size() == 1));
807 }
808 
809 /// Construct a dense elements attribute from a raw buffer representing the
810 /// data for this attribute. Users should generally not use this methods as
811 /// the expected buffer format may not be a form the user expects.
812 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
813                                                       ArrayRef<char> rawBuffer,
814                                                       bool isSplatBuffer) {
815   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
816 }
817 
818 /// Returns true if the given buffer is a valid raw buffer for the given type.
819 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
820                                          ArrayRef<char> rawBuffer,
821                                          bool &detectedSplat) {
822   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
823   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
824 
825   // Storage width of 1 is special as it is packed by the bit.
826   if (storageWidth == 1) {
827     // Check for a splat, or a buffer equal to the number of elements.
828     if ((detectedSplat = rawBuffer.size() == 1))
829       return true;
830     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
831   }
832   // All other types are 8-bit aligned.
833   if ((detectedSplat = rawBufferWidth == storageWidth))
834     return true;
835   return rawBufferWidth == (storageWidth * type.getNumElements());
836 }
837 
838 /// Check the information for a C++ data type, check if this type is valid for
839 /// the current attribute. This method is used to verify specific type
840 /// invariants that the templatized 'getValues' method cannot.
841 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
842                               bool isSigned) {
843   // Make sure that the data element size is the same as the type element width.
844   if (getDenseElementBitWidth(type) !=
845       static_cast<size_t>(dataEltSize * CHAR_BIT))
846     return false;
847 
848   // Check that the element type is either float or integer or index.
849   if (!isInt)
850     return type.isa<FloatType>();
851   if (type.isIndex())
852     return true;
853 
854   auto intType = type.dyn_cast<IntegerType>();
855   if (!intType)
856     return false;
857 
858   // Make sure signedness semantics is consistent.
859   if (intType.isSignless())
860     return true;
861   return intType.isSigned() ? isSigned : !isSigned;
862 }
863 
864 /// Defaults down the subclass implementation.
865 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
866                                                    ArrayRef<char> data,
867                                                    int64_t dataEltSize,
868                                                    bool isInt, bool isSigned) {
869   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
870                                                  isSigned);
871 }
872 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
873                                                       ArrayRef<char> data,
874                                                       int64_t dataEltSize,
875                                                       bool isInt,
876                                                       bool isSigned) {
877   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
878                                                     isInt, isSigned);
879 }
880 
881 /// A method used to verify specific type invariants that the templatized 'get'
882 /// method cannot.
883 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
884                                           bool isSigned) const {
885   return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
886                              isSigned);
887 }
888 
889 /// Check the information for a C++ data type, check if this type is valid for
890 /// the current attribute.
891 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
892                                        bool isSigned) const {
893   return ::isValidIntOrFloat(
894       getType().getElementType().cast<ComplexType>().getElementType(),
895       dataEltSize / 2, isInt, isSigned);
896 }
897 
898 /// Returns true if this attribute corresponds to a splat, i.e. if all element
899 /// values are the same.
900 bool DenseElementsAttr::isSplat() const {
901   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
902 }
903 
904 /// Return the held element values as a range of Attributes.
905 auto DenseElementsAttr::getAttributeValues() const
906     -> llvm::iterator_range<AttributeElementIterator> {
907   return {attr_value_begin(), attr_value_end()};
908 }
909 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
910   return AttributeElementIterator(*this, 0);
911 }
912 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
913   return AttributeElementIterator(*this, getNumElements());
914 }
915 
916 /// Return the held element values as a range of bool. The element type of
917 /// this attribute must be of integer type of bitwidth 1.
918 auto DenseElementsAttr::getBoolValues() const
919     -> llvm::iterator_range<BoolElementIterator> {
920   auto eltType = getType().getElementType().dyn_cast<IntegerType>();
921   assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
922   (void)eltType;
923   return {BoolElementIterator(*this, 0),
924           BoolElementIterator(*this, getNumElements())};
925 }
926 
927 /// Return the held element values as a range of APInts. The element type of
928 /// this attribute must be of integer type.
929 auto DenseElementsAttr::getIntValues() const
930     -> llvm::iterator_range<IntElementIterator> {
931   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
932   return {raw_int_begin(), raw_int_end()};
933 }
934 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
935   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
936   return raw_int_begin();
937 }
938 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
939   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
940   return raw_int_end();
941 }
942 auto DenseElementsAttr::getComplexIntValues() const
943     -> llvm::iterator_range<ComplexIntElementIterator> {
944   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
945   (void)eltTy;
946   assert(eltTy.isa<IntegerType>() && "expected complex integral type");
947   return {ComplexIntElementIterator(*this, 0),
948           ComplexIntElementIterator(*this, getNumElements())};
949 }
950 
951 /// Return the held element values as a range of APFloat. The element type of
952 /// this attribute must be of float type.
953 auto DenseElementsAttr::getFloatValues() const
954     -> llvm::iterator_range<FloatElementIterator> {
955   auto elementType = getType().getElementType().cast<FloatType>();
956   const auto &elementSemantics = elementType.getFloatSemantics();
957   return {FloatElementIterator(elementSemantics, raw_int_begin()),
958           FloatElementIterator(elementSemantics, raw_int_end())};
959 }
960 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
961   return getFloatValues().begin();
962 }
963 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
964   return getFloatValues().end();
965 }
966 auto DenseElementsAttr::getComplexFloatValues() const
967     -> llvm::iterator_range<ComplexFloatElementIterator> {
968   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
969   assert(eltTy.isa<FloatType>() && "expected complex float type");
970   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
971   return {{semantics, {*this, 0}},
972           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
973 }
974 
975 /// Return the raw storage data held by this attribute.
976 ArrayRef<char> DenseElementsAttr::getRawData() const {
977   return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data;
978 }
979 
980 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
981   return static_cast<DenseStringElementsAttributeStorage *>(impl)->data;
982 }
983 
984 /// Return a new DenseElementsAttr that has the same data as the current
985 /// attribute, but has been reshaped to 'newType'. The new type must have the
986 /// same total number of elements as well as element type.
987 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
988   ShapedType curType = getType();
989   if (curType == newType)
990     return *this;
991 
992   (void)curType;
993   assert(newType.getElementType() == curType.getElementType() &&
994          "expected the same element type");
995   assert(newType.getNumElements() == curType.getNumElements() &&
996          "expected the same number of elements");
997   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
998 }
999 
1000 DenseElementsAttr
1001 DenseElementsAttr::mapValues(Type newElementType,
1002                              function_ref<APInt(const APInt &)> mapping) const {
1003   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1004 }
1005 
1006 DenseElementsAttr DenseElementsAttr::mapValues(
1007     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1008   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1009 }
1010 
1011 //===----------------------------------------------------------------------===//
1012 // DenseStringElementsAttr
1013 //===----------------------------------------------------------------------===//
1014 
1015 DenseStringElementsAttr
1016 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
1017   return Base::get(type.getContext(), type, values, (values.size() == 1));
1018 }
1019 
1020 //===----------------------------------------------------------------------===//
1021 // DenseIntOrFPElementsAttr
1022 //===----------------------------------------------------------------------===//
1023 
1024 /// Utility method to write a range of APInt values to a buffer.
1025 template <typename APRangeT>
1026 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1027                                 APRangeT &&values) {
1028   data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
1029   size_t offset = 0;
1030   for (auto it = values.begin(), e = values.end(); it != e;
1031        ++it, offset += storageWidth) {
1032     assert((*it).getBitWidth() <= storageWidth);
1033     writeBits(data.data(), offset, *it);
1034   }
1035 }
1036 
1037 /// Constructs a dense elements attribute from an array of raw APFloat values.
1038 /// Each APFloat value is expected to have the same bitwidth as the element
1039 /// type of 'type'. 'type' must be a vector or tensor with static shape.
1040 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1041                                                    size_t storageWidth,
1042                                                    ArrayRef<APFloat> values,
1043                                                    bool isSplat) {
1044   std::vector<char> data;
1045   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1046   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1047   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1048 }
1049 
1050 /// Constructs a dense elements attribute from an array of raw APInt values.
1051 /// Each APInt value is expected to have the same bitwidth as the element type
1052 /// of 'type'.
1053 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1054                                                    size_t storageWidth,
1055                                                    ArrayRef<APInt> values,
1056                                                    bool isSplat) {
1057   std::vector<char> data;
1058   writeAPIntsToBuffer(storageWidth, data, values);
1059   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1060 }
1061 
1062 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1063                                                    ArrayRef<char> data,
1064                                                    bool isSplat) {
1065   assert((type.isa<RankedTensorType, VectorType>()) &&
1066          "type must be ranked tensor or vector");
1067   assert(type.hasStaticShape() && "type must have static shape");
1068   return Base::get(type.getContext(), type, data, isSplat);
1069 }
1070 
1071 /// Overload of the raw 'get' method that asserts that the given type is of
1072 /// complex type. This method is used to verify type invariants that the
1073 /// templatized 'get' method cannot.
1074 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1075                                                           ArrayRef<char> data,
1076                                                           int64_t dataEltSize,
1077                                                           bool isInt,
1078                                                           bool isSigned) {
1079   assert(::isValidIntOrFloat(
1080       type.getElementType().cast<ComplexType>().getElementType(),
1081       dataEltSize / 2, isInt, isSigned));
1082 
1083   int64_t numElements = data.size() / dataEltSize;
1084   assert(numElements == 1 || numElements == type.getNumElements());
1085   return getRaw(type, data, /*isSplat=*/numElements == 1);
1086 }
1087 
1088 /// Overload of the 'getRaw' method that asserts that the given type is of
1089 /// integer type. This method is used to verify type invariants that the
1090 /// templatized 'get' method cannot.
1091 DenseElementsAttr
1092 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1093                                            int64_t dataEltSize, bool isInt,
1094                                            bool isSigned) {
1095   assert(
1096       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1097 
1098   int64_t numElements = data.size() / dataEltSize;
1099   assert(numElements == 1 || numElements == type.getNumElements());
1100   return getRaw(type, data, /*isSplat=*/numElements == 1);
1101 }
1102 
1103 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1104     const char *inRawData, char *outRawData, size_t elementBitWidth,
1105     size_t numElements) {
1106   using llvm::support::ulittle16_t;
1107   using llvm::support::ulittle32_t;
1108   using llvm::support::ulittle64_t;
1109 
1110   assert(llvm::support::endian::system_endianness() == // NOLINT
1111          llvm::support::endianness::big);              // NOLINT
1112   // NOLINT to avoid warning message about replacing by static_assert()
1113 
1114   // Following std::copy_n always converts endianness on BE machine.
1115   switch (elementBitWidth) {
1116   case 16: {
1117     const ulittle16_t *inRawDataPos =
1118         reinterpret_cast<const ulittle16_t *>(inRawData);
1119     uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1120     std::copy_n(inRawDataPos, numElements, outDataPos);
1121     break;
1122   }
1123   case 32: {
1124     const ulittle32_t *inRawDataPos =
1125         reinterpret_cast<const ulittle32_t *>(inRawData);
1126     uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1127     std::copy_n(inRawDataPos, numElements, outDataPos);
1128     break;
1129   }
1130   case 64: {
1131     const ulittle64_t *inRawDataPos =
1132         reinterpret_cast<const ulittle64_t *>(inRawData);
1133     uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1134     std::copy_n(inRawDataPos, numElements, outDataPos);
1135     break;
1136   }
1137   default: {
1138     size_t nBytes = elementBitWidth / CHAR_BIT;
1139     for (size_t i = 0; i < nBytes; i++)
1140       std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1141     break;
1142   }
1143   }
1144 }
1145 
1146 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1147     ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1148     ShapedType type) {
1149   size_t numElements = type.getNumElements();
1150   Type elementType = type.getElementType();
1151   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1152     elementType = complexTy.getElementType();
1153     numElements = numElements * 2;
1154   }
1155   size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1156   assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1157          inRawData.size() <= outRawData.size());
1158   convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1159                                   elementBitWidth, numElements);
1160 }
1161 
1162 //===----------------------------------------------------------------------===//
1163 // DenseFPElementsAttr
1164 //===----------------------------------------------------------------------===//
1165 
1166 template <typename Fn, typename Attr>
1167 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1168                                 Type newElementType,
1169                                 llvm::SmallVectorImpl<char> &data) {
1170   size_t bitWidth = getDenseElementBitWidth(newElementType);
1171   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1172 
1173   ShapedType newArrayType;
1174   if (inType.isa<RankedTensorType>())
1175     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1176   else if (inType.isa<UnrankedTensorType>())
1177     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1178   else if (inType.isa<VectorType>())
1179     newArrayType = VectorType::get(inType.getShape(), newElementType);
1180   else
1181     assert(newArrayType && "Unhandled tensor type");
1182 
1183   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1184   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1185 
1186   // Functor used to process a single element value of the attribute.
1187   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1188     auto newInt = mapping(value);
1189     assert(newInt.getBitWidth() == bitWidth);
1190     writeBits(data.data(), index * storageBitWidth, newInt);
1191   };
1192 
1193   // Check for the splat case.
1194   if (attr.isSplat()) {
1195     processElt(*attr.begin(), /*index=*/0);
1196     return newArrayType;
1197   }
1198 
1199   // Otherwise, process all of the element values.
1200   uint64_t elementIdx = 0;
1201   for (auto value : attr)
1202     processElt(value, elementIdx++);
1203   return newArrayType;
1204 }
1205 
1206 DenseElementsAttr DenseFPElementsAttr::mapValues(
1207     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1208   llvm::SmallVector<char, 8> elementData;
1209   auto newArrayType =
1210       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1211 
1212   return getRaw(newArrayType, elementData, isSplat());
1213 }
1214 
1215 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1216 bool DenseFPElementsAttr::classof(Attribute attr) {
1217   return attr.isa<DenseElementsAttr>() &&
1218          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1219 }
1220 
1221 //===----------------------------------------------------------------------===//
1222 // DenseIntElementsAttr
1223 //===----------------------------------------------------------------------===//
1224 
1225 DenseElementsAttr DenseIntElementsAttr::mapValues(
1226     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1227   llvm::SmallVector<char, 8> elementData;
1228   auto newArrayType =
1229       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1230 
1231   return getRaw(newArrayType, elementData, isSplat());
1232 }
1233 
1234 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1235 bool DenseIntElementsAttr::classof(Attribute attr) {
1236   return attr.isa<DenseElementsAttr>() &&
1237          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1238 }
1239 
1240 //===----------------------------------------------------------------------===//
1241 // OpaqueElementsAttr
1242 //===----------------------------------------------------------------------===//
1243 
1244 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
1245                                            StringRef bytes) {
1246   assert(TensorType::isValidElementType(type.getElementType()) &&
1247          "Input element type should be a valid tensor element type");
1248   return Base::get(type.getContext(), type, dialect, bytes);
1249 }
1250 
1251 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
1252 
1253 /// Return the value at the given index. If index does not refer to a valid
1254 /// element, then a null attribute is returned.
1255 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1256   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1257   return Attribute();
1258 }
1259 
1260 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
1261 
1262 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1263   auto *d = getDialect();
1264   if (!d)
1265     return true;
1266   auto *interface =
1267       d->getRegisteredInterface<DialectDecodeAttributesInterface>();
1268   if (!interface)
1269     return true;
1270   return failed(interface->decode(*this, result));
1271 }
1272 
1273 //===----------------------------------------------------------------------===//
1274 // SparseElementsAttr
1275 //===----------------------------------------------------------------------===//
1276 
1277 SparseElementsAttr SparseElementsAttr::get(ShapedType type,
1278                                            DenseElementsAttr indices,
1279                                            DenseElementsAttr values) {
1280   assert(indices.getType().getElementType().isInteger(64) &&
1281          "expected sparse indices to be 64-bit integer values");
1282   assert((type.isa<RankedTensorType, VectorType>()) &&
1283          "type must be ranked tensor or vector");
1284   assert(type.hasStaticShape() && "type must have static shape");
1285   return Base::get(type.getContext(), type,
1286                    indices.cast<DenseIntElementsAttr>(), values);
1287 }
1288 
1289 DenseIntElementsAttr SparseElementsAttr::getIndices() const {
1290   return getImpl()->indices;
1291 }
1292 
1293 DenseElementsAttr SparseElementsAttr::getValues() const {
1294   return getImpl()->values;
1295 }
1296 
1297 /// Return the value of the element at the given index.
1298 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1299   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1300   auto type = getType();
1301 
1302   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1303   // as a 1-D index array.
1304   auto sparseIndices = getIndices();
1305   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1306 
1307   // Check to see if the indices are a splat.
1308   if (sparseIndices.isSplat()) {
1309     // If the index is also not a splat of the index value, we know that the
1310     // value is zero.
1311     auto splatIndex = *sparseIndexValues.begin();
1312     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
1313       return getZeroAttr();
1314 
1315     // If the indices are a splat, we also expect the values to be a splat.
1316     assert(getValues().isSplat() && "expected splat values");
1317     return getValues().getSplatValue();
1318   }
1319 
1320   // Build a mapping between known indices and the offset of the stored element.
1321   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
1322   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1323   size_t rank = type.getRank();
1324   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1325     mappedIndices.try_emplace(
1326         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1327 
1328   // Look for the provided index key within the mapped indices. If the provided
1329   // index is not found, then return a zero attribute.
1330   auto it = mappedIndices.find(index);
1331   if (it == mappedIndices.end())
1332     return getZeroAttr();
1333 
1334   // Otherwise, return the held sparse value element.
1335   return getValues().getValue(it->second);
1336 }
1337 
1338 /// Get a zero APFloat for the given sparse attribute.
1339 APFloat SparseElementsAttr::getZeroAPFloat() const {
1340   auto eltType = getType().getElementType().cast<FloatType>();
1341   return APFloat(eltType.getFloatSemantics());
1342 }
1343 
1344 /// Get a zero APInt for the given sparse attribute.
1345 APInt SparseElementsAttr::getZeroAPInt() const {
1346   auto eltType = getType().getElementType().cast<IntegerType>();
1347   return APInt::getNullValue(eltType.getWidth());
1348 }
1349 
1350 /// Get a zero attribute for the given attribute type.
1351 Attribute SparseElementsAttr::getZeroAttr() const {
1352   auto eltType = getType().getElementType();
1353 
1354   // Handle floating point elements.
1355   if (eltType.isa<FloatType>())
1356     return FloatAttr::get(eltType, 0);
1357 
1358   // Otherwise, this is an integer.
1359   // TODO: Handle StringAttr here.
1360   return IntegerAttr::get(eltType, 0);
1361 }
1362 
1363 /// Flatten, and return, all of the sparse indices in this attribute in
1364 /// row-major order.
1365 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1366   std::vector<ptrdiff_t> flatSparseIndices;
1367 
1368   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1369   // as a 1-D index array.
1370   auto sparseIndices = getIndices();
1371   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1372   if (sparseIndices.isSplat()) {
1373     SmallVector<uint64_t, 8> indices(getType().getRank(),
1374                                      *sparseIndexValues.begin());
1375     flatSparseIndices.push_back(getFlattenedIndex(indices));
1376     return flatSparseIndices;
1377   }
1378 
1379   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1380   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1381   size_t rank = getType().getRank();
1382   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1383     flatSparseIndices.push_back(getFlattenedIndex(
1384         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1385   return flatSparseIndices;
1386 }
1387