1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Diagnostics.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/IntegerSet.h"
15 #include "mlir/IR/Types.h"
16 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Endian.h"
20 
21 using namespace mlir;
22 using namespace mlir::detail;
23 
24 //===----------------------------------------------------------------------===//
25 // AffineMapAttr
26 //===----------------------------------------------------------------------===//
27 
28 AffineMapAttr AffineMapAttr::get(AffineMap value) {
29   return Base::get(value.getContext(), value);
30 }
31 
32 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
33 
34 //===----------------------------------------------------------------------===//
35 // ArrayAttr
36 //===----------------------------------------------------------------------===//
37 
38 ArrayAttr ArrayAttr::get(MLIRContext *context, ArrayRef<Attribute> value) {
39   return Base::get(context, value);
40 }
41 
42 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
43 
44 Attribute ArrayAttr::operator[](unsigned idx) const {
45   assert(idx < size() && "index out of bounds");
46   return getValue()[idx];
47 }
48 
49 //===----------------------------------------------------------------------===//
50 // DictionaryAttr
51 //===----------------------------------------------------------------------===//
52 
53 /// Helper function that does either an in place sort or sorts from source array
54 /// into destination. If inPlace then storage is both the source and the
55 /// destination, else value is the source and storage destination. Returns
56 /// whether source was sorted.
57 template <bool inPlace>
58 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
59                                SmallVectorImpl<NamedAttribute> &storage) {
60   // Specialize for the common case.
61   switch (value.size()) {
62   case 0:
63     // Zero already sorted.
64     break;
65   case 1:
66     // One already sorted but may need to be copied.
67     if (!inPlace)
68       storage.assign({value[0]});
69     break;
70   case 2: {
71     bool isSorted = value[0] < value[1];
72     if (inPlace) {
73       if (!isSorted)
74         std::swap(storage[0], storage[1]);
75     } else if (isSorted) {
76       storage.assign({value[0], value[1]});
77     } else {
78       storage.assign({value[1], value[0]});
79     }
80     return !isSorted;
81   }
82   default:
83     if (!inPlace)
84       storage.assign(value.begin(), value.end());
85     // Check to see they are sorted already.
86     bool isSorted = llvm::is_sorted(value);
87     if (!isSorted) {
88       // If not, do a general sort.
89       llvm::array_pod_sort(storage.begin(), storage.end());
90       value = storage;
91     }
92     return !isSorted;
93   }
94   return false;
95 }
96 
97 /// Returns an entry with a duplicate name from the given sorted array of named
98 /// attributes. Returns llvm::None if all elements have unique names.
99 static Optional<NamedAttribute>
100 findDuplicateElement(ArrayRef<NamedAttribute> value) {
101   const Optional<NamedAttribute> none{llvm::None};
102   if (value.size() < 2)
103     return none;
104 
105   if (value.size() == 2)
106     return value[0].first == value[1].first ? value[0] : none;
107 
108   auto it = std::adjacent_find(
109       value.begin(), value.end(),
110       [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; });
111   return it != value.end() ? *it : none;
112 }
113 
114 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
115                           SmallVectorImpl<NamedAttribute> &storage) {
116   bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
117   assert(!findDuplicateElement(storage) &&
118          "DictionaryAttr element names must be unique");
119   return isSorted;
120 }
121 
122 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
123   bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
124   assert(!findDuplicateElement(array) &&
125          "DictionaryAttr element names must be unique");
126   return isSorted;
127 }
128 
129 Optional<NamedAttribute>
130 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
131                               bool isSorted) {
132   if (!isSorted)
133     dictionaryAttrSort</*inPlace=*/true>(array, array);
134   return findDuplicateElement(array);
135 }
136 
137 DictionaryAttr DictionaryAttr::get(MLIRContext *context,
138                                    ArrayRef<NamedAttribute> value) {
139   if (value.empty())
140     return DictionaryAttr::getEmpty(context);
141   assert(llvm::all_of(value,
142                       [](const NamedAttribute &attr) { return attr.second; }) &&
143          "value cannot have null entries");
144 
145   // We need to sort the element list to canonicalize it.
146   SmallVector<NamedAttribute, 8> storage;
147   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
148     value = storage;
149   assert(!findDuplicateElement(value) &&
150          "DictionaryAttr element names must be unique");
151   return Base::get(context, value);
152 }
153 /// Construct a dictionary with an array of values that is known to already be
154 /// sorted by name and uniqued.
155 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
156                                              MLIRContext *context) {
157   if (value.empty())
158     return DictionaryAttr::getEmpty(context);
159   // Ensure that the attribute elements are unique and sorted.
160   assert(llvm::is_sorted(value,
161                          [](NamedAttribute l, NamedAttribute r) {
162                            return l.first.strref() < r.first.strref();
163                          }) &&
164          "expected attribute values to be sorted");
165   assert(!findDuplicateElement(value) &&
166          "DictionaryAttr element names must be unique");
167   return Base::get(context, value);
168 }
169 
170 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
171   return getImpl()->getElements();
172 }
173 
174 /// Return the specified attribute if present, null otherwise.
175 Attribute DictionaryAttr::get(StringRef name) const {
176   Optional<NamedAttribute> attr = getNamed(name);
177   return attr ? attr->second : nullptr;
178 }
179 Attribute DictionaryAttr::get(Identifier name) const {
180   Optional<NamedAttribute> attr = getNamed(name);
181   return attr ? attr->second : nullptr;
182 }
183 
184 /// Return the specified named attribute if present, None otherwise.
185 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
186   ArrayRef<NamedAttribute> values = getValue();
187   const auto *it = llvm::lower_bound(values, name);
188   return it != values.end() && it->first == name ? *it
189                                                  : Optional<NamedAttribute>();
190 }
191 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
192   for (auto elt : getValue())
193     if (elt.first == name)
194       return elt;
195   return llvm::None;
196 }
197 
198 DictionaryAttr::iterator DictionaryAttr::begin() const {
199   return getValue().begin();
200 }
201 DictionaryAttr::iterator DictionaryAttr::end() const {
202   return getValue().end();
203 }
204 size_t DictionaryAttr::size() const { return getValue().size(); }
205 
206 //===----------------------------------------------------------------------===//
207 // FloatAttr
208 //===----------------------------------------------------------------------===//
209 
210 FloatAttr FloatAttr::get(Type type, double value) {
211   return Base::get(type.getContext(), type, value);
212 }
213 
214 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
215   return Base::getChecked(loc, type, value);
216 }
217 
218 FloatAttr FloatAttr::get(Type type, const APFloat &value) {
219   return Base::get(type.getContext(), type, value);
220 }
221 
222 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
223   return Base::getChecked(loc, type, value);
224 }
225 
226 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
227 
228 double FloatAttr::getValueAsDouble() const {
229   return getValueAsDouble(getValue());
230 }
231 double FloatAttr::getValueAsDouble(APFloat value) {
232   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
233     bool losesInfo = false;
234     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
235                   &losesInfo);
236   }
237   return value.convertToDouble();
238 }
239 
240 /// Verify construction invariants.
241 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
242   if (!type.isa<FloatType>())
243     return emitError(loc, "expected floating point type");
244   return success();
245 }
246 
247 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
248                                                       double value) {
249   return verifyFloatTypeInvariants(loc, type);
250 }
251 
252 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
253                                                       const APFloat &value) {
254   // Verify that the type is correct.
255   if (failed(verifyFloatTypeInvariants(loc, type)))
256     return failure();
257 
258   // Verify that the type semantics match that of the value.
259   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
260     return emitError(
261         loc, "FloatAttr type doesn't match the type implied by its value");
262   }
263   return success();
264 }
265 
266 //===----------------------------------------------------------------------===//
267 // SymbolRefAttr
268 //===----------------------------------------------------------------------===//
269 
270 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
271   return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
272 }
273 
274 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
275                                  ArrayRef<FlatSymbolRefAttr> nestedReferences) {
276   return Base::get(ctx, value, nestedReferences);
277 }
278 
279 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
280 
281 StringRef SymbolRefAttr::getLeafReference() const {
282   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
283   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
284 }
285 
286 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
287   return getImpl()->getNestedRefs();
288 }
289 
290 //===----------------------------------------------------------------------===//
291 // IntegerAttr
292 //===----------------------------------------------------------------------===//
293 
294 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
295   if (type.isSignlessInteger(1))
296     return BoolAttr::get(type.getContext(), value.getBoolValue());
297   return Base::get(type.getContext(), type, value);
298 }
299 
300 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
301   // This uses 64 bit APInts by default for index type.
302   if (type.isIndex())
303     return get(type, APInt(IndexType::kInternalStorageBitWidth, value));
304 
305   auto intType = type.cast<IntegerType>();
306   return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
307 }
308 
309 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
310 
311 int64_t IntegerAttr::getInt() const {
312   assert((getImpl()->getType().isIndex() ||
313           getImpl()->getType().isSignlessInteger()) &&
314          "must be signless integer");
315   return getValue().getSExtValue();
316 }
317 
318 int64_t IntegerAttr::getSInt() const {
319   assert(getImpl()->getType().isSignedInteger() && "must be signed integer");
320   return getValue().getSExtValue();
321 }
322 
323 uint64_t IntegerAttr::getUInt() const {
324   assert(getImpl()->getType().isUnsignedInteger() &&
325          "must be unsigned integer");
326   return getValue().getZExtValue();
327 }
328 
329 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
330   if (type.isa<IntegerType, IndexType>())
331     return success();
332   return emitError(loc, "expected integer or index type");
333 }
334 
335 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
336                                                         int64_t value) {
337   return verifyIntegerTypeInvariants(loc, type);
338 }
339 
340 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
341                                                         const APInt &value) {
342   if (failed(verifyIntegerTypeInvariants(loc, type)))
343     return failure();
344   if (auto integerType = type.dyn_cast<IntegerType>())
345     if (integerType.getWidth() != value.getBitWidth())
346       return emitError(loc, "integer type bit width (")
347              << integerType.getWidth() << ") doesn't match value bit width ("
348              << value.getBitWidth() << ")";
349   return success();
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // BoolAttr
354 
355 bool BoolAttr::getValue() const {
356   auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl);
357   return storage->getValue().getBoolValue();
358 }
359 
360 bool BoolAttr::classof(Attribute attr) {
361   IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
362   return intAttr && intAttr.getType().isSignlessInteger(1);
363 }
364 
365 //===----------------------------------------------------------------------===//
366 // IntegerSetAttr
367 //===----------------------------------------------------------------------===//
368 
369 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
370   return Base::get(value.getConstraint(0).getContext(), value);
371 }
372 
373 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
374 
375 //===----------------------------------------------------------------------===//
376 // OpaqueAttr
377 //===----------------------------------------------------------------------===//
378 
379 OpaqueAttr OpaqueAttr::get(MLIRContext *context, Identifier dialect,
380                            StringRef attrData, Type type) {
381   return Base::get(context, dialect, attrData, type);
382 }
383 
384 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
385                                   Type type, Location location) {
386   return Base::getChecked(location, dialect, attrData, type);
387 }
388 
389 /// Returns the dialect namespace of the opaque attribute.
390 Identifier OpaqueAttr::getDialectNamespace() const {
391   return getImpl()->dialectNamespace;
392 }
393 
394 /// Returns the raw attribute data of the opaque attribute.
395 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
396 
397 /// Verify the construction of an opaque attribute.
398 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
399                                                        Identifier dialect,
400                                                        StringRef attrData,
401                                                        Type type) {
402   if (!Dialect::isValidNamespace(dialect.strref()))
403     return emitError(loc, "invalid dialect namespace '") << dialect << "'";
404   return success();
405 }
406 
407 //===----------------------------------------------------------------------===//
408 // StringAttr
409 //===----------------------------------------------------------------------===//
410 
411 StringAttr StringAttr::get(MLIRContext *context, StringRef bytes) {
412   return get(bytes, NoneType::get(context));
413 }
414 
415 /// Get an instance of a StringAttr with the given string and Type.
416 StringAttr StringAttr::get(StringRef bytes, Type type) {
417   return Base::get(type.getContext(), bytes, type);
418 }
419 
420 StringRef StringAttr::getValue() const { return getImpl()->value; }
421 
422 //===----------------------------------------------------------------------===//
423 // TypeAttr
424 //===----------------------------------------------------------------------===//
425 
426 TypeAttr TypeAttr::get(Type value) {
427   return Base::get(value.getContext(), value);
428 }
429 
430 Type TypeAttr::getValue() const { return getImpl()->value; }
431 
432 //===----------------------------------------------------------------------===//
433 // ElementsAttr
434 //===----------------------------------------------------------------------===//
435 
436 ShapedType ElementsAttr::getType() const {
437   return Attribute::getType().cast<ShapedType>();
438 }
439 
440 /// Returns the number of elements held by this attribute.
441 int64_t ElementsAttr::getNumElements() const {
442   return getType().getNumElements();
443 }
444 
445 /// Return the value at the given index. If index does not refer to a valid
446 /// element, then a null attribute is returned.
447 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
448   if (auto denseAttr = dyn_cast<DenseElementsAttr>())
449     return denseAttr.getValue(index);
450   if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>())
451     return opaqueAttr.getValue(index);
452   return cast<SparseElementsAttr>().getValue(index);
453 }
454 
455 /// Return if the given 'index' refers to a valid element in this attribute.
456 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
457   auto type = getType();
458 
459   // Verify that the rank of the indices matches the held type.
460   auto rank = type.getRank();
461   if (rank == 0 && index.size() == 1 && index[0] == 0)
462     return true;
463   if (rank != static_cast<int64_t>(index.size()))
464     return false;
465 
466   // Verify that all of the indices are within the shape dimensions.
467   auto shape = type.getShape();
468   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
469     return static_cast<int64_t>(index[i]) < shape[i];
470   });
471 }
472 
473 ElementsAttr
474 ElementsAttr::mapValues(Type newElementType,
475                         function_ref<APInt(const APInt &)> mapping) const {
476   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
477     return intOrFpAttr.mapValues(newElementType, mapping);
478   llvm_unreachable("unsupported ElementsAttr subtype");
479 }
480 
481 ElementsAttr
482 ElementsAttr::mapValues(Type newElementType,
483                         function_ref<APInt(const APFloat &)> mapping) const {
484   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
485     return intOrFpAttr.mapValues(newElementType, mapping);
486   llvm_unreachable("unsupported ElementsAttr subtype");
487 }
488 
489 /// Method for support type inquiry through isa, cast and dyn_cast.
490 bool ElementsAttr::classof(Attribute attr) {
491   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr,
492                   OpaqueElementsAttr, SparseElementsAttr>();
493 }
494 
495 /// Returns the 1 dimensional flattened row-major index from the given
496 /// multi-dimensional index.
497 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
498   assert(isValidIndex(index) && "expected valid multi-dimensional index");
499   auto type = getType();
500 
501   // Reduce the provided multidimensional index into a flattended 1D row-major
502   // index.
503   auto rank = type.getRank();
504   auto shape = type.getShape();
505   uint64_t valueIndex = 0;
506   uint64_t dimMultiplier = 1;
507   for (int i = rank - 1; i >= 0; --i) {
508     valueIndex += index[i] * dimMultiplier;
509     dimMultiplier *= shape[i];
510   }
511   return valueIndex;
512 }
513 
514 //===----------------------------------------------------------------------===//
515 // DenseElementsAttr Utilities
516 //===----------------------------------------------------------------------===//
517 
518 /// Get the bitwidth of a dense element type within the buffer.
519 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
520 static size_t getDenseElementStorageWidth(size_t origWidth) {
521   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
522 }
523 static size_t getDenseElementStorageWidth(Type elementType) {
524   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
525 }
526 
527 /// Set a bit to a specific value.
528 static void setBit(char *rawData, size_t bitPos, bool value) {
529   if (value)
530     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
531   else
532     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
533 }
534 
535 /// Return the value of the specified bit.
536 static bool getBit(const char *rawData, size_t bitPos) {
537   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
538 }
539 
540 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
541 /// BE format.
542 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
543                                          char *result) {
544   assert(llvm::support::endian::system_endianness() == // NOLINT
545          llvm::support::endianness::big);              // NOLINT
546   assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
547 
548   // Copy the words filled with data.
549   // For example, when `value` has 2 words, the first word is filled with data.
550   // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
551   size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
552   std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
553               numFilledWords, result);
554   // Convert last word of APInt to LE format and store it in char
555   // array(`valueLE`).
556   // ex. last word of `value` (BE): |------ij|  ==> `valueLE` (LE): |ji------|
557   size_t lastWordPos = numFilledWords;
558   SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
559   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
560       reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
561       valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
562   // Extract actual APInt data from `valueLE`, convert endianness to BE format,
563   // and store it in `result`.
564   // ex. `valueLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|ij|
565   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
566       valueLE.begin(), result + lastWordPos,
567       (numBytes - lastWordPos) * CHAR_BIT, 1);
568 }
569 
570 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
571 /// format.
572 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
573                                          APInt &result) {
574   assert(llvm::support::endian::system_endianness() == // NOLINT
575          llvm::support::endianness::big);              // NOLINT
576   assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
577 
578   // Copy the data that fills the word of `result` from `inArray`.
579   // For example, when `result` has 2 words, the first word will be filled with
580   // data. So, the first 8 bytes are copied from `inArray` here.
581   // `inArray` (10 bytes, BE): |abcdefgh|ij|
582   //                     ==> `result` (2 words, BE): |abcdefgh|--------|
583   size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
584   std::copy_n(
585       inArray, numFilledWords,
586       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
587 
588   // Convert array data which will be last word of `result` to LE format, and
589   // store it in char array(`inArrayLE`).
590   // ex. `inArray` (last two bytes, BE): |ij|  ==> `inArrayLE` (LE): |ji------|
591   size_t lastWordPos = numFilledWords;
592   SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
593   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
594       inArray + lastWordPos, inArrayLE.begin(),
595       (numBytes - lastWordPos) * CHAR_BIT, 1);
596 
597   // Convert `inArrayLE` to BE format, and store it in last word of `result`.
598   // ex. `inArrayLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|------ij|
599   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
600       inArrayLE.begin(),
601       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
602           lastWordPos,
603       APInt::APINT_BITS_PER_WORD, 1);
604 }
605 
606 /// Writes value to the bit position `bitPos` in array `rawData`.
607 static void writeBits(char *rawData, size_t bitPos, APInt value) {
608   size_t bitWidth = value.getBitWidth();
609 
610   // If the bitwidth is 1 we just toggle the specific bit.
611   if (bitWidth == 1)
612     return setBit(rawData, bitPos, value.isOneValue());
613 
614   // Otherwise, the bit position is guaranteed to be byte aligned.
615   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
616   if (llvm::support::endian::system_endianness() ==
617       llvm::support::endianness::big) {
618     // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
619     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
620     // work correctly in BE format.
621     // ex. `value` (2 words including 10 bytes)
622     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------|
623     copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
624                                  rawData + (bitPos / CHAR_BIT));
625   } else {
626     std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
627                 llvm::divideCeil(bitWidth, CHAR_BIT),
628                 rawData + (bitPos / CHAR_BIT));
629   }
630 }
631 
632 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
633 /// `rawData`.
634 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
635   // Handle a boolean bit position.
636   if (bitWidth == 1)
637     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
638 
639   // Otherwise, the bit position must be 8-bit aligned.
640   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
641   APInt result(bitWidth, 0);
642   if (llvm::support::endian::system_endianness() ==
643       llvm::support::endianness::big) {
644     // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
645     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
646     // work correctly in BE format.
647     // ex. `result` (2 words including 10 bytes)
648     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------| This function
649     copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
650                                  llvm::divideCeil(bitWidth, CHAR_BIT), result);
651   } else {
652     std::copy_n(rawData + (bitPos / CHAR_BIT),
653                 llvm::divideCeil(bitWidth, CHAR_BIT),
654                 const_cast<char *>(
655                     reinterpret_cast<const char *>(result.getRawData())));
656   }
657   return result;
658 }
659 
660 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
661 /// the same element count as 'type'.
662 template <typename Values>
663 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
664   return (values.size() == 1) ||
665          (type.getNumElements() == static_cast<int64_t>(values.size()));
666 }
667 
668 //===----------------------------------------------------------------------===//
669 // DenseElementsAttr Iterators
670 //===----------------------------------------------------------------------===//
671 
672 //===----------------------------------------------------------------------===//
673 // AttributeElementIterator
674 
675 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
676     DenseElementsAttr attr, size_t index)
677     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
678                                       Attribute, Attribute, Attribute>(
679           attr.getAsOpaquePointer(), index) {}
680 
681 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
682   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
683   Type eltTy = owner.getType().getElementType();
684   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
685     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
686   if (eltTy.isa<IndexType>())
687     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
688   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
689     IntElementIterator intIt(owner, index);
690     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
691     return FloatAttr::get(eltTy, *floatIt);
692   }
693   if (owner.isa<DenseStringElementsAttr>()) {
694     ArrayRef<StringRef> vals = owner.getRawStringData();
695     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
696   }
697   llvm_unreachable("unexpected element type");
698 }
699 
700 //===----------------------------------------------------------------------===//
701 // BoolElementIterator
702 
703 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
704     DenseElementsAttr attr, size_t dataIndex)
705     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
706           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
707 
708 bool DenseElementsAttr::BoolElementIterator::operator*() const {
709   return getBit(getData(), getDataIndex());
710 }
711 
712 //===----------------------------------------------------------------------===//
713 // IntElementIterator
714 
715 DenseElementsAttr::IntElementIterator::IntElementIterator(
716     DenseElementsAttr attr, size_t dataIndex)
717     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
718           attr.getRawData().data(), attr.isSplat(), dataIndex),
719       bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
720 
721 APInt DenseElementsAttr::IntElementIterator::operator*() const {
722   return readBits(getData(),
723                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
724                   bitWidth);
725 }
726 
727 //===----------------------------------------------------------------------===//
728 // ComplexIntElementIterator
729 
730 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
731     DenseElementsAttr attr, size_t dataIndex)
732     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
733                                       std::complex<APInt>, std::complex<APInt>,
734                                       std::complex<APInt>>(
735           attr.getRawData().data(), attr.isSplat(), dataIndex) {
736   auto complexType = attr.getType().getElementType().cast<ComplexType>();
737   bitWidth = getDenseElementBitWidth(complexType.getElementType());
738 }
739 
740 std::complex<APInt>
741 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
742   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
743   size_t offset = getDataIndex() * storageWidth * 2;
744   return {readBits(getData(), offset, bitWidth),
745           readBits(getData(), offset + storageWidth, bitWidth)};
746 }
747 
748 //===----------------------------------------------------------------------===//
749 // FloatElementIterator
750 
751 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
752     const llvm::fltSemantics &smt, IntElementIterator it)
753     : llvm::mapped_iterator<IntElementIterator,
754                             std::function<APFloat(const APInt &)>>(
755           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
756 
757 //===----------------------------------------------------------------------===//
758 // ComplexFloatElementIterator
759 
760 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
761     const llvm::fltSemantics &smt, ComplexIntElementIterator it)
762     : llvm::mapped_iterator<
763           ComplexIntElementIterator,
764           std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
765           it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
766             return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
767           }) {}
768 
769 //===----------------------------------------------------------------------===//
770 // DenseElementsAttr
771 //===----------------------------------------------------------------------===//
772 
773 /// Method for support type inquiry through isa, cast and dyn_cast.
774 bool DenseElementsAttr::classof(Attribute attr) {
775   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
776 }
777 
778 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
779                                          ArrayRef<Attribute> values) {
780   assert(hasSameElementsOrSplat(type, values));
781 
782   // If the element type is not based on int/float/index, assume it is a string
783   // type.
784   auto eltType = type.getElementType();
785   if (!type.getElementType().isIntOrIndexOrFloat()) {
786     SmallVector<StringRef, 8> stringValues;
787     stringValues.reserve(values.size());
788     for (Attribute attr : values) {
789       assert(attr.isa<StringAttr>() &&
790              "expected string value for non integer/index/float element");
791       stringValues.push_back(attr.cast<StringAttr>().getValue());
792     }
793     return get(type, stringValues);
794   }
795 
796   // Otherwise, get the raw storage width to use for the allocation.
797   size_t bitWidth = getDenseElementBitWidth(eltType);
798   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
799 
800   // Compress the attribute values into a character buffer.
801   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
802                             values.size());
803   APInt intVal;
804   for (unsigned i = 0, e = values.size(); i < e; ++i) {
805     assert(eltType == values[i].getType() &&
806            "expected attribute value to have element type");
807     if (eltType.isa<FloatType>())
808       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
809     else if (eltType.isa<IntegerType>())
810       intVal = values[i].cast<IntegerAttr>().getValue();
811     else
812       llvm_unreachable("unexpected element type");
813 
814     assert(intVal.getBitWidth() == bitWidth &&
815            "expected value to have same bitwidth as element type");
816     writeBits(data.data(), i * storageBitWidth, intVal);
817   }
818   return DenseIntOrFPElementsAttr::getRaw(type, data,
819                                           /*isSplat=*/(values.size() == 1));
820 }
821 
822 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
823                                          ArrayRef<bool> values) {
824   assert(hasSameElementsOrSplat(type, values));
825   assert(type.getElementType().isInteger(1));
826 
827   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
828   for (int i = 0, e = values.size(); i != e; ++i)
829     setBit(buff.data(), i, values[i]);
830   return DenseIntOrFPElementsAttr::getRaw(type, buff,
831                                           /*isSplat=*/(values.size() == 1));
832 }
833 
834 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
835                                          ArrayRef<StringRef> values) {
836   assert(!type.getElementType().isIntOrFloat());
837   return DenseStringElementsAttr::get(type, values);
838 }
839 
840 /// Constructs a dense integer elements attribute from an array of APInt
841 /// values. Each APInt value is expected to have the same bitwidth as the
842 /// element type of 'type'.
843 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
844                                          ArrayRef<APInt> values) {
845   assert(type.getElementType().isIntOrIndex());
846   assert(hasSameElementsOrSplat(type, values));
847   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
848   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
849                                           /*isSplat=*/(values.size() == 1));
850 }
851 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
852                                          ArrayRef<std::complex<APInt>> values) {
853   ComplexType complex = type.getElementType().cast<ComplexType>();
854   assert(complex.getElementType().isa<IntegerType>());
855   assert(hasSameElementsOrSplat(type, values));
856   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
857   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
858                           values.size() * 2);
859   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
860                                           /*isSplat=*/(values.size() == 1));
861 }
862 
863 // Constructs a dense float elements attribute from an array of APFloat
864 // values. Each APFloat value is expected to have the same bitwidth as the
865 // element type of 'type'.
866 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
867                                          ArrayRef<APFloat> values) {
868   assert(type.getElementType().isa<FloatType>());
869   assert(hasSameElementsOrSplat(type, values));
870   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
871   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
872                                           /*isSplat=*/(values.size() == 1));
873 }
874 DenseElementsAttr
875 DenseElementsAttr::get(ShapedType type,
876                        ArrayRef<std::complex<APFloat>> values) {
877   ComplexType complex = type.getElementType().cast<ComplexType>();
878   assert(complex.getElementType().isa<FloatType>());
879   assert(hasSameElementsOrSplat(type, values));
880   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
881                            values.size() * 2);
882   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
883   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
884                                           /*isSplat=*/(values.size() == 1));
885 }
886 
887 /// Construct a dense elements attribute from a raw buffer representing the
888 /// data for this attribute. Users should generally not use this methods as
889 /// the expected buffer format may not be a form the user expects.
890 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
891                                                       ArrayRef<char> rawBuffer,
892                                                       bool isSplatBuffer) {
893   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
894 }
895 
896 /// Returns true if the given buffer is a valid raw buffer for the given type.
897 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
898                                          ArrayRef<char> rawBuffer,
899                                          bool &detectedSplat) {
900   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
901   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
902 
903   // Storage width of 1 is special as it is packed by the bit.
904   if (storageWidth == 1) {
905     // Check for a splat, or a buffer equal to the number of elements.
906     if ((detectedSplat = rawBuffer.size() == 1))
907       return true;
908     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
909   }
910   // All other types are 8-bit aligned.
911   if ((detectedSplat = rawBufferWidth == storageWidth))
912     return true;
913   return rawBufferWidth == (storageWidth * type.getNumElements());
914 }
915 
916 /// Check the information for a C++ data type, check if this type is valid for
917 /// the current attribute. This method is used to verify specific type
918 /// invariants that the templatized 'getValues' method cannot.
919 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
920                               bool isSigned) {
921   // Make sure that the data element size is the same as the type element width.
922   if (getDenseElementBitWidth(type) !=
923       static_cast<size_t>(dataEltSize * CHAR_BIT))
924     return false;
925 
926   // Check that the element type is either float or integer or index.
927   if (!isInt)
928     return type.isa<FloatType>();
929   if (type.isIndex())
930     return true;
931 
932   auto intType = type.dyn_cast<IntegerType>();
933   if (!intType)
934     return false;
935 
936   // Make sure signedness semantics is consistent.
937   if (intType.isSignless())
938     return true;
939   return intType.isSigned() ? isSigned : !isSigned;
940 }
941 
942 /// Defaults down the subclass implementation.
943 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
944                                                    ArrayRef<char> data,
945                                                    int64_t dataEltSize,
946                                                    bool isInt, bool isSigned) {
947   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
948                                                  isSigned);
949 }
950 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
951                                                       ArrayRef<char> data,
952                                                       int64_t dataEltSize,
953                                                       bool isInt,
954                                                       bool isSigned) {
955   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
956                                                     isInt, isSigned);
957 }
958 
959 /// A method used to verify specific type invariants that the templatized 'get'
960 /// method cannot.
961 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
962                                           bool isSigned) const {
963   return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
964                              isSigned);
965 }
966 
967 /// Check the information for a C++ data type, check if this type is valid for
968 /// the current attribute.
969 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
970                                        bool isSigned) const {
971   return ::isValidIntOrFloat(
972       getType().getElementType().cast<ComplexType>().getElementType(),
973       dataEltSize / 2, isInt, isSigned);
974 }
975 
976 /// Returns true if this attribute corresponds to a splat, i.e. if all element
977 /// values are the same.
978 bool DenseElementsAttr::isSplat() const {
979   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
980 }
981 
982 /// Return the held element values as a range of Attributes.
983 auto DenseElementsAttr::getAttributeValues() const
984     -> llvm::iterator_range<AttributeElementIterator> {
985   return {attr_value_begin(), attr_value_end()};
986 }
987 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
988   return AttributeElementIterator(*this, 0);
989 }
990 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
991   return AttributeElementIterator(*this, getNumElements());
992 }
993 
994 /// Return the held element values as a range of bool. The element type of
995 /// this attribute must be of integer type of bitwidth 1.
996 auto DenseElementsAttr::getBoolValues() const
997     -> llvm::iterator_range<BoolElementIterator> {
998   auto eltType = getType().getElementType().dyn_cast<IntegerType>();
999   assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
1000   (void)eltType;
1001   return {BoolElementIterator(*this, 0),
1002           BoolElementIterator(*this, getNumElements())};
1003 }
1004 
1005 /// Return the held element values as a range of APInts. The element type of
1006 /// this attribute must be of integer type.
1007 auto DenseElementsAttr::getIntValues() const
1008     -> llvm::iterator_range<IntElementIterator> {
1009   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
1010   return {raw_int_begin(), raw_int_end()};
1011 }
1012 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
1013   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
1014   return raw_int_begin();
1015 }
1016 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
1017   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
1018   return raw_int_end();
1019 }
1020 auto DenseElementsAttr::getComplexIntValues() const
1021     -> llvm::iterator_range<ComplexIntElementIterator> {
1022   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
1023   (void)eltTy;
1024   assert(eltTy.isa<IntegerType>() && "expected complex integral type");
1025   return {ComplexIntElementIterator(*this, 0),
1026           ComplexIntElementIterator(*this, getNumElements())};
1027 }
1028 
1029 /// Return the held element values as a range of APFloat. The element type of
1030 /// this attribute must be of float type.
1031 auto DenseElementsAttr::getFloatValues() const
1032     -> llvm::iterator_range<FloatElementIterator> {
1033   auto elementType = getType().getElementType().cast<FloatType>();
1034   const auto &elementSemantics = elementType.getFloatSemantics();
1035   return {FloatElementIterator(elementSemantics, raw_int_begin()),
1036           FloatElementIterator(elementSemantics, raw_int_end())};
1037 }
1038 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
1039   return getFloatValues().begin();
1040 }
1041 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
1042   return getFloatValues().end();
1043 }
1044 auto DenseElementsAttr::getComplexFloatValues() const
1045     -> llvm::iterator_range<ComplexFloatElementIterator> {
1046   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
1047   assert(eltTy.isa<FloatType>() && "expected complex float type");
1048   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
1049   return {{semantics, {*this, 0}},
1050           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
1051 }
1052 
1053 /// Return the raw storage data held by this attribute.
1054 ArrayRef<char> DenseElementsAttr::getRawData() const {
1055   return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data;
1056 }
1057 
1058 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1059   return static_cast<DenseStringElementsAttributeStorage *>(impl)->data;
1060 }
1061 
1062 /// Return a new DenseElementsAttr that has the same data as the current
1063 /// attribute, but has been reshaped to 'newType'. The new type must have the
1064 /// same total number of elements as well as element type.
1065 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1066   ShapedType curType = getType();
1067   if (curType == newType)
1068     return *this;
1069 
1070   (void)curType;
1071   assert(newType.getElementType() == curType.getElementType() &&
1072          "expected the same element type");
1073   assert(newType.getNumElements() == curType.getNumElements() &&
1074          "expected the same number of elements");
1075   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
1076 }
1077 
1078 DenseElementsAttr
1079 DenseElementsAttr::mapValues(Type newElementType,
1080                              function_ref<APInt(const APInt &)> mapping) const {
1081   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1082 }
1083 
1084 DenseElementsAttr DenseElementsAttr::mapValues(
1085     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1086   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1087 }
1088 
1089 //===----------------------------------------------------------------------===//
1090 // DenseStringElementsAttr
1091 //===----------------------------------------------------------------------===//
1092 
1093 DenseStringElementsAttr
1094 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
1095   return Base::get(type.getContext(), type, values, (values.size() == 1));
1096 }
1097 
1098 //===----------------------------------------------------------------------===//
1099 // DenseIntOrFPElementsAttr
1100 //===----------------------------------------------------------------------===//
1101 
1102 /// Utility method to write a range of APInt values to a buffer.
1103 template <typename APRangeT>
1104 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1105                                 APRangeT &&values) {
1106   data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
1107   size_t offset = 0;
1108   for (auto it = values.begin(), e = values.end(); it != e;
1109        ++it, offset += storageWidth) {
1110     assert((*it).getBitWidth() <= storageWidth);
1111     writeBits(data.data(), offset, *it);
1112   }
1113 }
1114 
1115 /// Constructs a dense elements attribute from an array of raw APFloat values.
1116 /// Each APFloat value is expected to have the same bitwidth as the element
1117 /// type of 'type'. 'type' must be a vector or tensor with static shape.
1118 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1119                                                    size_t storageWidth,
1120                                                    ArrayRef<APFloat> values,
1121                                                    bool isSplat) {
1122   std::vector<char> data;
1123   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1124   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1125   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1126 }
1127 
1128 /// Constructs a dense elements attribute from an array of raw APInt values.
1129 /// Each APInt value is expected to have the same bitwidth as the element type
1130 /// of 'type'.
1131 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1132                                                    size_t storageWidth,
1133                                                    ArrayRef<APInt> values,
1134                                                    bool isSplat) {
1135   std::vector<char> data;
1136   writeAPIntsToBuffer(storageWidth, data, values);
1137   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1138 }
1139 
1140 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1141                                                    ArrayRef<char> data,
1142                                                    bool isSplat) {
1143   assert((type.isa<RankedTensorType, VectorType>()) &&
1144          "type must be ranked tensor or vector");
1145   assert(type.hasStaticShape() && "type must have static shape");
1146   return Base::get(type.getContext(), type, data, isSplat);
1147 }
1148 
1149 /// Overload of the raw 'get' method that asserts that the given type is of
1150 /// complex type. This method is used to verify type invariants that the
1151 /// templatized 'get' method cannot.
1152 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1153                                                           ArrayRef<char> data,
1154                                                           int64_t dataEltSize,
1155                                                           bool isInt,
1156                                                           bool isSigned) {
1157   assert(::isValidIntOrFloat(
1158       type.getElementType().cast<ComplexType>().getElementType(),
1159       dataEltSize / 2, isInt, isSigned));
1160 
1161   int64_t numElements = data.size() / dataEltSize;
1162   assert(numElements == 1 || numElements == type.getNumElements());
1163   return getRaw(type, data, /*isSplat=*/numElements == 1);
1164 }
1165 
1166 /// Overload of the 'getRaw' method that asserts that the given type is of
1167 /// integer type. This method is used to verify type invariants that the
1168 /// templatized 'get' method cannot.
1169 DenseElementsAttr
1170 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1171                                            int64_t dataEltSize, bool isInt,
1172                                            bool isSigned) {
1173   assert(
1174       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1175 
1176   int64_t numElements = data.size() / dataEltSize;
1177   assert(numElements == 1 || numElements == type.getNumElements());
1178   return getRaw(type, data, /*isSplat=*/numElements == 1);
1179 }
1180 
1181 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1182     const char *inRawData, char *outRawData, size_t elementBitWidth,
1183     size_t numElements) {
1184   using llvm::support::ulittle16_t;
1185   using llvm::support::ulittle32_t;
1186   using llvm::support::ulittle64_t;
1187 
1188   assert(llvm::support::endian::system_endianness() == // NOLINT
1189          llvm::support::endianness::big);              // NOLINT
1190   // NOLINT to avoid warning message about replacing by static_assert()
1191 
1192   // Following std::copy_n always converts endianness on BE machine.
1193   switch (elementBitWidth) {
1194   case 16: {
1195     const ulittle16_t *inRawDataPos =
1196         reinterpret_cast<const ulittle16_t *>(inRawData);
1197     uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1198     std::copy_n(inRawDataPos, numElements, outDataPos);
1199     break;
1200   }
1201   case 32: {
1202     const ulittle32_t *inRawDataPos =
1203         reinterpret_cast<const ulittle32_t *>(inRawData);
1204     uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1205     std::copy_n(inRawDataPos, numElements, outDataPos);
1206     break;
1207   }
1208   case 64: {
1209     const ulittle64_t *inRawDataPos =
1210         reinterpret_cast<const ulittle64_t *>(inRawData);
1211     uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1212     std::copy_n(inRawDataPos, numElements, outDataPos);
1213     break;
1214   }
1215   default: {
1216     size_t nBytes = elementBitWidth / CHAR_BIT;
1217     for (size_t i = 0; i < nBytes; i++)
1218       std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1219     break;
1220   }
1221   }
1222 }
1223 
1224 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1225     ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1226     ShapedType type) {
1227   size_t numElements = type.getNumElements();
1228   Type elementType = type.getElementType();
1229   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1230     elementType = complexTy.getElementType();
1231     numElements = numElements * 2;
1232   }
1233   size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1234   assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1235          inRawData.size() <= outRawData.size());
1236   convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1237                                   elementBitWidth, numElements);
1238 }
1239 
1240 //===----------------------------------------------------------------------===//
1241 // DenseFPElementsAttr
1242 //===----------------------------------------------------------------------===//
1243 
1244 template <typename Fn, typename Attr>
1245 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1246                                 Type newElementType,
1247                                 llvm::SmallVectorImpl<char> &data) {
1248   size_t bitWidth = getDenseElementBitWidth(newElementType);
1249   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1250 
1251   ShapedType newArrayType;
1252   if (inType.isa<RankedTensorType>())
1253     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1254   else if (inType.isa<UnrankedTensorType>())
1255     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1256   else if (inType.isa<VectorType>())
1257     newArrayType = VectorType::get(inType.getShape(), newElementType);
1258   else
1259     assert(newArrayType && "Unhandled tensor type");
1260 
1261   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1262   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1263 
1264   // Functor used to process a single element value of the attribute.
1265   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1266     auto newInt = mapping(value);
1267     assert(newInt.getBitWidth() == bitWidth);
1268     writeBits(data.data(), index * storageBitWidth, newInt);
1269   };
1270 
1271   // Check for the splat case.
1272   if (attr.isSplat()) {
1273     processElt(*attr.begin(), /*index=*/0);
1274     return newArrayType;
1275   }
1276 
1277   // Otherwise, process all of the element values.
1278   uint64_t elementIdx = 0;
1279   for (auto value : attr)
1280     processElt(value, elementIdx++);
1281   return newArrayType;
1282 }
1283 
1284 DenseElementsAttr DenseFPElementsAttr::mapValues(
1285     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1286   llvm::SmallVector<char, 8> elementData;
1287   auto newArrayType =
1288       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1289 
1290   return getRaw(newArrayType, elementData, isSplat());
1291 }
1292 
1293 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1294 bool DenseFPElementsAttr::classof(Attribute attr) {
1295   return attr.isa<DenseElementsAttr>() &&
1296          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1297 }
1298 
1299 //===----------------------------------------------------------------------===//
1300 // DenseIntElementsAttr
1301 //===----------------------------------------------------------------------===//
1302 
1303 DenseElementsAttr DenseIntElementsAttr::mapValues(
1304     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1305   llvm::SmallVector<char, 8> elementData;
1306   auto newArrayType =
1307       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1308 
1309   return getRaw(newArrayType, elementData, isSplat());
1310 }
1311 
1312 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1313 bool DenseIntElementsAttr::classof(Attribute attr) {
1314   return attr.isa<DenseElementsAttr>() &&
1315          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1316 }
1317 
1318 //===----------------------------------------------------------------------===//
1319 // OpaqueElementsAttr
1320 //===----------------------------------------------------------------------===//
1321 
1322 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
1323                                            StringRef bytes) {
1324   assert(TensorType::isValidElementType(type.getElementType()) &&
1325          "Input element type should be a valid tensor element type");
1326   return Base::get(type.getContext(), type, dialect, bytes);
1327 }
1328 
1329 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
1330 
1331 /// Return the value at the given index. If index does not refer to a valid
1332 /// element, then a null attribute is returned.
1333 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1334   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1335   return Attribute();
1336 }
1337 
1338 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
1339 
1340 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1341   auto *d = getDialect();
1342   if (!d)
1343     return true;
1344   auto *interface =
1345       d->getRegisteredInterface<DialectDecodeAttributesInterface>();
1346   if (!interface)
1347     return true;
1348   return failed(interface->decode(*this, result));
1349 }
1350 
1351 //===----------------------------------------------------------------------===//
1352 // SparseElementsAttr
1353 //===----------------------------------------------------------------------===//
1354 
1355 SparseElementsAttr SparseElementsAttr::get(ShapedType type,
1356                                            DenseElementsAttr indices,
1357                                            DenseElementsAttr values) {
1358   assert(indices.getType().getElementType().isInteger(64) &&
1359          "expected sparse indices to be 64-bit integer values");
1360   assert((type.isa<RankedTensorType, VectorType>()) &&
1361          "type must be ranked tensor or vector");
1362   assert(type.hasStaticShape() && "type must have static shape");
1363   return Base::get(type.getContext(), type,
1364                    indices.cast<DenseIntElementsAttr>(), values);
1365 }
1366 
1367 DenseIntElementsAttr SparseElementsAttr::getIndices() const {
1368   return getImpl()->indices;
1369 }
1370 
1371 DenseElementsAttr SparseElementsAttr::getValues() const {
1372   return getImpl()->values;
1373 }
1374 
1375 /// Return the value of the element at the given index.
1376 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1377   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1378   auto type = getType();
1379 
1380   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1381   // as a 1-D index array.
1382   auto sparseIndices = getIndices();
1383   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1384 
1385   // Check to see if the indices are a splat.
1386   if (sparseIndices.isSplat()) {
1387     // If the index is also not a splat of the index value, we know that the
1388     // value is zero.
1389     auto splatIndex = *sparseIndexValues.begin();
1390     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
1391       return getZeroAttr();
1392 
1393     // If the indices are a splat, we also expect the values to be a splat.
1394     assert(getValues().isSplat() && "expected splat values");
1395     return getValues().getSplatValue();
1396   }
1397 
1398   // Build a mapping between known indices and the offset of the stored element.
1399   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
1400   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1401   size_t rank = type.getRank();
1402   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1403     mappedIndices.try_emplace(
1404         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1405 
1406   // Look for the provided index key within the mapped indices. If the provided
1407   // index is not found, then return a zero attribute.
1408   auto it = mappedIndices.find(index);
1409   if (it == mappedIndices.end())
1410     return getZeroAttr();
1411 
1412   // Otherwise, return the held sparse value element.
1413   return getValues().getValue(it->second);
1414 }
1415 
1416 /// Get a zero APFloat for the given sparse attribute.
1417 APFloat SparseElementsAttr::getZeroAPFloat() const {
1418   auto eltType = getType().getElementType().cast<FloatType>();
1419   return APFloat(eltType.getFloatSemantics());
1420 }
1421 
1422 /// Get a zero APInt for the given sparse attribute.
1423 APInt SparseElementsAttr::getZeroAPInt() const {
1424   auto eltType = getType().getElementType().cast<IntegerType>();
1425   return APInt::getNullValue(eltType.getWidth());
1426 }
1427 
1428 /// Get a zero attribute for the given attribute type.
1429 Attribute SparseElementsAttr::getZeroAttr() const {
1430   auto eltType = getType().getElementType();
1431 
1432   // Handle floating point elements.
1433   if (eltType.isa<FloatType>())
1434     return FloatAttr::get(eltType, 0);
1435 
1436   // Otherwise, this is an integer.
1437   // TODO: Handle StringAttr here.
1438   return IntegerAttr::get(eltType, 0);
1439 }
1440 
1441 /// Flatten, and return, all of the sparse indices in this attribute in
1442 /// row-major order.
1443 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1444   std::vector<ptrdiff_t> flatSparseIndices;
1445 
1446   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1447   // as a 1-D index array.
1448   auto sparseIndices = getIndices();
1449   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1450   if (sparseIndices.isSplat()) {
1451     SmallVector<uint64_t, 8> indices(getType().getRank(),
1452                                      *sparseIndexValues.begin());
1453     flatSparseIndices.push_back(getFlattenedIndex(indices));
1454     return flatSparseIndices;
1455   }
1456 
1457   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1458   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1459   size_t rank = getType().getRank();
1460   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1461     flatSparseIndices.push_back(getFlattenedIndex(
1462         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1463   return flatSparseIndices;
1464 }
1465