1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Attributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Diagnostics.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/Function.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Types.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Endian.h"
20 
21 using namespace mlir;
22 using namespace mlir::detail;
23 
24 //===----------------------------------------------------------------------===//
25 // AttributeStorage
26 //===----------------------------------------------------------------------===//
27 
28 AttributeStorage::AttributeStorage(Type type)
29     : type(type.getAsOpaquePointer()) {}
30 AttributeStorage::AttributeStorage() : type(nullptr) {}
31 
32 Type AttributeStorage::getType() const {
33   return Type::getFromOpaquePointer(type);
34 }
35 void AttributeStorage::setType(Type newType) {
36   type = newType.getAsOpaquePointer();
37 }
38 
39 //===----------------------------------------------------------------------===//
40 // Attribute
41 //===----------------------------------------------------------------------===//
42 
43 /// Return the type of this attribute.
44 Type Attribute::getType() const { return impl->getType(); }
45 
46 /// Return the context this attribute belongs to.
47 MLIRContext *Attribute::getContext() const { return getType().getContext(); }
48 
49 /// Get the dialect this attribute is registered to.
50 Dialect &Attribute::getDialect() const { return impl->getDialect(); }
51 
52 //===----------------------------------------------------------------------===//
53 // AffineMapAttr
54 //===----------------------------------------------------------------------===//
55 
56 AffineMapAttr AffineMapAttr::get(AffineMap value) {
57   return Base::get(value.getContext(), StandardAttributes::AffineMap, value);
58 }
59 
60 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
61 
62 //===----------------------------------------------------------------------===//
63 // ArrayAttr
64 //===----------------------------------------------------------------------===//
65 
66 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
67   return Base::get(context, StandardAttributes::Array, value);
68 }
69 
70 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
71 
72 Attribute ArrayAttr::operator[](unsigned idx) const {
73   assert(idx < size() && "index out of bounds");
74   return getValue()[idx];
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // BoolAttr
79 //===----------------------------------------------------------------------===//
80 
81 bool BoolAttr::getValue() const { return getImpl()->value; }
82 
83 //===----------------------------------------------------------------------===//
84 // DictionaryAttr
85 //===----------------------------------------------------------------------===//
86 
87 /// Helper function that does either an in place sort or sorts from source array
88 /// into destination. If inPlace then storage is both the source and the
89 /// destination, else value is the source and storage destination. Returns
90 /// whether source was sorted.
91 template <bool inPlace>
92 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
93                                SmallVectorImpl<NamedAttribute> &storage) {
94   // Specialize for the common case.
95   switch (value.size()) {
96   case 0:
97     // Zero already sorted.
98     break;
99   case 1:
100     // One already sorted but may need to be copied.
101     if (!inPlace)
102       storage.assign({value[0]});
103     break;
104   case 2: {
105     assert(value[0].first != value[1].first &&
106            "DictionaryAttr element names must be unique");
107     bool isSorted = value[0] < value[1];
108     if (inPlace) {
109       if (!isSorted)
110         std::swap(storage[0], storage[1]);
111     } else if (isSorted) {
112       storage.assign({value[0], value[1]});
113     } else {
114       storage.assign({value[1], value[0]});
115     }
116     return !isSorted;
117   }
118   default:
119     if (!inPlace)
120       storage.assign(value.begin(), value.end());
121     // Check to see they are sorted already.
122     bool isSorted = llvm::is_sorted(value);
123     if (!isSorted) {
124       // If not, do a general sort.
125       llvm::array_pod_sort(storage.begin(), storage.end());
126       value = storage;
127     }
128 
129     // Ensure that the attribute elements are unique.
130     assert(std::adjacent_find(value.begin(), value.end(),
131                               [](NamedAttribute l, NamedAttribute r) {
132                                 return l.first == r.first;
133                               }) == value.end() &&
134            "DictionaryAttr element names must be unique");
135     return !isSorted;
136   }
137   return false;
138 }
139 
140 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
141                           SmallVectorImpl<NamedAttribute> &storage) {
142   return dictionaryAttrSort</*inPlace=*/false>(value, storage);
143 }
144 
145 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
146   return dictionaryAttrSort</*inPlace=*/true>(array, array);
147 }
148 
149 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
150                                    MLIRContext *context) {
151   if (value.empty())
152     return DictionaryAttr::getEmpty(context);
153   assert(llvm::all_of(value,
154                       [](const NamedAttribute &attr) { return attr.second; }) &&
155          "value cannot have null entries");
156 
157   // We need to sort the element list to canonicalize it.
158   SmallVector<NamedAttribute, 8> storage;
159   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
160     value = storage;
161 
162   return Base::get(context, StandardAttributes::Dictionary, value);
163 }
164 /// Construct a dictionary with an array of values that is known to already be
165 /// sorted by name and uniqued.
166 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
167                                              MLIRContext *context) {
168   if (value.empty())
169     return DictionaryAttr::getEmpty(context);
170   // Ensure that the attribute elements are unique and sorted.
171   assert(llvm::is_sorted(value,
172                          [](NamedAttribute l, NamedAttribute r) {
173                            return l.first.strref() < r.first.strref();
174                          }) &&
175          "expected attribute values to be sorted");
176   assert(std::adjacent_find(value.begin(), value.end(),
177                             [](NamedAttribute l, NamedAttribute r) {
178                               return l.first == r.first;
179                             }) == value.end() &&
180          "DictionaryAttr element names must be unique");
181   return Base::get(context, StandardAttributes::Dictionary, value);
182 }
183 
184 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
185   return getImpl()->getElements();
186 }
187 
188 /// Return the specified attribute if present, null otherwise.
189 Attribute DictionaryAttr::get(StringRef name) const {
190   Optional<NamedAttribute> attr = getNamed(name);
191   return attr ? attr->second : nullptr;
192 }
193 Attribute DictionaryAttr::get(Identifier name) const {
194   Optional<NamedAttribute> attr = getNamed(name);
195   return attr ? attr->second : nullptr;
196 }
197 
198 /// Return the specified named attribute if present, None otherwise.
199 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
200   ArrayRef<NamedAttribute> values = getValue();
201   const auto *it = llvm::lower_bound(values, name);
202   return it != values.end() && it->first == name ? *it
203                                                  : Optional<NamedAttribute>();
204 }
205 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
206   for (auto elt : getValue())
207     if (elt.first == name)
208       return elt;
209   return llvm::None;
210 }
211 
212 DictionaryAttr::iterator DictionaryAttr::begin() const {
213   return getValue().begin();
214 }
215 DictionaryAttr::iterator DictionaryAttr::end() const {
216   return getValue().end();
217 }
218 size_t DictionaryAttr::size() const { return getValue().size(); }
219 
220 //===----------------------------------------------------------------------===//
221 // FloatAttr
222 //===----------------------------------------------------------------------===//
223 
224 FloatAttr FloatAttr::get(Type type, double value) {
225   return Base::get(type.getContext(), StandardAttributes::Float, type, value);
226 }
227 
228 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
229   return Base::getChecked(loc, StandardAttributes::Float, type, value);
230 }
231 
232 FloatAttr FloatAttr::get(Type type, const APFloat &value) {
233   return Base::get(type.getContext(), StandardAttributes::Float, type, value);
234 }
235 
236 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
237   return Base::getChecked(loc, StandardAttributes::Float, type, value);
238 }
239 
240 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
241 
242 double FloatAttr::getValueAsDouble() const {
243   return getValueAsDouble(getValue());
244 }
245 double FloatAttr::getValueAsDouble(APFloat value) {
246   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
247     bool losesInfo = false;
248     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
249                   &losesInfo);
250   }
251   return value.convertToDouble();
252 }
253 
254 /// Verify construction invariants.
255 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
256   if (!type.isa<FloatType>())
257     return emitError(loc, "expected floating point type");
258   return success();
259 }
260 
261 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
262                                                       double value) {
263   return verifyFloatTypeInvariants(loc, type);
264 }
265 
266 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
267                                                       const APFloat &value) {
268   // Verify that the type is correct.
269   if (failed(verifyFloatTypeInvariants(loc, type)))
270     return failure();
271 
272   // Verify that the type semantics match that of the value.
273   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
274     return emitError(
275         loc, "FloatAttr type doesn't match the type implied by its value");
276   }
277   return success();
278 }
279 
280 //===----------------------------------------------------------------------===//
281 // SymbolRefAttr
282 //===----------------------------------------------------------------------===//
283 
284 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
285   return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None)
286       .cast<FlatSymbolRefAttr>();
287 }
288 
289 SymbolRefAttr SymbolRefAttr::get(StringRef value,
290                                  ArrayRef<FlatSymbolRefAttr> nestedReferences,
291                                  MLIRContext *ctx) {
292   return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences);
293 }
294 
295 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
296 
297 StringRef SymbolRefAttr::getLeafReference() const {
298   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
299   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
300 }
301 
302 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
303   return getImpl()->getNestedRefs();
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // IntegerAttr
308 //===----------------------------------------------------------------------===//
309 
310 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
311   return Base::get(type.getContext(), StandardAttributes::Integer, type, value);
312 }
313 
314 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
315   // This uses 64 bit APInts by default for index type.
316   if (type.isIndex())
317     return get(type, APInt(IndexType::kInternalStorageBitWidth, value));
318 
319   auto intType = type.cast<IntegerType>();
320   return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
321 }
322 
323 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
324 
325 int64_t IntegerAttr::getInt() const {
326   assert((getImpl()->getType().isIndex() ||
327           getImpl()->getType().isSignlessInteger()) &&
328          "must be signless integer");
329   return getValue().getSExtValue();
330 }
331 
332 int64_t IntegerAttr::getSInt() const {
333   assert(getImpl()->getType().isSignedInteger() && "must be signed integer");
334   return getValue().getSExtValue();
335 }
336 
337 uint64_t IntegerAttr::getUInt() const {
338   assert(getImpl()->getType().isUnsignedInteger() &&
339          "must be unsigned integer");
340   return getValue().getZExtValue();
341 }
342 
343 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
344   if (type.isa<IntegerType>() || type.isa<IndexType>())
345     return success();
346   return emitError(loc, "expected integer or index type");
347 }
348 
349 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
350                                                         int64_t value) {
351   return verifyIntegerTypeInvariants(loc, type);
352 }
353 
354 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
355                                                         const APInt &value) {
356   if (failed(verifyIntegerTypeInvariants(loc, type)))
357     return failure();
358   if (auto integerType = type.dyn_cast<IntegerType>())
359     if (integerType.getWidth() != value.getBitWidth())
360       return emitError(loc, "integer type bit width (")
361              << integerType.getWidth() << ") doesn't match value bit width ("
362              << value.getBitWidth() << ")";
363   return success();
364 }
365 
366 //===----------------------------------------------------------------------===//
367 // IntegerSetAttr
368 //===----------------------------------------------------------------------===//
369 
370 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
371   return Base::get(value.getConstraint(0).getContext(),
372                    StandardAttributes::IntegerSet, value);
373 }
374 
375 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
376 
377 //===----------------------------------------------------------------------===//
378 // OpaqueAttr
379 //===----------------------------------------------------------------------===//
380 
381 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
382                            MLIRContext *context) {
383   return Base::get(context, StandardAttributes::Opaque, dialect, attrData,
384                    type);
385 }
386 
387 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
388                                   Type type, Location location) {
389   return Base::getChecked(location, StandardAttributes::Opaque, dialect,
390                           attrData, type);
391 }
392 
393 /// Returns the dialect namespace of the opaque attribute.
394 Identifier OpaqueAttr::getDialectNamespace() const {
395   return getImpl()->dialectNamespace;
396 }
397 
398 /// Returns the raw attribute data of the opaque attribute.
399 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
400 
401 /// Verify the construction of an opaque attribute.
402 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
403                                                        Identifier dialect,
404                                                        StringRef attrData,
405                                                        Type type) {
406   if (!Dialect::isValidNamespace(dialect.strref()))
407     return emitError(loc, "invalid dialect namespace '") << dialect << "'";
408   return success();
409 }
410 
411 //===----------------------------------------------------------------------===//
412 // StringAttr
413 //===----------------------------------------------------------------------===//
414 
415 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
416   return get(bytes, NoneType::get(context));
417 }
418 
419 /// Get an instance of a StringAttr with the given string and Type.
420 StringAttr StringAttr::get(StringRef bytes, Type type) {
421   return Base::get(type.getContext(), StandardAttributes::String, bytes, type);
422 }
423 
424 StringRef StringAttr::getValue() const { return getImpl()->value; }
425 
426 //===----------------------------------------------------------------------===//
427 // TypeAttr
428 //===----------------------------------------------------------------------===//
429 
430 TypeAttr TypeAttr::get(Type value) {
431   return Base::get(value.getContext(), StandardAttributes::Type, value);
432 }
433 
434 Type TypeAttr::getValue() const { return getImpl()->value; }
435 
436 //===----------------------------------------------------------------------===//
437 // ElementsAttr
438 //===----------------------------------------------------------------------===//
439 
440 ShapedType ElementsAttr::getType() const {
441   return Attribute::getType().cast<ShapedType>();
442 }
443 
444 /// Returns the number of elements held by this attribute.
445 int64_t ElementsAttr::getNumElements() const {
446   return getType().getNumElements();
447 }
448 
449 /// Return the value at the given index. If index does not refer to a valid
450 /// element, then a null attribute is returned.
451 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
452   switch (getKind()) {
453   case StandardAttributes::DenseIntOrFPElements:
454     return cast<DenseElementsAttr>().getValue(index);
455   case StandardAttributes::OpaqueElements:
456     return cast<OpaqueElementsAttr>().getValue(index);
457   case StandardAttributes::SparseElements:
458     return cast<SparseElementsAttr>().getValue(index);
459   default:
460     llvm_unreachable("unknown ElementsAttr kind");
461   }
462 }
463 
464 /// Return if the given 'index' refers to a valid element in this attribute.
465 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
466   auto type = getType();
467 
468   // Verify that the rank of the indices matches the held type.
469   auto rank = type.getRank();
470   if (rank != static_cast<int64_t>(index.size()))
471     return false;
472 
473   // Verify that all of the indices are within the shape dimensions.
474   auto shape = type.getShape();
475   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
476     return static_cast<int64_t>(index[i]) < shape[i];
477   });
478 }
479 
480 ElementsAttr
481 ElementsAttr::mapValues(Type newElementType,
482                         function_ref<APInt(const APInt &)> mapping) const {
483   switch (getKind()) {
484   case StandardAttributes::DenseIntOrFPElements:
485     return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
486   default:
487     llvm_unreachable("unsupported ElementsAttr subtype");
488   }
489 }
490 
491 ElementsAttr
492 ElementsAttr::mapValues(Type newElementType,
493                         function_ref<APInt(const APFloat &)> mapping) const {
494   switch (getKind()) {
495   case StandardAttributes::DenseIntOrFPElements:
496     return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
497   default:
498     llvm_unreachable("unsupported ElementsAttr subtype");
499   }
500 }
501 
502 /// Returns the 1 dimensional flattened row-major index from the given
503 /// multi-dimensional index.
504 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
505   assert(isValidIndex(index) && "expected valid multi-dimensional index");
506   auto type = getType();
507 
508   // Reduce the provided multidimensional index into a flattended 1D row-major
509   // index.
510   auto rank = type.getRank();
511   auto shape = type.getShape();
512   uint64_t valueIndex = 0;
513   uint64_t dimMultiplier = 1;
514   for (int i = rank - 1; i >= 0; --i) {
515     valueIndex += index[i] * dimMultiplier;
516     dimMultiplier *= shape[i];
517   }
518   return valueIndex;
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // DenseElementAttr Utilities
523 //===----------------------------------------------------------------------===//
524 
525 /// Get the bitwidth of a dense element type within the buffer.
526 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
527 static size_t getDenseElementStorageWidth(size_t origWidth) {
528   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
529 }
530 static size_t getDenseElementStorageWidth(Type elementType) {
531   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
532 }
533 
534 /// Set a bit to a specific value.
535 static void setBit(char *rawData, size_t bitPos, bool value) {
536   if (value)
537     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
538   else
539     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
540 }
541 
542 /// Return the value of the specified bit.
543 static bool getBit(const char *rawData, size_t bitPos) {
544   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
545 }
546 
547 /// Get start position of actual data in `value`. Actual data is
548 /// stored in last `bitWidth`/CHAR_BIT bytes in big endian.
549 static char *getAPIntDataPos(APInt &value, size_t bitWidth) {
550   char *dataPos =
551       const_cast<char *>(reinterpret_cast<const char *>(value.getRawData()));
552   if (llvm::support::endian::system_endianness() ==
553       llvm::support::endianness::big)
554     dataPos = dataPos + 8 - llvm::divideCeil(bitWidth, CHAR_BIT);
555   return dataPos;
556 }
557 
558 /// Read APInt `value` from appropriate position.
559 static void readAPInt(APInt &value, size_t bitWidth, char *outData) {
560   char *dataPos = getAPIntDataPos(value, bitWidth);
561   std::copy_n(dataPos, llvm::divideCeil(bitWidth, CHAR_BIT), outData);
562 }
563 
564 /// Write `inData` to appropriate position of APInt `value`.
565 static void writeAPInt(const char *inData, size_t bitWidth, APInt &value) {
566   char *dataPos = getAPIntDataPos(value, bitWidth);
567   std::copy_n(inData, llvm::divideCeil(bitWidth, CHAR_BIT), dataPos);
568 }
569 
570 /// Writes value to the bit position `bitPos` in array `rawData`.
571 static void writeBits(char *rawData, size_t bitPos, APInt value) {
572   size_t bitWidth = value.getBitWidth();
573 
574   // If the bitwidth is 1 we just toggle the specific bit.
575   if (bitWidth == 1)
576     return setBit(rawData, bitPos, value.isOneValue());
577 
578   // Otherwise, the bit position is guaranteed to be byte aligned.
579   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
580   readAPInt(value, bitWidth, rawData + (bitPos / CHAR_BIT));
581 }
582 
583 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
584 /// `rawData`.
585 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
586   // Handle a boolean bit position.
587   if (bitWidth == 1)
588     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
589 
590   // Otherwise, the bit position must be 8-bit aligned.
591   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
592   APInt result(bitWidth, 0);
593   writeAPInt(rawData + (bitPos / CHAR_BIT), bitWidth, result);
594   return result;
595 }
596 
597 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the
598 /// same element count as 'type'.
599 template <typename Values>
600 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
601   return (values.size() == 1) ||
602          (type.getNumElements() == static_cast<int64_t>(values.size()));
603 }
604 
605 //===----------------------------------------------------------------------===//
606 // DenseElementAttr Iterators
607 //===----------------------------------------------------------------------===//
608 
609 //===----------------------------------------------------------------------===//
610 // AttributeElementIterator
611 
612 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
613     DenseElementsAttr attr, size_t index)
614     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
615                                       Attribute, Attribute, Attribute>(
616           attr.getAsOpaquePointer(), index) {}
617 
618 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
619   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
620   Type eltTy = owner.getType().getElementType();
621   if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) {
622     if (intEltTy.getWidth() == 1)
623       return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(),
624                            owner.getContext());
625     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
626   }
627   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
628     IntElementIterator intIt(owner, index);
629     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
630     return FloatAttr::get(eltTy, *floatIt);
631   }
632   if (owner.isa<DenseStringElementsAttr>()) {
633     ArrayRef<StringRef> vals = owner.getRawStringData();
634     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
635   }
636   llvm_unreachable("unexpected element type");
637 }
638 
639 //===----------------------------------------------------------------------===//
640 // BoolElementIterator
641 
642 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
643     DenseElementsAttr attr, size_t dataIndex)
644     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
645           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
646 
647 bool DenseElementsAttr::BoolElementIterator::operator*() const {
648   return getBit(getData(), getDataIndex());
649 }
650 
651 //===----------------------------------------------------------------------===//
652 // IntElementIterator
653 
654 DenseElementsAttr::IntElementIterator::IntElementIterator(
655     DenseElementsAttr attr, size_t dataIndex)
656     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
657           attr.getRawData().data(), attr.isSplat(), dataIndex),
658       bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
659 
660 APInt DenseElementsAttr::IntElementIterator::operator*() const {
661   return readBits(getData(),
662                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
663                   bitWidth);
664 }
665 
666 //===----------------------------------------------------------------------===//
667 // ComplexIntElementIterator
668 
669 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
670     DenseElementsAttr attr, size_t dataIndex)
671     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
672                                       std::complex<APInt>, std::complex<APInt>,
673                                       std::complex<APInt>>(
674           attr.getRawData().data(), attr.isSplat(), dataIndex) {
675   auto complexType = attr.getType().getElementType().cast<ComplexType>();
676   bitWidth = getDenseElementBitWidth(complexType.getElementType());
677 }
678 
679 std::complex<APInt>
680 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
681   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
682   size_t offset = getDataIndex() * storageWidth * 2;
683   return {readBits(getData(), offset, bitWidth),
684           readBits(getData(), offset + storageWidth, bitWidth)};
685 }
686 
687 //===----------------------------------------------------------------------===//
688 // FloatElementIterator
689 
690 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
691     const llvm::fltSemantics &smt, IntElementIterator it)
692     : llvm::mapped_iterator<IntElementIterator,
693                             std::function<APFloat(const APInt &)>>(
694           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
695 
696 //===----------------------------------------------------------------------===//
697 // ComplexFloatElementIterator
698 
699 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
700     const llvm::fltSemantics &smt, ComplexIntElementIterator it)
701     : llvm::mapped_iterator<
702           ComplexIntElementIterator,
703           std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
704           it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
705             return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
706           }) {}
707 
708 //===----------------------------------------------------------------------===//
709 // DenseElementsAttr
710 //===----------------------------------------------------------------------===//
711 
712 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
713                                          ArrayRef<Attribute> values) {
714   assert(hasSameElementsOrSplat(type, values));
715 
716   // If the element type is not based on int/float/index, assume it is a string
717   // type.
718   auto eltType = type.getElementType();
719   if (!type.getElementType().isIntOrIndexOrFloat()) {
720     SmallVector<StringRef, 8> stringValues;
721     stringValues.reserve(values.size());
722     for (Attribute attr : values) {
723       assert(attr.isa<StringAttr>() &&
724              "expected string value for non integer/index/float element");
725       stringValues.push_back(attr.cast<StringAttr>().getValue());
726     }
727     return get(type, stringValues);
728   }
729 
730   // Otherwise, get the raw storage width to use for the allocation.
731   size_t bitWidth = getDenseElementBitWidth(eltType);
732   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
733 
734   // Compress the attribute values into a character buffer.
735   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
736                             values.size());
737   APInt intVal;
738   for (unsigned i = 0, e = values.size(); i < e; ++i) {
739     assert(eltType == values[i].getType() &&
740            "expected attribute value to have element type");
741 
742     switch (eltType.getKind()) {
743     case StandardTypes::BF16:
744     case StandardTypes::F16:
745     case StandardTypes::F32:
746     case StandardTypes::F64:
747       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
748       break;
749     case StandardTypes::Integer:
750     case StandardTypes::Index:
751       intVal = values[i].isa<BoolAttr>()
752                    ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0)
753                    : values[i].cast<IntegerAttr>().getValue();
754       break;
755     default:
756       llvm_unreachable("unexpected element type");
757     }
758     assert(intVal.getBitWidth() == bitWidth &&
759            "expected value to have same bitwidth as element type");
760     writeBits(data.data(), i * storageBitWidth, intVal);
761   }
762   return DenseIntOrFPElementsAttr::getRaw(type, data,
763                                           /*isSplat=*/(values.size() == 1));
764 }
765 
766 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
767                                          ArrayRef<bool> values) {
768   assert(hasSameElementsOrSplat(type, values));
769   assert(type.getElementType().isInteger(1));
770 
771   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
772   for (int i = 0, e = values.size(); i != e; ++i)
773     setBit(buff.data(), i, values[i]);
774   return DenseIntOrFPElementsAttr::getRaw(type, buff,
775                                           /*isSplat=*/(values.size() == 1));
776 }
777 
778 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
779                                          ArrayRef<StringRef> values) {
780   assert(!type.getElementType().isIntOrFloat());
781   return DenseStringElementsAttr::get(type, values);
782 }
783 
784 /// Constructs a dense integer elements attribute from an array of APInt
785 /// values. Each APInt value is expected to have the same bitwidth as the
786 /// element type of 'type'.
787 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
788                                          ArrayRef<APInt> values) {
789   assert(type.getElementType().isIntOrIndex());
790   assert(hasSameElementsOrSplat(type, values));
791   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
792   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
793                                           /*isSplat=*/(values.size() == 1));
794 }
795 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
796                                          ArrayRef<std::complex<APInt>> values) {
797   ComplexType complex = type.getElementType().cast<ComplexType>();
798   assert(complex.getElementType().isa<IntegerType>());
799   assert(hasSameElementsOrSplat(type, values));
800   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
801   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
802                           values.size() * 2);
803   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
804                                           /*isSplat=*/(values.size() == 1));
805 }
806 
807 // Constructs a dense float elements attribute from an array of APFloat
808 // values. Each APFloat value is expected to have the same bitwidth as the
809 // element type of 'type'.
810 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
811                                          ArrayRef<APFloat> values) {
812   assert(type.getElementType().isa<FloatType>());
813   assert(hasSameElementsOrSplat(type, values));
814   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
815   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
816                                           /*isSplat=*/(values.size() == 1));
817 }
818 DenseElementsAttr
819 DenseElementsAttr::get(ShapedType type,
820                        ArrayRef<std::complex<APFloat>> values) {
821   ComplexType complex = type.getElementType().cast<ComplexType>();
822   assert(complex.getElementType().isa<FloatType>());
823   assert(hasSameElementsOrSplat(type, values));
824   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
825                            values.size() * 2);
826   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
827   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
828                                           /*isSplat=*/(values.size() == 1));
829 }
830 
831 /// Construct a dense elements attribute from a raw buffer representing the
832 /// data for this attribute. Users should generally not use this methods as
833 /// the expected buffer format may not be a form the user expects.
834 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
835                                                       ArrayRef<char> rawBuffer,
836                                                       bool isSplatBuffer) {
837   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
838 }
839 
840 /// Returns true if the given buffer is a valid raw buffer for the given type.
841 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
842                                          ArrayRef<char> rawBuffer,
843                                          bool &detectedSplat) {
844   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
845   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
846 
847   // Storage width of 1 is special as it is packed by the bit.
848   if (storageWidth == 1) {
849     // Check for a splat, or a buffer equal to the number of elements.
850     if ((detectedSplat = rawBuffer.size() == 1))
851       return true;
852     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
853   }
854   // All other types are 8-bit aligned.
855   if ((detectedSplat = rawBufferWidth == storageWidth))
856     return true;
857   return rawBufferWidth == (storageWidth * type.getNumElements());
858 }
859 
860 /// Check the information for a C++ data type, check if this type is valid for
861 /// the current attribute. This method is used to verify specific type
862 /// invariants that the templatized 'getValues' method cannot.
863 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
864                               bool isSigned) {
865   // Make sure that the data element size is the same as the type element width.
866   if (getDenseElementBitWidth(type) !=
867       static_cast<size_t>(dataEltSize * CHAR_BIT))
868     return false;
869 
870   // Check that the element type is either float or integer or index.
871   if (!isInt)
872     return type.isa<FloatType>();
873   if (type.isIndex())
874     return true;
875 
876   auto intType = type.dyn_cast<IntegerType>();
877   if (!intType)
878     return false;
879 
880   // Make sure signedness semantics is consistent.
881   if (intType.isSignless())
882     return true;
883   return intType.isSigned() ? isSigned : !isSigned;
884 }
885 
886 /// Defaults down the subclass implementation.
887 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
888                                                    ArrayRef<char> data,
889                                                    int64_t dataEltSize,
890                                                    bool isInt, bool isSigned) {
891   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
892                                                  isSigned);
893 }
894 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
895                                                       ArrayRef<char> data,
896                                                       int64_t dataEltSize,
897                                                       bool isInt,
898                                                       bool isSigned) {
899   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
900                                                     isInt, isSigned);
901 }
902 
903 /// A method used to verify specific type invariants that the templatized 'get'
904 /// method cannot.
905 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
906                                           bool isSigned) const {
907   return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
908                              isSigned);
909 }
910 
911 /// Check the information for a C++ data type, check if this type is valid for
912 /// the current attribute.
913 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
914                                        bool isSigned) const {
915   return ::isValidIntOrFloat(
916       getType().getElementType().cast<ComplexType>().getElementType(),
917       dataEltSize / 2, isInt, isSigned);
918 }
919 
920 /// Returns if this attribute corresponds to a splat, i.e. if all element
921 /// values are the same.
922 bool DenseElementsAttr::isSplat() const {
923   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
924 }
925 
926 /// Return the held element values as a range of Attributes.
927 auto DenseElementsAttr::getAttributeValues() const
928     -> llvm::iterator_range<AttributeElementIterator> {
929   return {attr_value_begin(), attr_value_end()};
930 }
931 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
932   return AttributeElementIterator(*this, 0);
933 }
934 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
935   return AttributeElementIterator(*this, getNumElements());
936 }
937 
938 /// Return the held element values as a range of bool. The element type of
939 /// this attribute must be of integer type of bitwidth 1.
940 auto DenseElementsAttr::getBoolValues() const
941     -> llvm::iterator_range<BoolElementIterator> {
942   auto eltType = getType().getElementType().dyn_cast<IntegerType>();
943   assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
944   (void)eltType;
945   return {BoolElementIterator(*this, 0),
946           BoolElementIterator(*this, getNumElements())};
947 }
948 
949 /// Return the held element values as a range of APInts. The element type of
950 /// this attribute must be of integer type.
951 auto DenseElementsAttr::getIntValues() const
952     -> llvm::iterator_range<IntElementIterator> {
953   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
954   return {raw_int_begin(), raw_int_end()};
955 }
956 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
957   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
958   return raw_int_begin();
959 }
960 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
961   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
962   return raw_int_end();
963 }
964 auto DenseElementsAttr::getComplexIntValues() const
965     -> llvm::iterator_range<ComplexIntElementIterator> {
966   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
967   (void)eltTy;
968   assert(eltTy.isa<IntegerType>() && "expected complex integral type");
969   return {ComplexIntElementIterator(*this, 0),
970           ComplexIntElementIterator(*this, getNumElements())};
971 }
972 
973 /// Return the held element values as a range of APFloat. The element type of
974 /// this attribute must be of float type.
975 auto DenseElementsAttr::getFloatValues() const
976     -> llvm::iterator_range<FloatElementIterator> {
977   auto elementType = getType().getElementType().cast<FloatType>();
978   const auto &elementSemantics = elementType.getFloatSemantics();
979   return {FloatElementIterator(elementSemantics, raw_int_begin()),
980           FloatElementIterator(elementSemantics, raw_int_end())};
981 }
982 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
983   return getFloatValues().begin();
984 }
985 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
986   return getFloatValues().end();
987 }
988 auto DenseElementsAttr::getComplexFloatValues() const
989     -> llvm::iterator_range<ComplexFloatElementIterator> {
990   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
991   assert(eltTy.isa<FloatType>() && "expected complex float type");
992   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
993   return {{semantics, {*this, 0}},
994           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
995 }
996 
997 /// Return the raw storage data held by this attribute.
998 ArrayRef<char> DenseElementsAttr::getRawData() const {
999   return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data;
1000 }
1001 
1002 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1003   return static_cast<DenseStringElementsAttributeStorage *>(impl)->data;
1004 }
1005 
1006 /// Return a new DenseElementsAttr that has the same data as the current
1007 /// attribute, but has been reshaped to 'newType'. The new type must have the
1008 /// same total number of elements as well as element type.
1009 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1010   ShapedType curType = getType();
1011   if (curType == newType)
1012     return *this;
1013 
1014   (void)curType;
1015   assert(newType.getElementType() == curType.getElementType() &&
1016          "expected the same element type");
1017   assert(newType.getNumElements() == curType.getNumElements() &&
1018          "expected the same number of elements");
1019   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
1020 }
1021 
1022 DenseElementsAttr
1023 DenseElementsAttr::mapValues(Type newElementType,
1024                              function_ref<APInt(const APInt &)> mapping) const {
1025   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1026 }
1027 
1028 DenseElementsAttr DenseElementsAttr::mapValues(
1029     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1030   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1031 }
1032 
1033 //===----------------------------------------------------------------------===//
1034 // DenseStringElementsAttr
1035 //===----------------------------------------------------------------------===//
1036 
1037 DenseStringElementsAttr
1038 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
1039   return Base::get(type.getContext(), StandardAttributes::DenseStringElements,
1040                    type, values, (values.size() == 1));
1041 }
1042 
1043 //===----------------------------------------------------------------------===//
1044 // DenseIntOrFPElementsAttr
1045 //===----------------------------------------------------------------------===//
1046 
1047 /// Utility method to write a range of APInt values to a buffer.
1048 template <typename APRangeT>
1049 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1050                                 APRangeT &&values) {
1051   data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
1052   size_t offset = 0;
1053   for (auto it = values.begin(), e = values.end(); it != e;
1054        ++it, offset += storageWidth) {
1055     assert((*it).getBitWidth() <= storageWidth);
1056     writeBits(data.data(), offset, *it);
1057   }
1058 }
1059 
1060 /// Constructs a dense elements attribute from an array of raw APFloat values.
1061 /// Each APFloat value is expected to have the same bitwidth as the element
1062 /// type of 'type'. 'type' must be a vector or tensor with static shape.
1063 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1064                                                    size_t storageWidth,
1065                                                    ArrayRef<APFloat> values,
1066                                                    bool isSplat) {
1067   std::vector<char> data;
1068   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1069   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1070   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1071 }
1072 
1073 /// Constructs a dense elements attribute from an array of raw APInt values.
1074 /// Each APInt value is expected to have the same bitwidth as the element type
1075 /// of 'type'.
1076 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1077                                                    size_t storageWidth,
1078                                                    ArrayRef<APInt> values,
1079                                                    bool isSplat) {
1080   std::vector<char> data;
1081   writeAPIntsToBuffer(storageWidth, data, values);
1082   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1083 }
1084 
1085 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1086                                                    ArrayRef<char> data,
1087                                                    bool isSplat) {
1088   assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
1089          "type must be ranked tensor or vector");
1090   assert(type.hasStaticShape() && "type must have static shape");
1091   return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements,
1092                    type, data, isSplat);
1093 }
1094 
1095 /// Overload of the raw 'get' method that asserts that the given type is of
1096 /// complex type. This method is used to verify type invariants that the
1097 /// templatized 'get' method cannot.
1098 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1099                                                           ArrayRef<char> data,
1100                                                           int64_t dataEltSize,
1101                                                           bool isInt,
1102                                                           bool isSigned) {
1103   assert(::isValidIntOrFloat(
1104       type.getElementType().cast<ComplexType>().getElementType(),
1105       dataEltSize / 2, isInt, isSigned));
1106 
1107   int64_t numElements = data.size() / dataEltSize;
1108   assert(numElements == 1 || numElements == type.getNumElements());
1109   return getRaw(type, data, /*isSplat=*/numElements == 1);
1110 }
1111 
1112 /// Overload of the 'getRaw' method that asserts that the given type is of
1113 /// integer type. This method is used to verify type invariants that the
1114 /// templatized 'get' method cannot.
1115 DenseElementsAttr
1116 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1117                                            int64_t dataEltSize, bool isInt,
1118                                            bool isSigned) {
1119   assert(
1120       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1121 
1122   int64_t numElements = data.size() / dataEltSize;
1123   assert(numElements == 1 || numElements == type.getNumElements());
1124   return getRaw(type, data, /*isSplat=*/numElements == 1);
1125 }
1126 
1127 //===----------------------------------------------------------------------===//
1128 // DenseFPElementsAttr
1129 //===----------------------------------------------------------------------===//
1130 
1131 template <typename Fn, typename Attr>
1132 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1133                                 Type newElementType,
1134                                 llvm::SmallVectorImpl<char> &data) {
1135   size_t bitWidth = getDenseElementBitWidth(newElementType);
1136   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1137 
1138   ShapedType newArrayType;
1139   if (inType.isa<RankedTensorType>())
1140     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1141   else if (inType.isa<UnrankedTensorType>())
1142     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1143   else if (inType.isa<VectorType>())
1144     newArrayType = VectorType::get(inType.getShape(), newElementType);
1145   else
1146     assert(newArrayType && "Unhandled tensor type");
1147 
1148   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1149   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1150 
1151   // Functor used to process a single element value of the attribute.
1152   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1153     auto newInt = mapping(value);
1154     assert(newInt.getBitWidth() == bitWidth);
1155     writeBits(data.data(), index * storageBitWidth, newInt);
1156   };
1157 
1158   // Check for the splat case.
1159   if (attr.isSplat()) {
1160     processElt(*attr.begin(), /*index=*/0);
1161     return newArrayType;
1162   }
1163 
1164   // Otherwise, process all of the element values.
1165   uint64_t elementIdx = 0;
1166   for (auto value : attr)
1167     processElt(value, elementIdx++);
1168   return newArrayType;
1169 }
1170 
1171 DenseElementsAttr DenseFPElementsAttr::mapValues(
1172     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1173   llvm::SmallVector<char, 8> elementData;
1174   auto newArrayType =
1175       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1176 
1177   return getRaw(newArrayType, elementData, isSplat());
1178 }
1179 
1180 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1181 bool DenseFPElementsAttr::classof(Attribute attr) {
1182   return attr.isa<DenseElementsAttr>() &&
1183          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1184 }
1185 
1186 //===----------------------------------------------------------------------===//
1187 // DenseIntElementsAttr
1188 //===----------------------------------------------------------------------===//
1189 
1190 DenseElementsAttr DenseIntElementsAttr::mapValues(
1191     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1192   llvm::SmallVector<char, 8> elementData;
1193   auto newArrayType =
1194       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1195 
1196   return getRaw(newArrayType, elementData, isSplat());
1197 }
1198 
1199 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1200 bool DenseIntElementsAttr::classof(Attribute attr) {
1201   return attr.isa<DenseElementsAttr>() &&
1202          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1203 }
1204 
1205 //===----------------------------------------------------------------------===//
1206 // OpaqueElementsAttr
1207 //===----------------------------------------------------------------------===//
1208 
1209 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
1210                                            StringRef bytes) {
1211   assert(TensorType::isValidElementType(type.getElementType()) &&
1212          "Input element type should be a valid tensor element type");
1213   return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type,
1214                    dialect, bytes);
1215 }
1216 
1217 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
1218 
1219 /// Return the value at the given index. If index does not refer to a valid
1220 /// element, then a null attribute is returned.
1221 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1222   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1223   if (Dialect *dialect = getDialect())
1224     return dialect->extractElementHook(*this, index);
1225   return Attribute();
1226 }
1227 
1228 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
1229 
1230 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1231   if (auto *d = getDialect())
1232     return d->decodeHook(*this, result);
1233   return true;
1234 }
1235 
1236 //===----------------------------------------------------------------------===//
1237 // SparseElementsAttr
1238 //===----------------------------------------------------------------------===//
1239 
1240 SparseElementsAttr SparseElementsAttr::get(ShapedType type,
1241                                            DenseElementsAttr indices,
1242                                            DenseElementsAttr values) {
1243   assert(indices.getType().getElementType().isInteger(64) &&
1244          "expected sparse indices to be 64-bit integer values");
1245   assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
1246          "type must be ranked tensor or vector");
1247   assert(type.hasStaticShape() && "type must have static shape");
1248   return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
1249                    indices.cast<DenseIntElementsAttr>(), values);
1250 }
1251 
1252 DenseIntElementsAttr SparseElementsAttr::getIndices() const {
1253   return getImpl()->indices;
1254 }
1255 
1256 DenseElementsAttr SparseElementsAttr::getValues() const {
1257   return getImpl()->values;
1258 }
1259 
1260 /// Return the value of the element at the given index.
1261 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1262   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1263   auto type = getType();
1264 
1265   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1266   // as a 1-D index array.
1267   auto sparseIndices = getIndices();
1268   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1269 
1270   // Check to see if the indices are a splat.
1271   if (sparseIndices.isSplat()) {
1272     // If the index is also not a splat of the index value, we know that the
1273     // value is zero.
1274     auto splatIndex = *sparseIndexValues.begin();
1275     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
1276       return getZeroAttr();
1277 
1278     // If the indices are a splat, we also expect the values to be a splat.
1279     assert(getValues().isSplat() && "expected splat values");
1280     return getValues().getSplatValue();
1281   }
1282 
1283   // Build a mapping between known indices and the offset of the stored element.
1284   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
1285   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1286   size_t rank = type.getRank();
1287   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1288     mappedIndices.try_emplace(
1289         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1290 
1291   // Look for the provided index key within the mapped indices. If the provided
1292   // index is not found, then return a zero attribute.
1293   auto it = mappedIndices.find(index);
1294   if (it == mappedIndices.end())
1295     return getZeroAttr();
1296 
1297   // Otherwise, return the held sparse value element.
1298   return getValues().getValue(it->second);
1299 }
1300 
1301 /// Get a zero APFloat for the given sparse attribute.
1302 APFloat SparseElementsAttr::getZeroAPFloat() const {
1303   auto eltType = getType().getElementType().cast<FloatType>();
1304   return APFloat(eltType.getFloatSemantics());
1305 }
1306 
1307 /// Get a zero APInt for the given sparse attribute.
1308 APInt SparseElementsAttr::getZeroAPInt() const {
1309   auto eltType = getType().getElementType().cast<IntegerType>();
1310   return APInt::getNullValue(eltType.getWidth());
1311 }
1312 
1313 /// Get a zero attribute for the given attribute type.
1314 Attribute SparseElementsAttr::getZeroAttr() const {
1315   auto eltType = getType().getElementType();
1316 
1317   // Handle floating point elements.
1318   if (eltType.isa<FloatType>())
1319     return FloatAttr::get(eltType, 0);
1320 
1321   // Otherwise, this is an integer.
1322   auto intEltTy = eltType.cast<IntegerType>();
1323   if (intEltTy.getWidth() == 1)
1324     return BoolAttr::get(false, eltType.getContext());
1325   return IntegerAttr::get(eltType, 0);
1326 }
1327 
1328 /// Flatten, and return, all of the sparse indices in this attribute in
1329 /// row-major order.
1330 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1331   std::vector<ptrdiff_t> flatSparseIndices;
1332 
1333   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1334   // as a 1-D index array.
1335   auto sparseIndices = getIndices();
1336   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1337   if (sparseIndices.isSplat()) {
1338     SmallVector<uint64_t, 8> indices(getType().getRank(),
1339                                      *sparseIndexValues.begin());
1340     flatSparseIndices.push_back(getFlattenedIndex(indices));
1341     return flatSparseIndices;
1342   }
1343 
1344   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1345   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1346   size_t rank = getType().getRank();
1347   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1348     flatSparseIndices.push_back(getFlattenedIndex(
1349         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1350   return flatSparseIndices;
1351 }
1352 
1353 //===----------------------------------------------------------------------===//
1354 // MutableDictionaryAttr
1355 //===----------------------------------------------------------------------===//
1356 
1357 MutableDictionaryAttr::MutableDictionaryAttr(
1358     ArrayRef<NamedAttribute> attributes) {
1359   setAttrs(attributes);
1360 }
1361 
1362 /// Return the underlying dictionary attribute.
1363 DictionaryAttr
1364 MutableDictionaryAttr::getDictionary(MLIRContext *context) const {
1365   // Construct empty DictionaryAttr if needed.
1366   if (!attrs)
1367     return DictionaryAttr::get({}, context);
1368   return attrs;
1369 }
1370 
1371 ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const {
1372   return attrs ? attrs.getValue() : llvm::None;
1373 }
1374 
1375 /// Replace the held attributes with ones provided in 'newAttrs'.
1376 void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) {
1377   // Don't create an attribute list if there are no attributes.
1378   if (attributes.empty())
1379     attrs = nullptr;
1380   else
1381     attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
1382 }
1383 
1384 /// Return the specified attribute if present, null otherwise.
1385 Attribute MutableDictionaryAttr::get(StringRef name) const {
1386   return attrs ? attrs.get(name) : nullptr;
1387 }
1388 
1389 /// Return the specified attribute if present, null otherwise.
1390 Attribute MutableDictionaryAttr::get(Identifier name) const {
1391   return attrs ? attrs.get(name) : nullptr;
1392 }
1393 
1394 /// Return the specified named attribute if present, None otherwise.
1395 Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const {
1396   return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1397 }
1398 Optional<NamedAttribute>
1399 MutableDictionaryAttr::getNamed(Identifier name) const {
1400   return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1401 }
1402 
1403 /// If the an attribute exists with the specified name, change it to the new
1404 /// value.  Otherwise, add a new attribute with the specified name/value.
1405 void MutableDictionaryAttr::set(Identifier name, Attribute value) {
1406   assert(value && "attributes may never be null");
1407 
1408   // Look for an existing value for the given name, and set it in-place.
1409   ArrayRef<NamedAttribute> values = getAttrs();
1410   const auto *it = llvm::find_if(
1411       values, [name](NamedAttribute attr) { return attr.first == name; });
1412   if (it != values.end()) {
1413     // Bail out early if the value is the same as what we already have.
1414     if (it->second == value)
1415       return;
1416 
1417     SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end());
1418     newAttrs[it - values.begin()].second = value;
1419     attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1420     return;
1421   }
1422 
1423   // Otherwise, insert the new attribute into its sorted position.
1424   it = llvm::lower_bound(values, name);
1425   SmallVector<NamedAttribute, 8> newAttrs;
1426   newAttrs.reserve(values.size() + 1);
1427   newAttrs.append(values.begin(), it);
1428   newAttrs.push_back({name, value});
1429   newAttrs.append(it, values.end());
1430   attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1431 }
1432 
1433 /// Remove the attribute with the specified name if it exists.  The return
1434 /// value indicates whether the attribute was present or not.
1435 auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult {
1436   auto origAttrs = getAttrs();
1437   for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
1438     if (origAttrs[i].first == name) {
1439       // Handle the simple case of removing the only attribute in the list.
1440       if (e == 1) {
1441         attrs = nullptr;
1442         return RemoveResult::Removed;
1443       }
1444 
1445       SmallVector<NamedAttribute, 8> newAttrs;
1446       newAttrs.reserve(origAttrs.size() - 1);
1447       newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
1448       newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
1449       attrs = DictionaryAttr::getWithSorted(newAttrs,
1450                                             newAttrs[0].second.getContext());
1451       return RemoveResult::Removed;
1452     }
1453   }
1454   return RemoveResult::NotFound;
1455 }
1456 
1457 bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) {
1458   return strcmp(lhs.first.data(), rhs.first.data()) < 0;
1459 }
1460 bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) {
1461   // This is correct even when attr.first.data()[name.size()] is not a zero
1462   // string terminator, because we only care about a less than comparison.
1463   // This can't use memcmp, because it doesn't guarantee that it will stop
1464   // reading both buffers if one is shorter than the other, even if there is
1465   // a difference.
1466   return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0;
1467 }
1468