1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Attributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Diagnostics.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/Function.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Types.h"
17 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/Twine.h"
20 #include "llvm/Support/Endian.h"
21 
22 using namespace mlir;
23 using namespace mlir::detail;
24 
25 //===----------------------------------------------------------------------===//
26 // AttributeStorage
27 //===----------------------------------------------------------------------===//
28 
29 AttributeStorage::AttributeStorage(Type type)
30     : type(type.getAsOpaquePointer()) {}
31 AttributeStorage::AttributeStorage() : type(nullptr) {}
32 
33 Type AttributeStorage::getType() const {
34   return Type::getFromOpaquePointer(type);
35 }
36 void AttributeStorage::setType(Type newType) {
37   type = newType.getAsOpaquePointer();
38 }
39 
40 //===----------------------------------------------------------------------===//
41 // Attribute
42 //===----------------------------------------------------------------------===//
43 
44 /// Return the type of this attribute.
45 Type Attribute::getType() const { return impl->getType(); }
46 
47 /// Return the context this attribute belongs to.
48 MLIRContext *Attribute::getContext() const { return getType().getContext(); }
49 
50 /// Get the dialect this attribute is registered to.
51 Dialect &Attribute::getDialect() const {
52   return impl->getAbstractAttribute().getDialect();
53 }
54 
55 //===----------------------------------------------------------------------===//
56 // AffineMapAttr
57 //===----------------------------------------------------------------------===//
58 
59 AffineMapAttr AffineMapAttr::get(AffineMap value) {
60   return Base::get(value.getContext(), value);
61 }
62 
63 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
64 
65 //===----------------------------------------------------------------------===//
66 // ArrayAttr
67 //===----------------------------------------------------------------------===//
68 
69 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
70   return Base::get(context, value);
71 }
72 
73 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
74 
75 Attribute ArrayAttr::operator[](unsigned idx) const {
76   assert(idx < size() && "index out of bounds");
77   return getValue()[idx];
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // DictionaryAttr
82 //===----------------------------------------------------------------------===//
83 
84 /// Helper function that does either an in place sort or sorts from source array
85 /// into destination. If inPlace then storage is both the source and the
86 /// destination, else value is the source and storage destination. Returns
87 /// whether source was sorted.
88 template <bool inPlace>
89 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
90                                SmallVectorImpl<NamedAttribute> &storage) {
91   // Specialize for the common case.
92   switch (value.size()) {
93   case 0:
94     // Zero already sorted.
95     break;
96   case 1:
97     // One already sorted but may need to be copied.
98     if (!inPlace)
99       storage.assign({value[0]});
100     break;
101   case 2: {
102     assert(value[0].first != value[1].first &&
103            "DictionaryAttr element names must be unique");
104     bool isSorted = value[0] < value[1];
105     if (inPlace) {
106       if (!isSorted)
107         std::swap(storage[0], storage[1]);
108     } else if (isSorted) {
109       storage.assign({value[0], value[1]});
110     } else {
111       storage.assign({value[1], value[0]});
112     }
113     return !isSorted;
114   }
115   default:
116     if (!inPlace)
117       storage.assign(value.begin(), value.end());
118     // Check to see they are sorted already.
119     bool isSorted = llvm::is_sorted(value);
120     if (!isSorted) {
121       // If not, do a general sort.
122       llvm::array_pod_sort(storage.begin(), storage.end());
123       value = storage;
124     }
125 
126     // Ensure that the attribute elements are unique.
127     assert(std::adjacent_find(value.begin(), value.end(),
128                               [](NamedAttribute l, NamedAttribute r) {
129                                 return l.first == r.first;
130                               }) == value.end() &&
131            "DictionaryAttr element names must be unique");
132     return !isSorted;
133   }
134   return false;
135 }
136 
137 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
138                           SmallVectorImpl<NamedAttribute> &storage) {
139   return dictionaryAttrSort</*inPlace=*/false>(value, storage);
140 }
141 
142 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
143   return dictionaryAttrSort</*inPlace=*/true>(array, array);
144 }
145 
146 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
147                                    MLIRContext *context) {
148   if (value.empty())
149     return DictionaryAttr::getEmpty(context);
150   assert(llvm::all_of(value,
151                       [](const NamedAttribute &attr) { return attr.second; }) &&
152          "value cannot have null entries");
153 
154   // We need to sort the element list to canonicalize it.
155   SmallVector<NamedAttribute, 8> storage;
156   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
157     value = storage;
158 
159   return Base::get(context, value);
160 }
161 /// Construct a dictionary with an array of values that is known to already be
162 /// sorted by name and uniqued.
163 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
164                                              MLIRContext *context) {
165   if (value.empty())
166     return DictionaryAttr::getEmpty(context);
167   // Ensure that the attribute elements are unique and sorted.
168   assert(llvm::is_sorted(value,
169                          [](NamedAttribute l, NamedAttribute r) {
170                            return l.first.strref() < r.first.strref();
171                          }) &&
172          "expected attribute values to be sorted");
173   assert(std::adjacent_find(value.begin(), value.end(),
174                             [](NamedAttribute l, NamedAttribute r) {
175                               return l.first == r.first;
176                             }) == value.end() &&
177          "DictionaryAttr element names must be unique");
178   return Base::get(context, value);
179 }
180 
181 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
182   return getImpl()->getElements();
183 }
184 
185 /// Return the specified attribute if present, null otherwise.
186 Attribute DictionaryAttr::get(StringRef name) const {
187   Optional<NamedAttribute> attr = getNamed(name);
188   return attr ? attr->second : nullptr;
189 }
190 Attribute DictionaryAttr::get(Identifier name) const {
191   Optional<NamedAttribute> attr = getNamed(name);
192   return attr ? attr->second : nullptr;
193 }
194 
195 /// Return the specified named attribute if present, None otherwise.
196 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
197   ArrayRef<NamedAttribute> values = getValue();
198   const auto *it = llvm::lower_bound(values, name);
199   return it != values.end() && it->first == name ? *it
200                                                  : Optional<NamedAttribute>();
201 }
202 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
203   for (auto elt : getValue())
204     if (elt.first == name)
205       return elt;
206   return llvm::None;
207 }
208 
209 DictionaryAttr::iterator DictionaryAttr::begin() const {
210   return getValue().begin();
211 }
212 DictionaryAttr::iterator DictionaryAttr::end() const {
213   return getValue().end();
214 }
215 size_t DictionaryAttr::size() const { return getValue().size(); }
216 
217 //===----------------------------------------------------------------------===//
218 // FloatAttr
219 //===----------------------------------------------------------------------===//
220 
221 FloatAttr FloatAttr::get(Type type, double value) {
222   return Base::get(type.getContext(), type, value);
223 }
224 
225 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
226   return Base::getChecked(loc, type, value);
227 }
228 
229 FloatAttr FloatAttr::get(Type type, const APFloat &value) {
230   return Base::get(type.getContext(), type, value);
231 }
232 
233 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
234   return Base::getChecked(loc, type, value);
235 }
236 
237 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
238 
239 double FloatAttr::getValueAsDouble() const {
240   return getValueAsDouble(getValue());
241 }
242 double FloatAttr::getValueAsDouble(APFloat value) {
243   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
244     bool losesInfo = false;
245     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
246                   &losesInfo);
247   }
248   return value.convertToDouble();
249 }
250 
251 /// Verify construction invariants.
252 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
253   if (!type.isa<FloatType>())
254     return emitError(loc, "expected floating point type");
255   return success();
256 }
257 
258 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
259                                                       double value) {
260   return verifyFloatTypeInvariants(loc, type);
261 }
262 
263 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
264                                                       const APFloat &value) {
265   // Verify that the type is correct.
266   if (failed(verifyFloatTypeInvariants(loc, type)))
267     return failure();
268 
269   // Verify that the type semantics match that of the value.
270   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
271     return emitError(
272         loc, "FloatAttr type doesn't match the type implied by its value");
273   }
274   return success();
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // SymbolRefAttr
279 //===----------------------------------------------------------------------===//
280 
281 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
282   return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
283 }
284 
285 SymbolRefAttr SymbolRefAttr::get(StringRef value,
286                                  ArrayRef<FlatSymbolRefAttr> nestedReferences,
287                                  MLIRContext *ctx) {
288   return Base::get(ctx, value, nestedReferences);
289 }
290 
291 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
292 
293 StringRef SymbolRefAttr::getLeafReference() const {
294   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
295   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
296 }
297 
298 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
299   return getImpl()->getNestedRefs();
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // IntegerAttr
304 //===----------------------------------------------------------------------===//
305 
306 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
307   if (type.isSignlessInteger(1))
308     return BoolAttr::get(value.getBoolValue(), type.getContext());
309   return Base::get(type.getContext(), type, value);
310 }
311 
312 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
313   // This uses 64 bit APInts by default for index type.
314   if (type.isIndex())
315     return get(type, APInt(IndexType::kInternalStorageBitWidth, value));
316 
317   auto intType = type.cast<IntegerType>();
318   return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
319 }
320 
321 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
322 
323 int64_t IntegerAttr::getInt() const {
324   assert((getImpl()->getType().isIndex() ||
325           getImpl()->getType().isSignlessInteger()) &&
326          "must be signless integer");
327   return getValue().getSExtValue();
328 }
329 
330 int64_t IntegerAttr::getSInt() const {
331   assert(getImpl()->getType().isSignedInteger() && "must be signed integer");
332   return getValue().getSExtValue();
333 }
334 
335 uint64_t IntegerAttr::getUInt() const {
336   assert(getImpl()->getType().isUnsignedInteger() &&
337          "must be unsigned integer");
338   return getValue().getZExtValue();
339 }
340 
341 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
342   if (type.isa<IntegerType, IndexType>())
343     return success();
344   return emitError(loc, "expected integer or index type");
345 }
346 
347 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
348                                                         int64_t value) {
349   return verifyIntegerTypeInvariants(loc, type);
350 }
351 
352 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
353                                                         const APInt &value) {
354   if (failed(verifyIntegerTypeInvariants(loc, type)))
355     return failure();
356   if (auto integerType = type.dyn_cast<IntegerType>())
357     if (integerType.getWidth() != value.getBitWidth())
358       return emitError(loc, "integer type bit width (")
359              << integerType.getWidth() << ") doesn't match value bit width ("
360              << value.getBitWidth() << ")";
361   return success();
362 }
363 
364 //===----------------------------------------------------------------------===//
365 // BoolAttr
366 
367 bool BoolAttr::getValue() const {
368   auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl);
369   return storage->getValue().getBoolValue();
370 }
371 
372 bool BoolAttr::classof(Attribute attr) {
373   IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
374   return intAttr && intAttr.getType().isSignlessInteger(1);
375 }
376 
377 //===----------------------------------------------------------------------===//
378 // IntegerSetAttr
379 //===----------------------------------------------------------------------===//
380 
381 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
382   return Base::get(value.getConstraint(0).getContext(), value);
383 }
384 
385 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
386 
387 //===----------------------------------------------------------------------===//
388 // OpaqueAttr
389 //===----------------------------------------------------------------------===//
390 
391 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
392                            MLIRContext *context) {
393   return Base::get(context, dialect, attrData, type);
394 }
395 
396 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
397                                   Type type, Location location) {
398   return Base::getChecked(location, dialect, attrData, type);
399 }
400 
401 /// Returns the dialect namespace of the opaque attribute.
402 Identifier OpaqueAttr::getDialectNamespace() const {
403   return getImpl()->dialectNamespace;
404 }
405 
406 /// Returns the raw attribute data of the opaque attribute.
407 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
408 
409 /// Verify the construction of an opaque attribute.
410 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
411                                                        Identifier dialect,
412                                                        StringRef attrData,
413                                                        Type type) {
414   if (!Dialect::isValidNamespace(dialect.strref()))
415     return emitError(loc, "invalid dialect namespace '") << dialect << "'";
416   return success();
417 }
418 
419 //===----------------------------------------------------------------------===//
420 // StringAttr
421 //===----------------------------------------------------------------------===//
422 
423 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
424   return get(bytes, NoneType::get(context));
425 }
426 
427 /// Get an instance of a StringAttr with the given string and Type.
428 StringAttr StringAttr::get(StringRef bytes, Type type) {
429   return Base::get(type.getContext(), bytes, type);
430 }
431 
432 StringRef StringAttr::getValue() const { return getImpl()->value; }
433 
434 //===----------------------------------------------------------------------===//
435 // TypeAttr
436 //===----------------------------------------------------------------------===//
437 
438 TypeAttr TypeAttr::get(Type value) {
439   return Base::get(value.getContext(), value);
440 }
441 
442 Type TypeAttr::getValue() const { return getImpl()->value; }
443 
444 //===----------------------------------------------------------------------===//
445 // ElementsAttr
446 //===----------------------------------------------------------------------===//
447 
448 ShapedType ElementsAttr::getType() const {
449   return Attribute::getType().cast<ShapedType>();
450 }
451 
452 /// Returns the number of elements held by this attribute.
453 int64_t ElementsAttr::getNumElements() const {
454   return getType().getNumElements();
455 }
456 
457 /// Return the value at the given index. If index does not refer to a valid
458 /// element, then a null attribute is returned.
459 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
460   if (auto denseAttr = dyn_cast<DenseElementsAttr>())
461     return denseAttr.getValue(index);
462   if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>())
463     return opaqueAttr.getValue(index);
464   return cast<SparseElementsAttr>().getValue(index);
465 }
466 
467 /// Return if the given 'index' refers to a valid element in this attribute.
468 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
469   auto type = getType();
470 
471   // Verify that the rank of the indices matches the held type.
472   auto rank = type.getRank();
473   if (rank != static_cast<int64_t>(index.size()))
474     return false;
475 
476   // Verify that all of the indices are within the shape dimensions.
477   auto shape = type.getShape();
478   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
479     return static_cast<int64_t>(index[i]) < shape[i];
480   });
481 }
482 
483 ElementsAttr
484 ElementsAttr::mapValues(Type newElementType,
485                         function_ref<APInt(const APInt &)> mapping) const {
486   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
487     return intOrFpAttr.mapValues(newElementType, mapping);
488   llvm_unreachable("unsupported ElementsAttr subtype");
489 }
490 
491 ElementsAttr
492 ElementsAttr::mapValues(Type newElementType,
493                         function_ref<APInt(const APFloat &)> mapping) const {
494   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
495     return intOrFpAttr.mapValues(newElementType, mapping);
496   llvm_unreachable("unsupported ElementsAttr subtype");
497 }
498 
499 /// Method for support type inquiry through isa, cast and dyn_cast.
500 bool ElementsAttr::classof(Attribute attr) {
501   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr,
502                   OpaqueElementsAttr, SparseElementsAttr>();
503 }
504 
505 /// Returns the 1 dimensional flattened row-major index from the given
506 /// multi-dimensional index.
507 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
508   assert(isValidIndex(index) && "expected valid multi-dimensional index");
509   auto type = getType();
510 
511   // Reduce the provided multidimensional index into a flattended 1D row-major
512   // index.
513   auto rank = type.getRank();
514   auto shape = type.getShape();
515   uint64_t valueIndex = 0;
516   uint64_t dimMultiplier = 1;
517   for (int i = rank - 1; i >= 0; --i) {
518     valueIndex += index[i] * dimMultiplier;
519     dimMultiplier *= shape[i];
520   }
521   return valueIndex;
522 }
523 
524 //===----------------------------------------------------------------------===//
525 // DenseElementsAttr Utilities
526 //===----------------------------------------------------------------------===//
527 
528 /// Get the bitwidth of a dense element type within the buffer.
529 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
530 static size_t getDenseElementStorageWidth(size_t origWidth) {
531   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
532 }
533 static size_t getDenseElementStorageWidth(Type elementType) {
534   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
535 }
536 
537 /// Set a bit to a specific value.
538 static void setBit(char *rawData, size_t bitPos, bool value) {
539   if (value)
540     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
541   else
542     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
543 }
544 
545 /// Return the value of the specified bit.
546 static bool getBit(const char *rawData, size_t bitPos) {
547   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
548 }
549 
550 /// Get start position of actual data in `value`. Actual data is
551 /// stored in last `bitWidth`/CHAR_BIT bytes in big endian.
552 static char *getAPIntDataPos(APInt &value, size_t bitWidth) {
553   char *dataPos =
554       const_cast<char *>(reinterpret_cast<const char *>(value.getRawData()));
555   if (llvm::support::endian::system_endianness() ==
556       llvm::support::endianness::big)
557     dataPos = dataPos + 8 - llvm::divideCeil(bitWidth, CHAR_BIT);
558   return dataPos;
559 }
560 
561 /// Read APInt `value` from appropriate position.
562 static void readAPInt(APInt &value, size_t bitWidth, char *outData) {
563   char *dataPos = getAPIntDataPos(value, bitWidth);
564   std::copy_n(dataPos, llvm::divideCeil(bitWidth, CHAR_BIT), outData);
565 }
566 
567 /// Write `inData` to appropriate position of APInt `value`.
568 static void writeAPInt(const char *inData, size_t bitWidth, APInt &value) {
569   char *dataPos = getAPIntDataPos(value, bitWidth);
570   std::copy_n(inData, llvm::divideCeil(bitWidth, CHAR_BIT), dataPos);
571 }
572 
573 /// Writes value to the bit position `bitPos` in array `rawData`.
574 static void writeBits(char *rawData, size_t bitPos, APInt value) {
575   size_t bitWidth = value.getBitWidth();
576 
577   // If the bitwidth is 1 we just toggle the specific bit.
578   if (bitWidth == 1)
579     return setBit(rawData, bitPos, value.isOneValue());
580 
581   // Otherwise, the bit position is guaranteed to be byte aligned.
582   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
583   readAPInt(value, bitWidth, rawData + (bitPos / CHAR_BIT));
584 }
585 
586 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
587 /// `rawData`.
588 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
589   // Handle a boolean bit position.
590   if (bitWidth == 1)
591     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
592 
593   // Otherwise, the bit position must be 8-bit aligned.
594   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
595   APInt result(bitWidth, 0);
596   writeAPInt(rawData + (bitPos / CHAR_BIT), bitWidth, result);
597   return result;
598 }
599 
600 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
601 /// the same element count as 'type'.
602 template <typename Values>
603 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
604   return (values.size() == 1) ||
605          (type.getNumElements() == static_cast<int64_t>(values.size()));
606 }
607 
608 //===----------------------------------------------------------------------===//
609 // DenseElementsAttr Iterators
610 //===----------------------------------------------------------------------===//
611 
612 //===----------------------------------------------------------------------===//
613 // AttributeElementIterator
614 
615 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
616     DenseElementsAttr attr, size_t index)
617     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
618                                       Attribute, Attribute, Attribute>(
619           attr.getAsOpaquePointer(), index) {}
620 
621 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
622   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
623   Type eltTy = owner.getType().getElementType();
624   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
625     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
626   if (eltTy.isa<IndexType>())
627     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
628   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
629     IntElementIterator intIt(owner, index);
630     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
631     return FloatAttr::get(eltTy, *floatIt);
632   }
633   if (owner.isa<DenseStringElementsAttr>()) {
634     ArrayRef<StringRef> vals = owner.getRawStringData();
635     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
636   }
637   llvm_unreachable("unexpected element type");
638 }
639 
640 //===----------------------------------------------------------------------===//
641 // BoolElementIterator
642 
643 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
644     DenseElementsAttr attr, size_t dataIndex)
645     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
646           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
647 
648 bool DenseElementsAttr::BoolElementIterator::operator*() const {
649   return getBit(getData(), getDataIndex());
650 }
651 
652 //===----------------------------------------------------------------------===//
653 // IntElementIterator
654 
655 DenseElementsAttr::IntElementIterator::IntElementIterator(
656     DenseElementsAttr attr, size_t dataIndex)
657     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
658           attr.getRawData().data(), attr.isSplat(), dataIndex),
659       bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
660 
661 APInt DenseElementsAttr::IntElementIterator::operator*() const {
662   return readBits(getData(),
663                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
664                   bitWidth);
665 }
666 
667 //===----------------------------------------------------------------------===//
668 // ComplexIntElementIterator
669 
670 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
671     DenseElementsAttr attr, size_t dataIndex)
672     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
673                                       std::complex<APInt>, std::complex<APInt>,
674                                       std::complex<APInt>>(
675           attr.getRawData().data(), attr.isSplat(), dataIndex) {
676   auto complexType = attr.getType().getElementType().cast<ComplexType>();
677   bitWidth = getDenseElementBitWidth(complexType.getElementType());
678 }
679 
680 std::complex<APInt>
681 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
682   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
683   size_t offset = getDataIndex() * storageWidth * 2;
684   return {readBits(getData(), offset, bitWidth),
685           readBits(getData(), offset + storageWidth, bitWidth)};
686 }
687 
688 //===----------------------------------------------------------------------===//
689 // FloatElementIterator
690 
691 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
692     const llvm::fltSemantics &smt, IntElementIterator it)
693     : llvm::mapped_iterator<IntElementIterator,
694                             std::function<APFloat(const APInt &)>>(
695           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
696 
697 //===----------------------------------------------------------------------===//
698 // ComplexFloatElementIterator
699 
700 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
701     const llvm::fltSemantics &smt, ComplexIntElementIterator it)
702     : llvm::mapped_iterator<
703           ComplexIntElementIterator,
704           std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
705           it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
706             return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
707           }) {}
708 
709 //===----------------------------------------------------------------------===//
710 // DenseElementsAttr
711 //===----------------------------------------------------------------------===//
712 
713 /// Method for support type inquiry through isa, cast and dyn_cast.
714 bool DenseElementsAttr::classof(Attribute attr) {
715   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
716 }
717 
718 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
719                                          ArrayRef<Attribute> values) {
720   assert(hasSameElementsOrSplat(type, values));
721 
722   // If the element type is not based on int/float/index, assume it is a string
723   // type.
724   auto eltType = type.getElementType();
725   if (!type.getElementType().isIntOrIndexOrFloat()) {
726     SmallVector<StringRef, 8> stringValues;
727     stringValues.reserve(values.size());
728     for (Attribute attr : values) {
729       assert(attr.isa<StringAttr>() &&
730              "expected string value for non integer/index/float element");
731       stringValues.push_back(attr.cast<StringAttr>().getValue());
732     }
733     return get(type, stringValues);
734   }
735 
736   // Otherwise, get the raw storage width to use for the allocation.
737   size_t bitWidth = getDenseElementBitWidth(eltType);
738   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
739 
740   // Compress the attribute values into a character buffer.
741   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
742                             values.size());
743   APInt intVal;
744   for (unsigned i = 0, e = values.size(); i < e; ++i) {
745     assert(eltType == values[i].getType() &&
746            "expected attribute value to have element type");
747     if (eltType.isa<FloatType>())
748       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
749     else if (eltType.isa<IntegerType>())
750       intVal = values[i].cast<IntegerAttr>().getValue();
751     else
752       llvm_unreachable("unexpected element type");
753 
754     assert(intVal.getBitWidth() == bitWidth &&
755            "expected value to have same bitwidth as element type");
756     writeBits(data.data(), i * storageBitWidth, intVal);
757   }
758   return DenseIntOrFPElementsAttr::getRaw(type, data,
759                                           /*isSplat=*/(values.size() == 1));
760 }
761 
762 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
763                                          ArrayRef<bool> values) {
764   assert(hasSameElementsOrSplat(type, values));
765   assert(type.getElementType().isInteger(1));
766 
767   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
768   for (int i = 0, e = values.size(); i != e; ++i)
769     setBit(buff.data(), i, values[i]);
770   return DenseIntOrFPElementsAttr::getRaw(type, buff,
771                                           /*isSplat=*/(values.size() == 1));
772 }
773 
774 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
775                                          ArrayRef<StringRef> values) {
776   assert(!type.getElementType().isIntOrFloat());
777   return DenseStringElementsAttr::get(type, values);
778 }
779 
780 /// Constructs a dense integer elements attribute from an array of APInt
781 /// values. Each APInt value is expected to have the same bitwidth as the
782 /// element type of 'type'.
783 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
784                                          ArrayRef<APInt> values) {
785   assert(type.getElementType().isIntOrIndex());
786   assert(hasSameElementsOrSplat(type, values));
787   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
788   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
789                                           /*isSplat=*/(values.size() == 1));
790 }
791 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
792                                          ArrayRef<std::complex<APInt>> values) {
793   ComplexType complex = type.getElementType().cast<ComplexType>();
794   assert(complex.getElementType().isa<IntegerType>());
795   assert(hasSameElementsOrSplat(type, values));
796   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
797   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
798                           values.size() * 2);
799   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
800                                           /*isSplat=*/(values.size() == 1));
801 }
802 
803 // Constructs a dense float elements attribute from an array of APFloat
804 // values. Each APFloat value is expected to have the same bitwidth as the
805 // element type of 'type'.
806 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
807                                          ArrayRef<APFloat> values) {
808   assert(type.getElementType().isa<FloatType>());
809   assert(hasSameElementsOrSplat(type, values));
810   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
811   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
812                                           /*isSplat=*/(values.size() == 1));
813 }
814 DenseElementsAttr
815 DenseElementsAttr::get(ShapedType type,
816                        ArrayRef<std::complex<APFloat>> values) {
817   ComplexType complex = type.getElementType().cast<ComplexType>();
818   assert(complex.getElementType().isa<FloatType>());
819   assert(hasSameElementsOrSplat(type, values));
820   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
821                            values.size() * 2);
822   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
823   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
824                                           /*isSplat=*/(values.size() == 1));
825 }
826 
827 /// Construct a dense elements attribute from a raw buffer representing the
828 /// data for this attribute. Users should generally not use this methods as
829 /// the expected buffer format may not be a form the user expects.
830 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
831                                                       ArrayRef<char> rawBuffer,
832                                                       bool isSplatBuffer) {
833   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
834 }
835 
836 /// Returns true if the given buffer is a valid raw buffer for the given type.
837 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
838                                          ArrayRef<char> rawBuffer,
839                                          bool &detectedSplat) {
840   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
841   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
842 
843   // Storage width of 1 is special as it is packed by the bit.
844   if (storageWidth == 1) {
845     // Check for a splat, or a buffer equal to the number of elements.
846     if ((detectedSplat = rawBuffer.size() == 1))
847       return true;
848     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
849   }
850   // All other types are 8-bit aligned.
851   if ((detectedSplat = rawBufferWidth == storageWidth))
852     return true;
853   return rawBufferWidth == (storageWidth * type.getNumElements());
854 }
855 
856 /// Check the information for a C++ data type, check if this type is valid for
857 /// the current attribute. This method is used to verify specific type
858 /// invariants that the templatized 'getValues' method cannot.
859 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
860                               bool isSigned) {
861   // Make sure that the data element size is the same as the type element width.
862   if (getDenseElementBitWidth(type) !=
863       static_cast<size_t>(dataEltSize * CHAR_BIT))
864     return false;
865 
866   // Check that the element type is either float or integer or index.
867   if (!isInt)
868     return type.isa<FloatType>();
869   if (type.isIndex())
870     return true;
871 
872   auto intType = type.dyn_cast<IntegerType>();
873   if (!intType)
874     return false;
875 
876   // Make sure signedness semantics is consistent.
877   if (intType.isSignless())
878     return true;
879   return intType.isSigned() ? isSigned : !isSigned;
880 }
881 
882 /// Defaults down the subclass implementation.
883 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
884                                                    ArrayRef<char> data,
885                                                    int64_t dataEltSize,
886                                                    bool isInt, bool isSigned) {
887   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
888                                                  isSigned);
889 }
890 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
891                                                       ArrayRef<char> data,
892                                                       int64_t dataEltSize,
893                                                       bool isInt,
894                                                       bool isSigned) {
895   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
896                                                     isInt, isSigned);
897 }
898 
899 /// A method used to verify specific type invariants that the templatized 'get'
900 /// method cannot.
901 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
902                                           bool isSigned) const {
903   return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
904                              isSigned);
905 }
906 
907 /// Check the information for a C++ data type, check if this type is valid for
908 /// the current attribute.
909 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
910                                        bool isSigned) const {
911   return ::isValidIntOrFloat(
912       getType().getElementType().cast<ComplexType>().getElementType(),
913       dataEltSize / 2, isInt, isSigned);
914 }
915 
916 /// Returns true if this attribute corresponds to a splat, i.e. if all element
917 /// values are the same.
918 bool DenseElementsAttr::isSplat() const {
919   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
920 }
921 
922 /// Return the held element values as a range of Attributes.
923 auto DenseElementsAttr::getAttributeValues() const
924     -> llvm::iterator_range<AttributeElementIterator> {
925   return {attr_value_begin(), attr_value_end()};
926 }
927 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
928   return AttributeElementIterator(*this, 0);
929 }
930 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
931   return AttributeElementIterator(*this, getNumElements());
932 }
933 
934 /// Return the held element values as a range of bool. The element type of
935 /// this attribute must be of integer type of bitwidth 1.
936 auto DenseElementsAttr::getBoolValues() const
937     -> llvm::iterator_range<BoolElementIterator> {
938   auto eltType = getType().getElementType().dyn_cast<IntegerType>();
939   assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
940   (void)eltType;
941   return {BoolElementIterator(*this, 0),
942           BoolElementIterator(*this, getNumElements())};
943 }
944 
945 /// Return the held element values as a range of APInts. The element type of
946 /// this attribute must be of integer type.
947 auto DenseElementsAttr::getIntValues() const
948     -> llvm::iterator_range<IntElementIterator> {
949   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
950   return {raw_int_begin(), raw_int_end()};
951 }
952 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
953   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
954   return raw_int_begin();
955 }
956 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
957   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
958   return raw_int_end();
959 }
960 auto DenseElementsAttr::getComplexIntValues() const
961     -> llvm::iterator_range<ComplexIntElementIterator> {
962   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
963   (void)eltTy;
964   assert(eltTy.isa<IntegerType>() && "expected complex integral type");
965   return {ComplexIntElementIterator(*this, 0),
966           ComplexIntElementIterator(*this, getNumElements())};
967 }
968 
969 /// Return the held element values as a range of APFloat. The element type of
970 /// this attribute must be of float type.
971 auto DenseElementsAttr::getFloatValues() const
972     -> llvm::iterator_range<FloatElementIterator> {
973   auto elementType = getType().getElementType().cast<FloatType>();
974   const auto &elementSemantics = elementType.getFloatSemantics();
975   return {FloatElementIterator(elementSemantics, raw_int_begin()),
976           FloatElementIterator(elementSemantics, raw_int_end())};
977 }
978 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
979   return getFloatValues().begin();
980 }
981 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
982   return getFloatValues().end();
983 }
984 auto DenseElementsAttr::getComplexFloatValues() const
985     -> llvm::iterator_range<ComplexFloatElementIterator> {
986   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
987   assert(eltTy.isa<FloatType>() && "expected complex float type");
988   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
989   return {{semantics, {*this, 0}},
990           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
991 }
992 
993 /// Return the raw storage data held by this attribute.
994 ArrayRef<char> DenseElementsAttr::getRawData() const {
995   return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data;
996 }
997 
998 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
999   return static_cast<DenseStringElementsAttributeStorage *>(impl)->data;
1000 }
1001 
1002 /// Return a new DenseElementsAttr that has the same data as the current
1003 /// attribute, but has been reshaped to 'newType'. The new type must have the
1004 /// same total number of elements as well as element type.
1005 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1006   ShapedType curType = getType();
1007   if (curType == newType)
1008     return *this;
1009 
1010   (void)curType;
1011   assert(newType.getElementType() == curType.getElementType() &&
1012          "expected the same element type");
1013   assert(newType.getNumElements() == curType.getNumElements() &&
1014          "expected the same number of elements");
1015   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
1016 }
1017 
1018 DenseElementsAttr
1019 DenseElementsAttr::mapValues(Type newElementType,
1020                              function_ref<APInt(const APInt &)> mapping) const {
1021   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1022 }
1023 
1024 DenseElementsAttr DenseElementsAttr::mapValues(
1025     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1026   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1027 }
1028 
1029 //===----------------------------------------------------------------------===//
1030 // DenseStringElementsAttr
1031 //===----------------------------------------------------------------------===//
1032 
1033 DenseStringElementsAttr
1034 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
1035   return Base::get(type.getContext(), type, values, (values.size() == 1));
1036 }
1037 
1038 //===----------------------------------------------------------------------===//
1039 // DenseIntOrFPElementsAttr
1040 //===----------------------------------------------------------------------===//
1041 
1042 /// Utility method to write a range of APInt values to a buffer.
1043 template <typename APRangeT>
1044 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1045                                 APRangeT &&values) {
1046   data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
1047   size_t offset = 0;
1048   for (auto it = values.begin(), e = values.end(); it != e;
1049        ++it, offset += storageWidth) {
1050     assert((*it).getBitWidth() <= storageWidth);
1051     writeBits(data.data(), offset, *it);
1052   }
1053 }
1054 
1055 /// Constructs a dense elements attribute from an array of raw APFloat values.
1056 /// Each APFloat value is expected to have the same bitwidth as the element
1057 /// type of 'type'. 'type' must be a vector or tensor with static shape.
1058 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1059                                                    size_t storageWidth,
1060                                                    ArrayRef<APFloat> values,
1061                                                    bool isSplat) {
1062   std::vector<char> data;
1063   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1064   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1065   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1066 }
1067 
1068 /// Constructs a dense elements attribute from an array of raw APInt values.
1069 /// Each APInt value is expected to have the same bitwidth as the element type
1070 /// of 'type'.
1071 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1072                                                    size_t storageWidth,
1073                                                    ArrayRef<APInt> values,
1074                                                    bool isSplat) {
1075   std::vector<char> data;
1076   writeAPIntsToBuffer(storageWidth, data, values);
1077   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1078 }
1079 
1080 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1081                                                    ArrayRef<char> data,
1082                                                    bool isSplat) {
1083   assert((type.isa<RankedTensorType, VectorType>()) &&
1084          "type must be ranked tensor or vector");
1085   assert(type.hasStaticShape() && "type must have static shape");
1086   return Base::get(type.getContext(), type, data, isSplat);
1087 }
1088 
1089 /// Overload of the raw 'get' method that asserts that the given type is of
1090 /// complex type. This method is used to verify type invariants that the
1091 /// templatized 'get' method cannot.
1092 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1093                                                           ArrayRef<char> data,
1094                                                           int64_t dataEltSize,
1095                                                           bool isInt,
1096                                                           bool isSigned) {
1097   assert(::isValidIntOrFloat(
1098       type.getElementType().cast<ComplexType>().getElementType(),
1099       dataEltSize / 2, isInt, isSigned));
1100 
1101   int64_t numElements = data.size() / dataEltSize;
1102   assert(numElements == 1 || numElements == type.getNumElements());
1103   return getRaw(type, data, /*isSplat=*/numElements == 1);
1104 }
1105 
1106 /// Overload of the 'getRaw' method that asserts that the given type is of
1107 /// integer type. This method is used to verify type invariants that the
1108 /// templatized 'get' method cannot.
1109 DenseElementsAttr
1110 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1111                                            int64_t dataEltSize, bool isInt,
1112                                            bool isSigned) {
1113   assert(
1114       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1115 
1116   int64_t numElements = data.size() / dataEltSize;
1117   assert(numElements == 1 || numElements == type.getNumElements());
1118   return getRaw(type, data, /*isSplat=*/numElements == 1);
1119 }
1120 
1121 //===----------------------------------------------------------------------===//
1122 // DenseFPElementsAttr
1123 //===----------------------------------------------------------------------===//
1124 
1125 template <typename Fn, typename Attr>
1126 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1127                                 Type newElementType,
1128                                 llvm::SmallVectorImpl<char> &data) {
1129   size_t bitWidth = getDenseElementBitWidth(newElementType);
1130   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1131 
1132   ShapedType newArrayType;
1133   if (inType.isa<RankedTensorType>())
1134     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1135   else if (inType.isa<UnrankedTensorType>())
1136     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1137   else if (inType.isa<VectorType>())
1138     newArrayType = VectorType::get(inType.getShape(), newElementType);
1139   else
1140     assert(newArrayType && "Unhandled tensor type");
1141 
1142   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1143   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1144 
1145   // Functor used to process a single element value of the attribute.
1146   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1147     auto newInt = mapping(value);
1148     assert(newInt.getBitWidth() == bitWidth);
1149     writeBits(data.data(), index * storageBitWidth, newInt);
1150   };
1151 
1152   // Check for the splat case.
1153   if (attr.isSplat()) {
1154     processElt(*attr.begin(), /*index=*/0);
1155     return newArrayType;
1156   }
1157 
1158   // Otherwise, process all of the element values.
1159   uint64_t elementIdx = 0;
1160   for (auto value : attr)
1161     processElt(value, elementIdx++);
1162   return newArrayType;
1163 }
1164 
1165 DenseElementsAttr DenseFPElementsAttr::mapValues(
1166     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1167   llvm::SmallVector<char, 8> elementData;
1168   auto newArrayType =
1169       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1170 
1171   return getRaw(newArrayType, elementData, isSplat());
1172 }
1173 
1174 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1175 bool DenseFPElementsAttr::classof(Attribute attr) {
1176   return attr.isa<DenseElementsAttr>() &&
1177          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1178 }
1179 
1180 //===----------------------------------------------------------------------===//
1181 // DenseIntElementsAttr
1182 //===----------------------------------------------------------------------===//
1183 
1184 DenseElementsAttr DenseIntElementsAttr::mapValues(
1185     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1186   llvm::SmallVector<char, 8> elementData;
1187   auto newArrayType =
1188       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1189 
1190   return getRaw(newArrayType, elementData, isSplat());
1191 }
1192 
1193 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1194 bool DenseIntElementsAttr::classof(Attribute attr) {
1195   return attr.isa<DenseElementsAttr>() &&
1196          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1197 }
1198 
1199 //===----------------------------------------------------------------------===//
1200 // OpaqueElementsAttr
1201 //===----------------------------------------------------------------------===//
1202 
1203 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
1204                                            StringRef bytes) {
1205   assert(TensorType::isValidElementType(type.getElementType()) &&
1206          "Input element type should be a valid tensor element type");
1207   return Base::get(type.getContext(), type, dialect, bytes);
1208 }
1209 
1210 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
1211 
1212 /// Return the value at the given index. If index does not refer to a valid
1213 /// element, then a null attribute is returned.
1214 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1215   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1216   return Attribute();
1217 }
1218 
1219 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
1220 
1221 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1222   auto *d = getDialect();
1223   if (!d)
1224     return true;
1225   auto *interface =
1226       d->getRegisteredInterface<DialectDecodeAttributesInterface>();
1227   if (!interface)
1228     return true;
1229   return failed(interface->decode(*this, result));
1230 }
1231 
1232 //===----------------------------------------------------------------------===//
1233 // SparseElementsAttr
1234 //===----------------------------------------------------------------------===//
1235 
1236 SparseElementsAttr SparseElementsAttr::get(ShapedType type,
1237                                            DenseElementsAttr indices,
1238                                            DenseElementsAttr values) {
1239   assert(indices.getType().getElementType().isInteger(64) &&
1240          "expected sparse indices to be 64-bit integer values");
1241   assert((type.isa<RankedTensorType, VectorType>()) &&
1242          "type must be ranked tensor or vector");
1243   assert(type.hasStaticShape() && "type must have static shape");
1244   return Base::get(type.getContext(), type,
1245                    indices.cast<DenseIntElementsAttr>(), values);
1246 }
1247 
1248 DenseIntElementsAttr SparseElementsAttr::getIndices() const {
1249   return getImpl()->indices;
1250 }
1251 
1252 DenseElementsAttr SparseElementsAttr::getValues() const {
1253   return getImpl()->values;
1254 }
1255 
1256 /// Return the value of the element at the given index.
1257 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1258   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1259   auto type = getType();
1260 
1261   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1262   // as a 1-D index array.
1263   auto sparseIndices = getIndices();
1264   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1265 
1266   // Check to see if the indices are a splat.
1267   if (sparseIndices.isSplat()) {
1268     // If the index is also not a splat of the index value, we know that the
1269     // value is zero.
1270     auto splatIndex = *sparseIndexValues.begin();
1271     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
1272       return getZeroAttr();
1273 
1274     // If the indices are a splat, we also expect the values to be a splat.
1275     assert(getValues().isSplat() && "expected splat values");
1276     return getValues().getSplatValue();
1277   }
1278 
1279   // Build a mapping between known indices and the offset of the stored element.
1280   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
1281   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1282   size_t rank = type.getRank();
1283   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1284     mappedIndices.try_emplace(
1285         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1286 
1287   // Look for the provided index key within the mapped indices. If the provided
1288   // index is not found, then return a zero attribute.
1289   auto it = mappedIndices.find(index);
1290   if (it == mappedIndices.end())
1291     return getZeroAttr();
1292 
1293   // Otherwise, return the held sparse value element.
1294   return getValues().getValue(it->second);
1295 }
1296 
1297 /// Get a zero APFloat for the given sparse attribute.
1298 APFloat SparseElementsAttr::getZeroAPFloat() const {
1299   auto eltType = getType().getElementType().cast<FloatType>();
1300   return APFloat(eltType.getFloatSemantics());
1301 }
1302 
1303 /// Get a zero APInt for the given sparse attribute.
1304 APInt SparseElementsAttr::getZeroAPInt() const {
1305   auto eltType = getType().getElementType().cast<IntegerType>();
1306   return APInt::getNullValue(eltType.getWidth());
1307 }
1308 
1309 /// Get a zero attribute for the given attribute type.
1310 Attribute SparseElementsAttr::getZeroAttr() const {
1311   auto eltType = getType().getElementType();
1312 
1313   // Handle floating point elements.
1314   if (eltType.isa<FloatType>())
1315     return FloatAttr::get(eltType, 0);
1316 
1317   // Otherwise, this is an integer.
1318   // TODO: Handle StringAttr here.
1319   return IntegerAttr::get(eltType, 0);
1320 }
1321 
1322 /// Flatten, and return, all of the sparse indices in this attribute in
1323 /// row-major order.
1324 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1325   std::vector<ptrdiff_t> flatSparseIndices;
1326 
1327   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1328   // as a 1-D index array.
1329   auto sparseIndices = getIndices();
1330   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1331   if (sparseIndices.isSplat()) {
1332     SmallVector<uint64_t, 8> indices(getType().getRank(),
1333                                      *sparseIndexValues.begin());
1334     flatSparseIndices.push_back(getFlattenedIndex(indices));
1335     return flatSparseIndices;
1336   }
1337 
1338   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1339   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1340   size_t rank = getType().getRank();
1341   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1342     flatSparseIndices.push_back(getFlattenedIndex(
1343         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1344   return flatSparseIndices;
1345 }
1346 
1347 //===----------------------------------------------------------------------===//
1348 // MutableDictionaryAttr
1349 //===----------------------------------------------------------------------===//
1350 
1351 MutableDictionaryAttr::MutableDictionaryAttr(
1352     ArrayRef<NamedAttribute> attributes) {
1353   setAttrs(attributes);
1354 }
1355 
1356 /// Return the underlying dictionary attribute.
1357 DictionaryAttr
1358 MutableDictionaryAttr::getDictionary(MLIRContext *context) const {
1359   // Construct empty DictionaryAttr if needed.
1360   if (!attrs)
1361     return DictionaryAttr::get({}, context);
1362   return attrs;
1363 }
1364 
1365 ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const {
1366   return attrs ? attrs.getValue() : llvm::None;
1367 }
1368 
1369 /// Replace the held attributes with ones provided in 'newAttrs'.
1370 void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) {
1371   // Don't create an attribute list if there are no attributes.
1372   if (attributes.empty())
1373     attrs = nullptr;
1374   else
1375     attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
1376 }
1377 
1378 /// Return the specified attribute if present, null otherwise.
1379 Attribute MutableDictionaryAttr::get(StringRef name) const {
1380   return attrs ? attrs.get(name) : nullptr;
1381 }
1382 
1383 /// Return the specified attribute if present, null otherwise.
1384 Attribute MutableDictionaryAttr::get(Identifier name) const {
1385   return attrs ? attrs.get(name) : nullptr;
1386 }
1387 
1388 /// Return the specified named attribute if present, None otherwise.
1389 Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const {
1390   return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1391 }
1392 Optional<NamedAttribute>
1393 MutableDictionaryAttr::getNamed(Identifier name) const {
1394   return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1395 }
1396 
1397 /// If the an attribute exists with the specified name, change it to the new
1398 /// value.  Otherwise, add a new attribute with the specified name/value.
1399 void MutableDictionaryAttr::set(Identifier name, Attribute value) {
1400   assert(value && "attributes may never be null");
1401 
1402   // Look for an existing value for the given name, and set it in-place.
1403   ArrayRef<NamedAttribute> values = getAttrs();
1404   const auto *it = llvm::find_if(
1405       values, [name](NamedAttribute attr) { return attr.first == name; });
1406   if (it != values.end()) {
1407     // Bail out early if the value is the same as what we already have.
1408     if (it->second == value)
1409       return;
1410 
1411     SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end());
1412     newAttrs[it - values.begin()].second = value;
1413     attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1414     return;
1415   }
1416 
1417   // Otherwise, insert the new attribute into its sorted position.
1418   it = llvm::lower_bound(values, name);
1419   SmallVector<NamedAttribute, 8> newAttrs;
1420   newAttrs.reserve(values.size() + 1);
1421   newAttrs.append(values.begin(), it);
1422   newAttrs.push_back({name, value});
1423   newAttrs.append(it, values.end());
1424   attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1425 }
1426 
1427 /// Remove the attribute with the specified name if it exists.  The return
1428 /// value indicates whether the attribute was present or not.
1429 auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult {
1430   auto origAttrs = getAttrs();
1431   for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
1432     if (origAttrs[i].first == name) {
1433       // Handle the simple case of removing the only attribute in the list.
1434       if (e == 1) {
1435         attrs = nullptr;
1436         return RemoveResult::Removed;
1437       }
1438 
1439       SmallVector<NamedAttribute, 8> newAttrs;
1440       newAttrs.reserve(origAttrs.size() - 1);
1441       newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
1442       newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
1443       attrs = DictionaryAttr::getWithSorted(newAttrs,
1444                                             newAttrs[0].second.getContext());
1445       return RemoveResult::Removed;
1446     }
1447   }
1448   return RemoveResult::NotFound;
1449 }
1450 
1451 bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) {
1452   return strcmp(lhs.first.data(), rhs.first.data()) < 0;
1453 }
1454 bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) {
1455   // This is correct even when attr.first.data()[name.size()] is not a zero
1456   // string terminator, because we only care about a less than comparison.
1457   // This can't use memcmp, because it doesn't guarantee that it will stop
1458   // reading both buffers if one is shorter than the other, even if there is
1459   // a difference.
1460   return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0;
1461 }
1462