1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Attributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Types.h"
17 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/Twine.h"
20 #include "llvm/Support/Endian.h"
21 
22 using namespace mlir;
23 using namespace mlir::detail;
24 
25 //===----------------------------------------------------------------------===//
26 // AttributeStorage
27 //===----------------------------------------------------------------------===//
28 
29 AttributeStorage::AttributeStorage(Type type)
30     : type(type.getAsOpaquePointer()) {}
31 AttributeStorage::AttributeStorage() : type(nullptr) {}
32 
33 Type AttributeStorage::getType() const {
34   return Type::getFromOpaquePointer(type);
35 }
36 void AttributeStorage::setType(Type newType) {
37   type = newType.getAsOpaquePointer();
38 }
39 
40 //===----------------------------------------------------------------------===//
41 // Attribute
42 //===----------------------------------------------------------------------===//
43 
44 /// Return the type of this attribute.
45 Type Attribute::getType() const { return impl->getType(); }
46 
47 /// Return the context this attribute belongs to.
48 MLIRContext *Attribute::getContext() const { return getType().getContext(); }
49 
50 /// Get the dialect this attribute is registered to.
51 Dialect &Attribute::getDialect() const {
52   return impl->getAbstractAttribute().getDialect();
53 }
54 
55 //===----------------------------------------------------------------------===//
56 // AffineMapAttr
57 //===----------------------------------------------------------------------===//
58 
59 AffineMapAttr AffineMapAttr::get(AffineMap value) {
60   return Base::get(value.getContext(), value);
61 }
62 
63 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
64 
65 //===----------------------------------------------------------------------===//
66 // ArrayAttr
67 //===----------------------------------------------------------------------===//
68 
69 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
70   return Base::get(context, value);
71 }
72 
73 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
74 
75 Attribute ArrayAttr::operator[](unsigned idx) const {
76   assert(idx < size() && "index out of bounds");
77   return getValue()[idx];
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // DictionaryAttr
82 //===----------------------------------------------------------------------===//
83 
84 /// Helper function that does either an in place sort or sorts from source array
85 /// into destination. If inPlace then storage is both the source and the
86 /// destination, else value is the source and storage destination. Returns
87 /// whether source was sorted.
88 template <bool inPlace>
89 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
90                                SmallVectorImpl<NamedAttribute> &storage) {
91   // Specialize for the common case.
92   switch (value.size()) {
93   case 0:
94     // Zero already sorted.
95     break;
96   case 1:
97     // One already sorted but may need to be copied.
98     if (!inPlace)
99       storage.assign({value[0]});
100     break;
101   case 2: {
102     bool isSorted = value[0] < value[1];
103     if (inPlace) {
104       if (!isSorted)
105         std::swap(storage[0], storage[1]);
106     } else if (isSorted) {
107       storage.assign({value[0], value[1]});
108     } else {
109       storage.assign({value[1], value[0]});
110     }
111     return !isSorted;
112   }
113   default:
114     if (!inPlace)
115       storage.assign(value.begin(), value.end());
116     // Check to see they are sorted already.
117     bool isSorted = llvm::is_sorted(value);
118     if (!isSorted) {
119       // If not, do a general sort.
120       llvm::array_pod_sort(storage.begin(), storage.end());
121       value = storage;
122     }
123     return !isSorted;
124   }
125   return false;
126 }
127 
128 /// Returns an entry with a duplicate name from the given sorted array of named
129 /// attributes. Returns llvm::None if all elements have unique names.
130 static Optional<NamedAttribute>
131 findDuplicateElement(ArrayRef<NamedAttribute> value) {
132   const Optional<NamedAttribute> none{llvm::None};
133   if (value.size() < 2)
134     return none;
135 
136   if (value.size() == 2)
137     return value[0].first == value[1].first ? value[0] : none;
138 
139   auto it = std::adjacent_find(
140       value.begin(), value.end(),
141       [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; });
142   return it != value.end() ? *it : none;
143 }
144 
145 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
146                           SmallVectorImpl<NamedAttribute> &storage) {
147   bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
148   assert(!findDuplicateElement(storage) &&
149          "DictionaryAttr element names must be unique");
150   return isSorted;
151 }
152 
153 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
154   bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
155   assert(!findDuplicateElement(array) &&
156          "DictionaryAttr element names must be unique");
157   return isSorted;
158 }
159 
160 Optional<NamedAttribute>
161 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
162                               bool isSorted) {
163   if (!isSorted)
164     dictionaryAttrSort</*inPlace=*/true>(array, array);
165   return findDuplicateElement(array);
166 }
167 
168 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
169                                    MLIRContext *context) {
170   if (value.empty())
171     return DictionaryAttr::getEmpty(context);
172   assert(llvm::all_of(value,
173                       [](const NamedAttribute &attr) { return attr.second; }) &&
174          "value cannot have null entries");
175 
176   // We need to sort the element list to canonicalize it.
177   SmallVector<NamedAttribute, 8> storage;
178   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
179     value = storage;
180   assert(!findDuplicateElement(value) &&
181          "DictionaryAttr element names must be unique");
182   return Base::get(context, value);
183 }
184 /// Construct a dictionary with an array of values that is known to already be
185 /// sorted by name and uniqued.
186 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
187                                              MLIRContext *context) {
188   if (value.empty())
189     return DictionaryAttr::getEmpty(context);
190   // Ensure that the attribute elements are unique and sorted.
191   assert(llvm::is_sorted(value,
192                          [](NamedAttribute l, NamedAttribute r) {
193                            return l.first.strref() < r.first.strref();
194                          }) &&
195          "expected attribute values to be sorted");
196   assert(!findDuplicateElement(value) &&
197          "DictionaryAttr element names must be unique");
198   return Base::get(context, value);
199 }
200 
201 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
202   return getImpl()->getElements();
203 }
204 
205 /// Return the specified attribute if present, null otherwise.
206 Attribute DictionaryAttr::get(StringRef name) const {
207   Optional<NamedAttribute> attr = getNamed(name);
208   return attr ? attr->second : nullptr;
209 }
210 Attribute DictionaryAttr::get(Identifier name) const {
211   Optional<NamedAttribute> attr = getNamed(name);
212   return attr ? attr->second : nullptr;
213 }
214 
215 /// Return the specified named attribute if present, None otherwise.
216 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
217   ArrayRef<NamedAttribute> values = getValue();
218   const auto *it = llvm::lower_bound(values, name);
219   return it != values.end() && it->first == name ? *it
220                                                  : Optional<NamedAttribute>();
221 }
222 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
223   for (auto elt : getValue())
224     if (elt.first == name)
225       return elt;
226   return llvm::None;
227 }
228 
229 DictionaryAttr::iterator DictionaryAttr::begin() const {
230   return getValue().begin();
231 }
232 DictionaryAttr::iterator DictionaryAttr::end() const {
233   return getValue().end();
234 }
235 size_t DictionaryAttr::size() const { return getValue().size(); }
236 
237 //===----------------------------------------------------------------------===//
238 // FloatAttr
239 //===----------------------------------------------------------------------===//
240 
241 FloatAttr FloatAttr::get(Type type, double value) {
242   return Base::get(type.getContext(), type, value);
243 }
244 
245 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
246   return Base::getChecked(loc, type, value);
247 }
248 
249 FloatAttr FloatAttr::get(Type type, const APFloat &value) {
250   return Base::get(type.getContext(), type, value);
251 }
252 
253 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
254   return Base::getChecked(loc, type, value);
255 }
256 
257 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
258 
259 double FloatAttr::getValueAsDouble() const {
260   return getValueAsDouble(getValue());
261 }
262 double FloatAttr::getValueAsDouble(APFloat value) {
263   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
264     bool losesInfo = false;
265     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
266                   &losesInfo);
267   }
268   return value.convertToDouble();
269 }
270 
271 /// Verify construction invariants.
272 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
273   if (!type.isa<FloatType>())
274     return emitError(loc, "expected floating point type");
275   return success();
276 }
277 
278 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
279                                                       double value) {
280   return verifyFloatTypeInvariants(loc, type);
281 }
282 
283 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
284                                                       const APFloat &value) {
285   // Verify that the type is correct.
286   if (failed(verifyFloatTypeInvariants(loc, type)))
287     return failure();
288 
289   // Verify that the type semantics match that of the value.
290   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
291     return emitError(
292         loc, "FloatAttr type doesn't match the type implied by its value");
293   }
294   return success();
295 }
296 
297 //===----------------------------------------------------------------------===//
298 // SymbolRefAttr
299 //===----------------------------------------------------------------------===//
300 
301 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
302   return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
303 }
304 
305 SymbolRefAttr SymbolRefAttr::get(StringRef value,
306                                  ArrayRef<FlatSymbolRefAttr> nestedReferences,
307                                  MLIRContext *ctx) {
308   return Base::get(ctx, value, nestedReferences);
309 }
310 
311 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
312 
313 StringRef SymbolRefAttr::getLeafReference() const {
314   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
315   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
316 }
317 
318 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
319   return getImpl()->getNestedRefs();
320 }
321 
322 //===----------------------------------------------------------------------===//
323 // IntegerAttr
324 //===----------------------------------------------------------------------===//
325 
326 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
327   if (type.isSignlessInteger(1))
328     return BoolAttr::get(value.getBoolValue(), type.getContext());
329   return Base::get(type.getContext(), type, value);
330 }
331 
332 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
333   // This uses 64 bit APInts by default for index type.
334   if (type.isIndex())
335     return get(type, APInt(IndexType::kInternalStorageBitWidth, value));
336 
337   auto intType = type.cast<IntegerType>();
338   return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
339 }
340 
341 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
342 
343 int64_t IntegerAttr::getInt() const {
344   assert((getImpl()->getType().isIndex() ||
345           getImpl()->getType().isSignlessInteger()) &&
346          "must be signless integer");
347   return getValue().getSExtValue();
348 }
349 
350 int64_t IntegerAttr::getSInt() const {
351   assert(getImpl()->getType().isSignedInteger() && "must be signed integer");
352   return getValue().getSExtValue();
353 }
354 
355 uint64_t IntegerAttr::getUInt() const {
356   assert(getImpl()->getType().isUnsignedInteger() &&
357          "must be unsigned integer");
358   return getValue().getZExtValue();
359 }
360 
361 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
362   if (type.isa<IntegerType, IndexType>())
363     return success();
364   return emitError(loc, "expected integer or index type");
365 }
366 
367 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
368                                                         int64_t value) {
369   return verifyIntegerTypeInvariants(loc, type);
370 }
371 
372 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
373                                                         const APInt &value) {
374   if (failed(verifyIntegerTypeInvariants(loc, type)))
375     return failure();
376   if (auto integerType = type.dyn_cast<IntegerType>())
377     if (integerType.getWidth() != value.getBitWidth())
378       return emitError(loc, "integer type bit width (")
379              << integerType.getWidth() << ") doesn't match value bit width ("
380              << value.getBitWidth() << ")";
381   return success();
382 }
383 
384 //===----------------------------------------------------------------------===//
385 // BoolAttr
386 
387 bool BoolAttr::getValue() const {
388   auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl);
389   return storage->getValue().getBoolValue();
390 }
391 
392 bool BoolAttr::classof(Attribute attr) {
393   IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
394   return intAttr && intAttr.getType().isSignlessInteger(1);
395 }
396 
397 //===----------------------------------------------------------------------===//
398 // IntegerSetAttr
399 //===----------------------------------------------------------------------===//
400 
401 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
402   return Base::get(value.getConstraint(0).getContext(), value);
403 }
404 
405 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
406 
407 //===----------------------------------------------------------------------===//
408 // OpaqueAttr
409 //===----------------------------------------------------------------------===//
410 
411 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
412                            MLIRContext *context) {
413   return Base::get(context, dialect, attrData, type);
414 }
415 
416 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
417                                   Type type, Location location) {
418   return Base::getChecked(location, dialect, attrData, type);
419 }
420 
421 /// Returns the dialect namespace of the opaque attribute.
422 Identifier OpaqueAttr::getDialectNamespace() const {
423   return getImpl()->dialectNamespace;
424 }
425 
426 /// Returns the raw attribute data of the opaque attribute.
427 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
428 
429 /// Verify the construction of an opaque attribute.
430 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
431                                                        Identifier dialect,
432                                                        StringRef attrData,
433                                                        Type type) {
434   if (!Dialect::isValidNamespace(dialect.strref()))
435     return emitError(loc, "invalid dialect namespace '") << dialect << "'";
436   return success();
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // StringAttr
441 //===----------------------------------------------------------------------===//
442 
443 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
444   return get(bytes, NoneType::get(context));
445 }
446 
447 /// Get an instance of a StringAttr with the given string and Type.
448 StringAttr StringAttr::get(StringRef bytes, Type type) {
449   return Base::get(type.getContext(), bytes, type);
450 }
451 
452 StringRef StringAttr::getValue() const { return getImpl()->value; }
453 
454 //===----------------------------------------------------------------------===//
455 // TypeAttr
456 //===----------------------------------------------------------------------===//
457 
458 TypeAttr TypeAttr::get(Type value) {
459   return Base::get(value.getContext(), value);
460 }
461 
462 Type TypeAttr::getValue() const { return getImpl()->value; }
463 
464 //===----------------------------------------------------------------------===//
465 // ElementsAttr
466 //===----------------------------------------------------------------------===//
467 
468 ShapedType ElementsAttr::getType() const {
469   return Attribute::getType().cast<ShapedType>();
470 }
471 
472 /// Returns the number of elements held by this attribute.
473 int64_t ElementsAttr::getNumElements() const {
474   return getType().getNumElements();
475 }
476 
477 /// Return the value at the given index. If index does not refer to a valid
478 /// element, then a null attribute is returned.
479 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
480   if (auto denseAttr = dyn_cast<DenseElementsAttr>())
481     return denseAttr.getValue(index);
482   if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>())
483     return opaqueAttr.getValue(index);
484   return cast<SparseElementsAttr>().getValue(index);
485 }
486 
487 /// Return if the given 'index' refers to a valid element in this attribute.
488 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
489   auto type = getType();
490 
491   // Verify that the rank of the indices matches the held type.
492   auto rank = type.getRank();
493   if (rank != static_cast<int64_t>(index.size()))
494     return false;
495 
496   // Verify that all of the indices are within the shape dimensions.
497   auto shape = type.getShape();
498   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
499     return static_cast<int64_t>(index[i]) < shape[i];
500   });
501 }
502 
503 ElementsAttr
504 ElementsAttr::mapValues(Type newElementType,
505                         function_ref<APInt(const APInt &)> mapping) const {
506   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
507     return intOrFpAttr.mapValues(newElementType, mapping);
508   llvm_unreachable("unsupported ElementsAttr subtype");
509 }
510 
511 ElementsAttr
512 ElementsAttr::mapValues(Type newElementType,
513                         function_ref<APInt(const APFloat &)> mapping) const {
514   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
515     return intOrFpAttr.mapValues(newElementType, mapping);
516   llvm_unreachable("unsupported ElementsAttr subtype");
517 }
518 
519 /// Method for support type inquiry through isa, cast and dyn_cast.
520 bool ElementsAttr::classof(Attribute attr) {
521   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr,
522                   OpaqueElementsAttr, SparseElementsAttr>();
523 }
524 
525 /// Returns the 1 dimensional flattened row-major index from the given
526 /// multi-dimensional index.
527 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
528   assert(isValidIndex(index) && "expected valid multi-dimensional index");
529   auto type = getType();
530 
531   // Reduce the provided multidimensional index into a flattended 1D row-major
532   // index.
533   auto rank = type.getRank();
534   auto shape = type.getShape();
535   uint64_t valueIndex = 0;
536   uint64_t dimMultiplier = 1;
537   for (int i = rank - 1; i >= 0; --i) {
538     valueIndex += index[i] * dimMultiplier;
539     dimMultiplier *= shape[i];
540   }
541   return valueIndex;
542 }
543 
544 //===----------------------------------------------------------------------===//
545 // DenseElementsAttr Utilities
546 //===----------------------------------------------------------------------===//
547 
548 /// Get the bitwidth of a dense element type within the buffer.
549 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
550 static size_t getDenseElementStorageWidth(size_t origWidth) {
551   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
552 }
553 static size_t getDenseElementStorageWidth(Type elementType) {
554   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
555 }
556 
557 /// Set a bit to a specific value.
558 static void setBit(char *rawData, size_t bitPos, bool value) {
559   if (value)
560     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
561   else
562     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
563 }
564 
565 /// Return the value of the specified bit.
566 static bool getBit(const char *rawData, size_t bitPos) {
567   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
568 }
569 
570 /// Get start position of actual data in `value`. Actual data is
571 /// stored in last `bitWidth`/CHAR_BIT bytes in big endian.
572 static char *getAPIntDataPos(APInt &value, size_t bitWidth) {
573   char *dataPos =
574       const_cast<char *>(reinterpret_cast<const char *>(value.getRawData()));
575   if (llvm::support::endian::system_endianness() ==
576       llvm::support::endianness::big)
577     dataPos = dataPos + 8 - llvm::divideCeil(bitWidth, CHAR_BIT);
578   return dataPos;
579 }
580 
581 /// Read APInt `value` from appropriate position.
582 static void readAPInt(APInt &value, size_t bitWidth, char *outData) {
583   char *dataPos = getAPIntDataPos(value, bitWidth);
584   std::copy_n(dataPos, llvm::divideCeil(bitWidth, CHAR_BIT), outData);
585 }
586 
587 /// Write `inData` to appropriate position of APInt `value`.
588 static void writeAPInt(const char *inData, size_t bitWidth, APInt &value) {
589   char *dataPos = getAPIntDataPos(value, bitWidth);
590   std::copy_n(inData, llvm::divideCeil(bitWidth, CHAR_BIT), dataPos);
591 }
592 
593 /// Writes value to the bit position `bitPos` in array `rawData`.
594 static void writeBits(char *rawData, size_t bitPos, APInt value) {
595   size_t bitWidth = value.getBitWidth();
596 
597   // If the bitwidth is 1 we just toggle the specific bit.
598   if (bitWidth == 1)
599     return setBit(rawData, bitPos, value.isOneValue());
600 
601   // Otherwise, the bit position is guaranteed to be byte aligned.
602   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
603   readAPInt(value, bitWidth, rawData + (bitPos / CHAR_BIT));
604 }
605 
606 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
607 /// `rawData`.
608 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
609   // Handle a boolean bit position.
610   if (bitWidth == 1)
611     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
612 
613   // Otherwise, the bit position must be 8-bit aligned.
614   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
615   APInt result(bitWidth, 0);
616   writeAPInt(rawData + (bitPos / CHAR_BIT), bitWidth, result);
617   return result;
618 }
619 
620 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
621 /// the same element count as 'type'.
622 template <typename Values>
623 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
624   return (values.size() == 1) ||
625          (type.getNumElements() == static_cast<int64_t>(values.size()));
626 }
627 
628 //===----------------------------------------------------------------------===//
629 // DenseElementsAttr Iterators
630 //===----------------------------------------------------------------------===//
631 
632 //===----------------------------------------------------------------------===//
633 // AttributeElementIterator
634 
635 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
636     DenseElementsAttr attr, size_t index)
637     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
638                                       Attribute, Attribute, Attribute>(
639           attr.getAsOpaquePointer(), index) {}
640 
641 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
642   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
643   Type eltTy = owner.getType().getElementType();
644   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
645     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
646   if (eltTy.isa<IndexType>())
647     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
648   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
649     IntElementIterator intIt(owner, index);
650     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
651     return FloatAttr::get(eltTy, *floatIt);
652   }
653   if (owner.isa<DenseStringElementsAttr>()) {
654     ArrayRef<StringRef> vals = owner.getRawStringData();
655     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
656   }
657   llvm_unreachable("unexpected element type");
658 }
659 
660 //===----------------------------------------------------------------------===//
661 // BoolElementIterator
662 
663 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
664     DenseElementsAttr attr, size_t dataIndex)
665     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
666           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
667 
668 bool DenseElementsAttr::BoolElementIterator::operator*() const {
669   return getBit(getData(), getDataIndex());
670 }
671 
672 //===----------------------------------------------------------------------===//
673 // IntElementIterator
674 
675 DenseElementsAttr::IntElementIterator::IntElementIterator(
676     DenseElementsAttr attr, size_t dataIndex)
677     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
678           attr.getRawData().data(), attr.isSplat(), dataIndex),
679       bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
680 
681 APInt DenseElementsAttr::IntElementIterator::operator*() const {
682   return readBits(getData(),
683                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
684                   bitWidth);
685 }
686 
687 //===----------------------------------------------------------------------===//
688 // ComplexIntElementIterator
689 
690 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
691     DenseElementsAttr attr, size_t dataIndex)
692     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
693                                       std::complex<APInt>, std::complex<APInt>,
694                                       std::complex<APInt>>(
695           attr.getRawData().data(), attr.isSplat(), dataIndex) {
696   auto complexType = attr.getType().getElementType().cast<ComplexType>();
697   bitWidth = getDenseElementBitWidth(complexType.getElementType());
698 }
699 
700 std::complex<APInt>
701 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
702   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
703   size_t offset = getDataIndex() * storageWidth * 2;
704   return {readBits(getData(), offset, bitWidth),
705           readBits(getData(), offset + storageWidth, bitWidth)};
706 }
707 
708 //===----------------------------------------------------------------------===//
709 // FloatElementIterator
710 
711 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
712     const llvm::fltSemantics &smt, IntElementIterator it)
713     : llvm::mapped_iterator<IntElementIterator,
714                             std::function<APFloat(const APInt &)>>(
715           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
716 
717 //===----------------------------------------------------------------------===//
718 // ComplexFloatElementIterator
719 
720 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
721     const llvm::fltSemantics &smt, ComplexIntElementIterator it)
722     : llvm::mapped_iterator<
723           ComplexIntElementIterator,
724           std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
725           it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
726             return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
727           }) {}
728 
729 //===----------------------------------------------------------------------===//
730 // DenseElementsAttr
731 //===----------------------------------------------------------------------===//
732 
733 /// Method for support type inquiry through isa, cast and dyn_cast.
734 bool DenseElementsAttr::classof(Attribute attr) {
735   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
736 }
737 
738 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
739                                          ArrayRef<Attribute> values) {
740   assert(hasSameElementsOrSplat(type, values));
741 
742   // If the element type is not based on int/float/index, assume it is a string
743   // type.
744   auto eltType = type.getElementType();
745   if (!type.getElementType().isIntOrIndexOrFloat()) {
746     SmallVector<StringRef, 8> stringValues;
747     stringValues.reserve(values.size());
748     for (Attribute attr : values) {
749       assert(attr.isa<StringAttr>() &&
750              "expected string value for non integer/index/float element");
751       stringValues.push_back(attr.cast<StringAttr>().getValue());
752     }
753     return get(type, stringValues);
754   }
755 
756   // Otherwise, get the raw storage width to use for the allocation.
757   size_t bitWidth = getDenseElementBitWidth(eltType);
758   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
759 
760   // Compress the attribute values into a character buffer.
761   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
762                             values.size());
763   APInt intVal;
764   for (unsigned i = 0, e = values.size(); i < e; ++i) {
765     assert(eltType == values[i].getType() &&
766            "expected attribute value to have element type");
767     if (eltType.isa<FloatType>())
768       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
769     else if (eltType.isa<IntegerType>())
770       intVal = values[i].cast<IntegerAttr>().getValue();
771     else
772       llvm_unreachable("unexpected element type");
773 
774     assert(intVal.getBitWidth() == bitWidth &&
775            "expected value to have same bitwidth as element type");
776     writeBits(data.data(), i * storageBitWidth, intVal);
777   }
778   return DenseIntOrFPElementsAttr::getRaw(type, data,
779                                           /*isSplat=*/(values.size() == 1));
780 }
781 
782 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
783                                          ArrayRef<bool> values) {
784   assert(hasSameElementsOrSplat(type, values));
785   assert(type.getElementType().isInteger(1));
786 
787   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
788   for (int i = 0, e = values.size(); i != e; ++i)
789     setBit(buff.data(), i, values[i]);
790   return DenseIntOrFPElementsAttr::getRaw(type, buff,
791                                           /*isSplat=*/(values.size() == 1));
792 }
793 
794 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
795                                          ArrayRef<StringRef> values) {
796   assert(!type.getElementType().isIntOrFloat());
797   return DenseStringElementsAttr::get(type, values);
798 }
799 
800 /// Constructs a dense integer elements attribute from an array of APInt
801 /// values. Each APInt value is expected to have the same bitwidth as the
802 /// element type of 'type'.
803 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
804                                          ArrayRef<APInt> values) {
805   assert(type.getElementType().isIntOrIndex());
806   assert(hasSameElementsOrSplat(type, values));
807   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
808   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
809                                           /*isSplat=*/(values.size() == 1));
810 }
811 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
812                                          ArrayRef<std::complex<APInt>> values) {
813   ComplexType complex = type.getElementType().cast<ComplexType>();
814   assert(complex.getElementType().isa<IntegerType>());
815   assert(hasSameElementsOrSplat(type, values));
816   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
817   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
818                           values.size() * 2);
819   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
820                                           /*isSplat=*/(values.size() == 1));
821 }
822 
823 // Constructs a dense float elements attribute from an array of APFloat
824 // values. Each APFloat value is expected to have the same bitwidth as the
825 // element type of 'type'.
826 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
827                                          ArrayRef<APFloat> values) {
828   assert(type.getElementType().isa<FloatType>());
829   assert(hasSameElementsOrSplat(type, values));
830   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
831   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
832                                           /*isSplat=*/(values.size() == 1));
833 }
834 DenseElementsAttr
835 DenseElementsAttr::get(ShapedType type,
836                        ArrayRef<std::complex<APFloat>> values) {
837   ComplexType complex = type.getElementType().cast<ComplexType>();
838   assert(complex.getElementType().isa<FloatType>());
839   assert(hasSameElementsOrSplat(type, values));
840   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
841                            values.size() * 2);
842   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
843   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
844                                           /*isSplat=*/(values.size() == 1));
845 }
846 
847 /// Construct a dense elements attribute from a raw buffer representing the
848 /// data for this attribute. Users should generally not use this methods as
849 /// the expected buffer format may not be a form the user expects.
850 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
851                                                       ArrayRef<char> rawBuffer,
852                                                       bool isSplatBuffer) {
853   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
854 }
855 
856 /// Returns true if the given buffer is a valid raw buffer for the given type.
857 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
858                                          ArrayRef<char> rawBuffer,
859                                          bool &detectedSplat) {
860   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
861   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
862 
863   // Storage width of 1 is special as it is packed by the bit.
864   if (storageWidth == 1) {
865     // Check for a splat, or a buffer equal to the number of elements.
866     if ((detectedSplat = rawBuffer.size() == 1))
867       return true;
868     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
869   }
870   // All other types are 8-bit aligned.
871   if ((detectedSplat = rawBufferWidth == storageWidth))
872     return true;
873   return rawBufferWidth == (storageWidth * type.getNumElements());
874 }
875 
876 /// Check the information for a C++ data type, check if this type is valid for
877 /// the current attribute. This method is used to verify specific type
878 /// invariants that the templatized 'getValues' method cannot.
879 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
880                               bool isSigned) {
881   // Make sure that the data element size is the same as the type element width.
882   if (getDenseElementBitWidth(type) !=
883       static_cast<size_t>(dataEltSize * CHAR_BIT))
884     return false;
885 
886   // Check that the element type is either float or integer or index.
887   if (!isInt)
888     return type.isa<FloatType>();
889   if (type.isIndex())
890     return true;
891 
892   auto intType = type.dyn_cast<IntegerType>();
893   if (!intType)
894     return false;
895 
896   // Make sure signedness semantics is consistent.
897   if (intType.isSignless())
898     return true;
899   return intType.isSigned() ? isSigned : !isSigned;
900 }
901 
902 /// Defaults down the subclass implementation.
903 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
904                                                    ArrayRef<char> data,
905                                                    int64_t dataEltSize,
906                                                    bool isInt, bool isSigned) {
907   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
908                                                  isSigned);
909 }
910 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
911                                                       ArrayRef<char> data,
912                                                       int64_t dataEltSize,
913                                                       bool isInt,
914                                                       bool isSigned) {
915   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
916                                                     isInt, isSigned);
917 }
918 
919 /// A method used to verify specific type invariants that the templatized 'get'
920 /// method cannot.
921 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
922                                           bool isSigned) const {
923   return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
924                              isSigned);
925 }
926 
927 /// Check the information for a C++ data type, check if this type is valid for
928 /// the current attribute.
929 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
930                                        bool isSigned) const {
931   return ::isValidIntOrFloat(
932       getType().getElementType().cast<ComplexType>().getElementType(),
933       dataEltSize / 2, isInt, isSigned);
934 }
935 
936 /// Returns true if this attribute corresponds to a splat, i.e. if all element
937 /// values are the same.
938 bool DenseElementsAttr::isSplat() const {
939   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
940 }
941 
942 /// Return the held element values as a range of Attributes.
943 auto DenseElementsAttr::getAttributeValues() const
944     -> llvm::iterator_range<AttributeElementIterator> {
945   return {attr_value_begin(), attr_value_end()};
946 }
947 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
948   return AttributeElementIterator(*this, 0);
949 }
950 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
951   return AttributeElementIterator(*this, getNumElements());
952 }
953 
954 /// Return the held element values as a range of bool. The element type of
955 /// this attribute must be of integer type of bitwidth 1.
956 auto DenseElementsAttr::getBoolValues() const
957     -> llvm::iterator_range<BoolElementIterator> {
958   auto eltType = getType().getElementType().dyn_cast<IntegerType>();
959   assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
960   (void)eltType;
961   return {BoolElementIterator(*this, 0),
962           BoolElementIterator(*this, getNumElements())};
963 }
964 
965 /// Return the held element values as a range of APInts. The element type of
966 /// this attribute must be of integer type.
967 auto DenseElementsAttr::getIntValues() const
968     -> llvm::iterator_range<IntElementIterator> {
969   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
970   return {raw_int_begin(), raw_int_end()};
971 }
972 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
973   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
974   return raw_int_begin();
975 }
976 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
977   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
978   return raw_int_end();
979 }
980 auto DenseElementsAttr::getComplexIntValues() const
981     -> llvm::iterator_range<ComplexIntElementIterator> {
982   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
983   (void)eltTy;
984   assert(eltTy.isa<IntegerType>() && "expected complex integral type");
985   return {ComplexIntElementIterator(*this, 0),
986           ComplexIntElementIterator(*this, getNumElements())};
987 }
988 
989 /// Return the held element values as a range of APFloat. The element type of
990 /// this attribute must be of float type.
991 auto DenseElementsAttr::getFloatValues() const
992     -> llvm::iterator_range<FloatElementIterator> {
993   auto elementType = getType().getElementType().cast<FloatType>();
994   const auto &elementSemantics = elementType.getFloatSemantics();
995   return {FloatElementIterator(elementSemantics, raw_int_begin()),
996           FloatElementIterator(elementSemantics, raw_int_end())};
997 }
998 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
999   return getFloatValues().begin();
1000 }
1001 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
1002   return getFloatValues().end();
1003 }
1004 auto DenseElementsAttr::getComplexFloatValues() const
1005     -> llvm::iterator_range<ComplexFloatElementIterator> {
1006   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
1007   assert(eltTy.isa<FloatType>() && "expected complex float type");
1008   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
1009   return {{semantics, {*this, 0}},
1010           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
1011 }
1012 
1013 /// Return the raw storage data held by this attribute.
1014 ArrayRef<char> DenseElementsAttr::getRawData() const {
1015   return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data;
1016 }
1017 
1018 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1019   return static_cast<DenseStringElementsAttributeStorage *>(impl)->data;
1020 }
1021 
1022 /// Return a new DenseElementsAttr that has the same data as the current
1023 /// attribute, but has been reshaped to 'newType'. The new type must have the
1024 /// same total number of elements as well as element type.
1025 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1026   ShapedType curType = getType();
1027   if (curType == newType)
1028     return *this;
1029 
1030   (void)curType;
1031   assert(newType.getElementType() == curType.getElementType() &&
1032          "expected the same element type");
1033   assert(newType.getNumElements() == curType.getNumElements() &&
1034          "expected the same number of elements");
1035   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
1036 }
1037 
1038 DenseElementsAttr
1039 DenseElementsAttr::mapValues(Type newElementType,
1040                              function_ref<APInt(const APInt &)> mapping) const {
1041   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1042 }
1043 
1044 DenseElementsAttr DenseElementsAttr::mapValues(
1045     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1046   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1047 }
1048 
1049 //===----------------------------------------------------------------------===//
1050 // DenseStringElementsAttr
1051 //===----------------------------------------------------------------------===//
1052 
1053 DenseStringElementsAttr
1054 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
1055   return Base::get(type.getContext(), type, values, (values.size() == 1));
1056 }
1057 
1058 //===----------------------------------------------------------------------===//
1059 // DenseIntOrFPElementsAttr
1060 //===----------------------------------------------------------------------===//
1061 
1062 /// Utility method to write a range of APInt values to a buffer.
1063 template <typename APRangeT>
1064 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1065                                 APRangeT &&values) {
1066   data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
1067   size_t offset = 0;
1068   for (auto it = values.begin(), e = values.end(); it != e;
1069        ++it, offset += storageWidth) {
1070     assert((*it).getBitWidth() <= storageWidth);
1071     writeBits(data.data(), offset, *it);
1072   }
1073 }
1074 
1075 /// Constructs a dense elements attribute from an array of raw APFloat values.
1076 /// Each APFloat value is expected to have the same bitwidth as the element
1077 /// type of 'type'. 'type' must be a vector or tensor with static shape.
1078 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1079                                                    size_t storageWidth,
1080                                                    ArrayRef<APFloat> values,
1081                                                    bool isSplat) {
1082   std::vector<char> data;
1083   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1084   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1085   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1086 }
1087 
1088 /// Constructs a dense elements attribute from an array of raw APInt values.
1089 /// Each APInt value is expected to have the same bitwidth as the element type
1090 /// of 'type'.
1091 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1092                                                    size_t storageWidth,
1093                                                    ArrayRef<APInt> values,
1094                                                    bool isSplat) {
1095   std::vector<char> data;
1096   writeAPIntsToBuffer(storageWidth, data, values);
1097   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1098 }
1099 
1100 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1101                                                    ArrayRef<char> data,
1102                                                    bool isSplat) {
1103   assert((type.isa<RankedTensorType, VectorType>()) &&
1104          "type must be ranked tensor or vector");
1105   assert(type.hasStaticShape() && "type must have static shape");
1106   return Base::get(type.getContext(), type, data, isSplat);
1107 }
1108 
1109 /// Overload of the raw 'get' method that asserts that the given type is of
1110 /// complex type. This method is used to verify type invariants that the
1111 /// templatized 'get' method cannot.
1112 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1113                                                           ArrayRef<char> data,
1114                                                           int64_t dataEltSize,
1115                                                           bool isInt,
1116                                                           bool isSigned) {
1117   assert(::isValidIntOrFloat(
1118       type.getElementType().cast<ComplexType>().getElementType(),
1119       dataEltSize / 2, isInt, isSigned));
1120 
1121   int64_t numElements = data.size() / dataEltSize;
1122   assert(numElements == 1 || numElements == type.getNumElements());
1123   return getRaw(type, data, /*isSplat=*/numElements == 1);
1124 }
1125 
1126 /// Overload of the 'getRaw' method that asserts that the given type is of
1127 /// integer type. This method is used to verify type invariants that the
1128 /// templatized 'get' method cannot.
1129 DenseElementsAttr
1130 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1131                                            int64_t dataEltSize, bool isInt,
1132                                            bool isSigned) {
1133   assert(
1134       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1135 
1136   int64_t numElements = data.size() / dataEltSize;
1137   assert(numElements == 1 || numElements == type.getNumElements());
1138   return getRaw(type, data, /*isSplat=*/numElements == 1);
1139 }
1140 
1141 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1142     const char *inRawData, char *outRawData, size_t elementBitWidth,
1143     size_t numElements) {
1144   using llvm::support::ulittle16_t;
1145   using llvm::support::ulittle32_t;
1146   using llvm::support::ulittle64_t;
1147 
1148   assert(llvm::support::endian::system_endianness() == // NOLINT
1149          llvm::support::endianness::big);              // NOLINT
1150   // NOLINT to avoid warning message about replacing by static_assert()
1151 
1152   // Following std::copy_n always converts endianness on BE machine.
1153   switch (elementBitWidth) {
1154   case 16: {
1155     const ulittle16_t *inRawDataPos =
1156         reinterpret_cast<const ulittle16_t *>(inRawData);
1157     uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1158     std::copy_n(inRawDataPos, numElements, outDataPos);
1159     break;
1160   }
1161   case 32: {
1162     const ulittle32_t *inRawDataPos =
1163         reinterpret_cast<const ulittle32_t *>(inRawData);
1164     uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1165     std::copy_n(inRawDataPos, numElements, outDataPos);
1166     break;
1167   }
1168   case 64: {
1169     const ulittle64_t *inRawDataPos =
1170         reinterpret_cast<const ulittle64_t *>(inRawData);
1171     uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1172     std::copy_n(inRawDataPos, numElements, outDataPos);
1173     break;
1174   }
1175   default: {
1176     size_t nBytes = elementBitWidth / CHAR_BIT;
1177     for (size_t i = 0; i < nBytes; i++)
1178       std::copy_n(inRawData + (nBytes - 1 - i), numElements, outRawData + i);
1179     break;
1180   }
1181   }
1182 }
1183 
1184 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1185     ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1186     ShapedType type) {
1187   size_t numElements = type.getNumElements();
1188   Type elementType = type.getElementType();
1189   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1190     elementType = complexTy.getElementType();
1191     numElements = numElements * 2;
1192   }
1193   size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1194   assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1195          inRawData.size() <= outRawData.size());
1196   convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1197                                   elementBitWidth, numElements);
1198 }
1199 
1200 //===----------------------------------------------------------------------===//
1201 // DenseFPElementsAttr
1202 //===----------------------------------------------------------------------===//
1203 
1204 template <typename Fn, typename Attr>
1205 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1206                                 Type newElementType,
1207                                 llvm::SmallVectorImpl<char> &data) {
1208   size_t bitWidth = getDenseElementBitWidth(newElementType);
1209   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1210 
1211   ShapedType newArrayType;
1212   if (inType.isa<RankedTensorType>())
1213     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1214   else if (inType.isa<UnrankedTensorType>())
1215     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1216   else if (inType.isa<VectorType>())
1217     newArrayType = VectorType::get(inType.getShape(), newElementType);
1218   else
1219     assert(newArrayType && "Unhandled tensor type");
1220 
1221   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1222   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1223 
1224   // Functor used to process a single element value of the attribute.
1225   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1226     auto newInt = mapping(value);
1227     assert(newInt.getBitWidth() == bitWidth);
1228     writeBits(data.data(), index * storageBitWidth, newInt);
1229   };
1230 
1231   // Check for the splat case.
1232   if (attr.isSplat()) {
1233     processElt(*attr.begin(), /*index=*/0);
1234     return newArrayType;
1235   }
1236 
1237   // Otherwise, process all of the element values.
1238   uint64_t elementIdx = 0;
1239   for (auto value : attr)
1240     processElt(value, elementIdx++);
1241   return newArrayType;
1242 }
1243 
1244 DenseElementsAttr DenseFPElementsAttr::mapValues(
1245     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1246   llvm::SmallVector<char, 8> elementData;
1247   auto newArrayType =
1248       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1249 
1250   return getRaw(newArrayType, elementData, isSplat());
1251 }
1252 
1253 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1254 bool DenseFPElementsAttr::classof(Attribute attr) {
1255   return attr.isa<DenseElementsAttr>() &&
1256          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1257 }
1258 
1259 //===----------------------------------------------------------------------===//
1260 // DenseIntElementsAttr
1261 //===----------------------------------------------------------------------===//
1262 
1263 DenseElementsAttr DenseIntElementsAttr::mapValues(
1264     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1265   llvm::SmallVector<char, 8> elementData;
1266   auto newArrayType =
1267       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1268 
1269   return getRaw(newArrayType, elementData, isSplat());
1270 }
1271 
1272 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1273 bool DenseIntElementsAttr::classof(Attribute attr) {
1274   return attr.isa<DenseElementsAttr>() &&
1275          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1276 }
1277 
1278 //===----------------------------------------------------------------------===//
1279 // OpaqueElementsAttr
1280 //===----------------------------------------------------------------------===//
1281 
1282 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
1283                                            StringRef bytes) {
1284   assert(TensorType::isValidElementType(type.getElementType()) &&
1285          "Input element type should be a valid tensor element type");
1286   return Base::get(type.getContext(), type, dialect, bytes);
1287 }
1288 
1289 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
1290 
1291 /// Return the value at the given index. If index does not refer to a valid
1292 /// element, then a null attribute is returned.
1293 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1294   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1295   return Attribute();
1296 }
1297 
1298 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
1299 
1300 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1301   auto *d = getDialect();
1302   if (!d)
1303     return true;
1304   auto *interface =
1305       d->getRegisteredInterface<DialectDecodeAttributesInterface>();
1306   if (!interface)
1307     return true;
1308   return failed(interface->decode(*this, result));
1309 }
1310 
1311 //===----------------------------------------------------------------------===//
1312 // SparseElementsAttr
1313 //===----------------------------------------------------------------------===//
1314 
1315 SparseElementsAttr SparseElementsAttr::get(ShapedType type,
1316                                            DenseElementsAttr indices,
1317                                            DenseElementsAttr values) {
1318   assert(indices.getType().getElementType().isInteger(64) &&
1319          "expected sparse indices to be 64-bit integer values");
1320   assert((type.isa<RankedTensorType, VectorType>()) &&
1321          "type must be ranked tensor or vector");
1322   assert(type.hasStaticShape() && "type must have static shape");
1323   return Base::get(type.getContext(), type,
1324                    indices.cast<DenseIntElementsAttr>(), values);
1325 }
1326 
1327 DenseIntElementsAttr SparseElementsAttr::getIndices() const {
1328   return getImpl()->indices;
1329 }
1330 
1331 DenseElementsAttr SparseElementsAttr::getValues() const {
1332   return getImpl()->values;
1333 }
1334 
1335 /// Return the value of the element at the given index.
1336 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1337   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1338   auto type = getType();
1339 
1340   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1341   // as a 1-D index array.
1342   auto sparseIndices = getIndices();
1343   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1344 
1345   // Check to see if the indices are a splat.
1346   if (sparseIndices.isSplat()) {
1347     // If the index is also not a splat of the index value, we know that the
1348     // value is zero.
1349     auto splatIndex = *sparseIndexValues.begin();
1350     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
1351       return getZeroAttr();
1352 
1353     // If the indices are a splat, we also expect the values to be a splat.
1354     assert(getValues().isSplat() && "expected splat values");
1355     return getValues().getSplatValue();
1356   }
1357 
1358   // Build a mapping between known indices and the offset of the stored element.
1359   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
1360   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1361   size_t rank = type.getRank();
1362   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1363     mappedIndices.try_emplace(
1364         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1365 
1366   // Look for the provided index key within the mapped indices. If the provided
1367   // index is not found, then return a zero attribute.
1368   auto it = mappedIndices.find(index);
1369   if (it == mappedIndices.end())
1370     return getZeroAttr();
1371 
1372   // Otherwise, return the held sparse value element.
1373   return getValues().getValue(it->second);
1374 }
1375 
1376 /// Get a zero APFloat for the given sparse attribute.
1377 APFloat SparseElementsAttr::getZeroAPFloat() const {
1378   auto eltType = getType().getElementType().cast<FloatType>();
1379   return APFloat(eltType.getFloatSemantics());
1380 }
1381 
1382 /// Get a zero APInt for the given sparse attribute.
1383 APInt SparseElementsAttr::getZeroAPInt() const {
1384   auto eltType = getType().getElementType().cast<IntegerType>();
1385   return APInt::getNullValue(eltType.getWidth());
1386 }
1387 
1388 /// Get a zero attribute for the given attribute type.
1389 Attribute SparseElementsAttr::getZeroAttr() const {
1390   auto eltType = getType().getElementType();
1391 
1392   // Handle floating point elements.
1393   if (eltType.isa<FloatType>())
1394     return FloatAttr::get(eltType, 0);
1395 
1396   // Otherwise, this is an integer.
1397   // TODO: Handle StringAttr here.
1398   return IntegerAttr::get(eltType, 0);
1399 }
1400 
1401 /// Flatten, and return, all of the sparse indices in this attribute in
1402 /// row-major order.
1403 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1404   std::vector<ptrdiff_t> flatSparseIndices;
1405 
1406   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1407   // as a 1-D index array.
1408   auto sparseIndices = getIndices();
1409   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1410   if (sparseIndices.isSplat()) {
1411     SmallVector<uint64_t, 8> indices(getType().getRank(),
1412                                      *sparseIndexValues.begin());
1413     flatSparseIndices.push_back(getFlattenedIndex(indices));
1414     return flatSparseIndices;
1415   }
1416 
1417   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1418   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1419   size_t rank = getType().getRank();
1420   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1421     flatSparseIndices.push_back(getFlattenedIndex(
1422         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1423   return flatSparseIndices;
1424 }
1425 
1426 //===----------------------------------------------------------------------===//
1427 // MutableDictionaryAttr
1428 //===----------------------------------------------------------------------===//
1429 
1430 MutableDictionaryAttr::MutableDictionaryAttr(
1431     ArrayRef<NamedAttribute> attributes) {
1432   setAttrs(attributes);
1433 }
1434 
1435 /// Return the underlying dictionary attribute.
1436 DictionaryAttr
1437 MutableDictionaryAttr::getDictionary(MLIRContext *context) const {
1438   // Construct empty DictionaryAttr if needed.
1439   if (!attrs)
1440     return DictionaryAttr::get({}, context);
1441   return attrs;
1442 }
1443 
1444 ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const {
1445   return attrs ? attrs.getValue() : llvm::None;
1446 }
1447 
1448 /// Replace the held attributes with ones provided in 'newAttrs'.
1449 void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) {
1450   // Don't create an attribute list if there are no attributes.
1451   if (attributes.empty())
1452     attrs = nullptr;
1453   else
1454     attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
1455 }
1456 
1457 /// Return the specified attribute if present, null otherwise.
1458 Attribute MutableDictionaryAttr::get(StringRef name) const {
1459   return attrs ? attrs.get(name) : nullptr;
1460 }
1461 
1462 /// Return the specified attribute if present, null otherwise.
1463 Attribute MutableDictionaryAttr::get(Identifier name) const {
1464   return attrs ? attrs.get(name) : nullptr;
1465 }
1466 
1467 /// Return the specified named attribute if present, None otherwise.
1468 Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const {
1469   return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1470 }
1471 Optional<NamedAttribute>
1472 MutableDictionaryAttr::getNamed(Identifier name) const {
1473   return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1474 }
1475 
1476 /// If the an attribute exists with the specified name, change it to the new
1477 /// value.  Otherwise, add a new attribute with the specified name/value.
1478 void MutableDictionaryAttr::set(Identifier name, Attribute value) {
1479   assert(value && "attributes may never be null");
1480 
1481   // Look for an existing value for the given name, and set it in-place.
1482   ArrayRef<NamedAttribute> values = getAttrs();
1483   const auto *it = llvm::find_if(
1484       values, [name](NamedAttribute attr) { return attr.first == name; });
1485   if (it != values.end()) {
1486     // Bail out early if the value is the same as what we already have.
1487     if (it->second == value)
1488       return;
1489 
1490     SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end());
1491     newAttrs[it - values.begin()].second = value;
1492     attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1493     return;
1494   }
1495 
1496   // Otherwise, insert the new attribute into its sorted position.
1497   it = llvm::lower_bound(values, name);
1498   SmallVector<NamedAttribute, 8> newAttrs;
1499   newAttrs.reserve(values.size() + 1);
1500   newAttrs.append(values.begin(), it);
1501   newAttrs.push_back({name, value});
1502   newAttrs.append(it, values.end());
1503   attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1504 }
1505 
1506 /// Remove the attribute with the specified name if it exists.  The return
1507 /// value indicates whether the attribute was present or not.
1508 auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult {
1509   auto origAttrs = getAttrs();
1510   for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
1511     if (origAttrs[i].first == name) {
1512       // Handle the simple case of removing the only attribute in the list.
1513       if (e == 1) {
1514         attrs = nullptr;
1515         return RemoveResult::Removed;
1516       }
1517 
1518       SmallVector<NamedAttribute, 8> newAttrs;
1519       newAttrs.reserve(origAttrs.size() - 1);
1520       newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
1521       newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
1522       attrs = DictionaryAttr::getWithSorted(newAttrs,
1523                                             newAttrs[0].second.getContext());
1524       return RemoveResult::Removed;
1525     }
1526   }
1527   return RemoveResult::NotFound;
1528 }
1529 
1530 bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) {
1531   return strcmp(lhs.first.data(), rhs.first.data()) < 0;
1532 }
1533 bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) {
1534   // This is correct even when attr.first.data()[name.size()] is not a zero
1535   // string terminator, because we only care about a less than comparison.
1536   // This can't use memcmp, because it doesn't guarantee that it will stop
1537   // reading both buffers if one is shorter than the other, even if there is
1538   // a difference.
1539   return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0;
1540 }
1541