1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Attributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Diagnostics.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/Function.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Types.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Endian.h"
20 
21 using namespace mlir;
22 using namespace mlir::detail;
23 
24 //===----------------------------------------------------------------------===//
25 // AttributeStorage
26 //===----------------------------------------------------------------------===//
27 
28 AttributeStorage::AttributeStorage(Type type)
29     : type(type.getAsOpaquePointer()) {}
30 AttributeStorage::AttributeStorage() : type(nullptr) {}
31 
32 Type AttributeStorage::getType() const {
33   return Type::getFromOpaquePointer(type);
34 }
35 void AttributeStorage::setType(Type newType) {
36   type = newType.getAsOpaquePointer();
37 }
38 
39 //===----------------------------------------------------------------------===//
40 // Attribute
41 //===----------------------------------------------------------------------===//
42 
43 /// Return the type of this attribute.
44 Type Attribute::getType() const { return impl->getType(); }
45 
46 /// Return the context this attribute belongs to.
47 MLIRContext *Attribute::getContext() const { return getType().getContext(); }
48 
49 /// Get the dialect this attribute is registered to.
50 Dialect &Attribute::getDialect() const { return impl->getDialect(); }
51 
52 //===----------------------------------------------------------------------===//
53 // AffineMapAttr
54 //===----------------------------------------------------------------------===//
55 
56 AffineMapAttr AffineMapAttr::get(AffineMap value) {
57   return Base::get(value.getContext(), StandardAttributes::AffineMap, value);
58 }
59 
60 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
61 
62 //===----------------------------------------------------------------------===//
63 // ArrayAttr
64 //===----------------------------------------------------------------------===//
65 
66 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
67   return Base::get(context, StandardAttributes::Array, value);
68 }
69 
70 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
71 
72 Attribute ArrayAttr::operator[](unsigned idx) const {
73   assert(idx < size() && "index out of bounds");
74   return getValue()[idx];
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // BoolAttr
79 //===----------------------------------------------------------------------===//
80 
81 bool BoolAttr::getValue() const { return getImpl()->value; }
82 
83 //===----------------------------------------------------------------------===//
84 // DictionaryAttr
85 //===----------------------------------------------------------------------===//
86 
87 /// Perform a three-way comparison between the names of the specified
88 /// NamedAttributes.
89 static int compareNamedAttributes(const NamedAttribute *lhs,
90                                   const NamedAttribute *rhs) {
91   return strcmp(lhs->first.data(), rhs->first.data());
92 }
93 
94 /// Returns if the name of the given attribute precedes that of 'name'.
95 static bool compareNamedAttributeWithName(const NamedAttribute &attr,
96                                           StringRef name) {
97   // This is correct even when attr.first.data()[name.size()] is not a zero
98   // string terminator, because we only care about a less than comparison.
99   // This can't use memcmp, because it doesn't guarantee that it will stop
100   // reading both buffers if one is shorter than the other, even if there is
101   // a difference.
102   return strncmp(attr.first.data(), name.data(), name.size()) < 0;
103 }
104 
105 /// Helper function that does either an in place sort or sorts from source array
106 /// into destination. If inPlace then storage is both the source and the
107 /// destination, else value is the source and storage destination. Returns
108 /// whether source was sorted.
109 template <bool inPlace>
110 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
111                                SmallVectorImpl<NamedAttribute> &storage) {
112   // Specialize for the common case.
113   switch (value.size()) {
114   case 0:
115   case 1:
116     // Zero or one elements are already sorted.
117     break;
118   case 2:
119     assert(value[0].first != value[1].first &&
120            "DictionaryAttr element names must be unique");
121     if (compareNamedAttributes(&value[0], &value[1]) > 0) {
122       if (inPlace)
123         std::swap(storage[0], storage[1]);
124       else
125         storage.append({value[1], value[0]});
126       return true;
127     }
128     break;
129   default:
130     // Check to see they are sorted already.
131     bool isSorted =
132         llvm::is_sorted(value, [](NamedAttribute l, NamedAttribute r) {
133           return compareNamedAttributes(&l, &r) < 0;
134         });
135     if (!isSorted) {
136       // If not, do a general sort.
137       if (!inPlace)
138         storage.append(value.begin(), value.end());
139       llvm::array_pod_sort(storage.begin(), storage.end(),
140                            compareNamedAttributes);
141       value = storage;
142     }
143 
144     // Ensure that the attribute elements are unique.
145     assert(std::adjacent_find(value.begin(), value.end(),
146                               [](NamedAttribute l, NamedAttribute r) {
147                                 return l.first == r.first;
148                               }) == value.end() &&
149            "DictionaryAttr element names must be unique");
150     return !isSorted;
151   }
152   return false;
153 }
154 
155 /// Sorts the NamedAttributes in the array ordered by name as expected by
156 /// getWithSorted.
157 /// Requires: uniquely named attributes.
158 void DictionaryAttr::sort(SmallVectorImpl<NamedAttribute> &array) {
159   dictionaryAttrSort</*inPlace=*/true>(array, array);
160 }
161 
162 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
163                                    MLIRContext *context) {
164   assert(llvm::all_of(value,
165                       [](const NamedAttribute &attr) { return attr.second; }) &&
166          "value cannot have null entries");
167 
168   // We need to sort the element list to canonicalize it.
169   SmallVector<NamedAttribute, 8> storage;
170   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
171     value = storage;
172 
173   return Base::get(context, StandardAttributes::Dictionary, value);
174 }
175 
176 /// Construct a dictionary with an array of values that is known to already be
177 /// sorted by name and uniqued.
178 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
179                                              MLIRContext *context) {
180   // Ensure that the attribute elements are unique and sorted.
181   assert(llvm::is_sorted(value,
182                          [](NamedAttribute l, NamedAttribute r) {
183                            return l.first.strref() < r.first.strref();
184                          }) &&
185          "expected attribute values to be sorted");
186   assert(std::adjacent_find(value.begin(), value.end(),
187                             [](NamedAttribute l, NamedAttribute r) {
188                               return l.first == r.first;
189                             }) == value.end() &&
190          "DictionaryAttr element names must be unique");
191   return Base::get(context, StandardAttributes::Dictionary, value);
192 }
193 
194 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
195   return getImpl()->getElements();
196 }
197 
198 /// Return the specified attribute if present, null otherwise.
199 Attribute DictionaryAttr::get(StringRef name) const {
200   Optional<NamedAttribute> attr = getNamed(name);
201   return attr ? attr->second : nullptr;
202 }
203 Attribute DictionaryAttr::get(Identifier name) const {
204   Optional<NamedAttribute> attr = getNamed(name);
205   return attr ? attr->second : nullptr;
206 }
207 
208 /// Return the specified named attribute if present, None otherwise.
209 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
210   ArrayRef<NamedAttribute> values = getValue();
211   auto it = llvm::lower_bound(values, name, compareNamedAttributeWithName);
212   return it != values.end() && it->first == name ? *it
213                                                  : Optional<NamedAttribute>();
214 }
215 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
216   for (auto elt : getValue())
217     if (elt.first == name)
218       return elt;
219   return llvm::None;
220 }
221 
222 DictionaryAttr::iterator DictionaryAttr::begin() const {
223   return getValue().begin();
224 }
225 DictionaryAttr::iterator DictionaryAttr::end() const {
226   return getValue().end();
227 }
228 size_t DictionaryAttr::size() const { return getValue().size(); }
229 
230 //===----------------------------------------------------------------------===//
231 // FloatAttr
232 //===----------------------------------------------------------------------===//
233 
234 FloatAttr FloatAttr::get(Type type, double value) {
235   return Base::get(type.getContext(), StandardAttributes::Float, type, value);
236 }
237 
238 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
239   return Base::getChecked(loc, StandardAttributes::Float, type, value);
240 }
241 
242 FloatAttr FloatAttr::get(Type type, const APFloat &value) {
243   return Base::get(type.getContext(), StandardAttributes::Float, type, value);
244 }
245 
246 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
247   return Base::getChecked(loc, StandardAttributes::Float, type, value);
248 }
249 
250 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
251 
252 double FloatAttr::getValueAsDouble() const {
253   return getValueAsDouble(getValue());
254 }
255 double FloatAttr::getValueAsDouble(APFloat value) {
256   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
257     bool losesInfo = false;
258     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
259                   &losesInfo);
260   }
261   return value.convertToDouble();
262 }
263 
264 /// Verify construction invariants.
265 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
266   if (!type.isa<FloatType>())
267     return emitError(loc, "expected floating point type");
268   return success();
269 }
270 
271 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
272                                                       double value) {
273   return verifyFloatTypeInvariants(loc, type);
274 }
275 
276 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
277                                                       const APFloat &value) {
278   // Verify that the type is correct.
279   if (failed(verifyFloatTypeInvariants(loc, type)))
280     return failure();
281 
282   // Verify that the type semantics match that of the value.
283   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
284     return emitError(
285         loc, "FloatAttr type doesn't match the type implied by its value");
286   }
287   return success();
288 }
289 
290 //===----------------------------------------------------------------------===//
291 // SymbolRefAttr
292 //===----------------------------------------------------------------------===//
293 
294 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
295   return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None)
296       .cast<FlatSymbolRefAttr>();
297 }
298 
299 SymbolRefAttr SymbolRefAttr::get(StringRef value,
300                                  ArrayRef<FlatSymbolRefAttr> nestedReferences,
301                                  MLIRContext *ctx) {
302   return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences);
303 }
304 
305 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
306 
307 StringRef SymbolRefAttr::getLeafReference() const {
308   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
309   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
310 }
311 
312 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
313   return getImpl()->getNestedRefs();
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // IntegerAttr
318 //===----------------------------------------------------------------------===//
319 
320 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
321   return Base::get(type.getContext(), StandardAttributes::Integer, type, value);
322 }
323 
324 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
325   // This uses 64 bit APInts by default for index type.
326   if (type.isIndex())
327     return get(type, APInt(IndexType::kInternalStorageBitWidth, value));
328 
329   auto intType = type.cast<IntegerType>();
330   return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
331 }
332 
333 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
334 
335 int64_t IntegerAttr::getInt() const {
336   assert((getImpl()->getType().isIndex() ||
337           getImpl()->getType().isSignlessInteger()) &&
338          "must be signless integer");
339   return getValue().getSExtValue();
340 }
341 
342 int64_t IntegerAttr::getSInt() const {
343   assert(getImpl()->getType().isSignedInteger() && "must be signed integer");
344   return getValue().getSExtValue();
345 }
346 
347 uint64_t IntegerAttr::getUInt() const {
348   assert(getImpl()->getType().isUnsignedInteger() &&
349          "must be unsigned integer");
350   return getValue().getZExtValue();
351 }
352 
353 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
354   if (type.isa<IntegerType>() || type.isa<IndexType>())
355     return success();
356   return emitError(loc, "expected integer or index type");
357 }
358 
359 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
360                                                         int64_t value) {
361   return verifyIntegerTypeInvariants(loc, type);
362 }
363 
364 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
365                                                         const APInt &value) {
366   if (failed(verifyIntegerTypeInvariants(loc, type)))
367     return failure();
368   if (auto integerType = type.dyn_cast<IntegerType>())
369     if (integerType.getWidth() != value.getBitWidth())
370       return emitError(loc, "integer type bit width (")
371              << integerType.getWidth() << ") doesn't match value bit width ("
372              << value.getBitWidth() << ")";
373   return success();
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // IntegerSetAttr
378 //===----------------------------------------------------------------------===//
379 
380 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
381   return Base::get(value.getConstraint(0).getContext(),
382                    StandardAttributes::IntegerSet, value);
383 }
384 
385 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
386 
387 //===----------------------------------------------------------------------===//
388 // OpaqueAttr
389 //===----------------------------------------------------------------------===//
390 
391 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
392                            MLIRContext *context) {
393   return Base::get(context, StandardAttributes::Opaque, dialect, attrData,
394                    type);
395 }
396 
397 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
398                                   Type type, Location location) {
399   return Base::getChecked(location, StandardAttributes::Opaque, dialect,
400                           attrData, type);
401 }
402 
403 /// Returns the dialect namespace of the opaque attribute.
404 Identifier OpaqueAttr::getDialectNamespace() const {
405   return getImpl()->dialectNamespace;
406 }
407 
408 /// Returns the raw attribute data of the opaque attribute.
409 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
410 
411 /// Verify the construction of an opaque attribute.
412 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
413                                                        Identifier dialect,
414                                                        StringRef attrData,
415                                                        Type type) {
416   if (!Dialect::isValidNamespace(dialect.strref()))
417     return emitError(loc, "invalid dialect namespace '") << dialect << "'";
418   return success();
419 }
420 
421 //===----------------------------------------------------------------------===//
422 // StringAttr
423 //===----------------------------------------------------------------------===//
424 
425 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
426   return get(bytes, NoneType::get(context));
427 }
428 
429 /// Get an instance of a StringAttr with the given string and Type.
430 StringAttr StringAttr::get(StringRef bytes, Type type) {
431   return Base::get(type.getContext(), StandardAttributes::String, bytes, type);
432 }
433 
434 StringRef StringAttr::getValue() const { return getImpl()->value; }
435 
436 //===----------------------------------------------------------------------===//
437 // TypeAttr
438 //===----------------------------------------------------------------------===//
439 
440 TypeAttr TypeAttr::get(Type value) {
441   return Base::get(value.getContext(), StandardAttributes::Type, value);
442 }
443 
444 Type TypeAttr::getValue() const { return getImpl()->value; }
445 
446 //===----------------------------------------------------------------------===//
447 // ElementsAttr
448 //===----------------------------------------------------------------------===//
449 
450 ShapedType ElementsAttr::getType() const {
451   return Attribute::getType().cast<ShapedType>();
452 }
453 
454 /// Returns the number of elements held by this attribute.
455 int64_t ElementsAttr::getNumElements() const {
456   return getType().getNumElements();
457 }
458 
459 /// Return the value at the given index. If index does not refer to a valid
460 /// element, then a null attribute is returned.
461 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
462   switch (getKind()) {
463   case StandardAttributes::DenseIntOrFPElements:
464     return cast<DenseElementsAttr>().getValue(index);
465   case StandardAttributes::OpaqueElements:
466     return cast<OpaqueElementsAttr>().getValue(index);
467   case StandardAttributes::SparseElements:
468     return cast<SparseElementsAttr>().getValue(index);
469   default:
470     llvm_unreachable("unknown ElementsAttr kind");
471   }
472 }
473 
474 /// Return if the given 'index' refers to a valid element in this attribute.
475 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
476   auto type = getType();
477 
478   // Verify that the rank of the indices matches the held type.
479   auto rank = type.getRank();
480   if (rank != static_cast<int64_t>(index.size()))
481     return false;
482 
483   // Verify that all of the indices are within the shape dimensions.
484   auto shape = type.getShape();
485   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
486     return static_cast<int64_t>(index[i]) < shape[i];
487   });
488 }
489 
490 ElementsAttr
491 ElementsAttr::mapValues(Type newElementType,
492                         function_ref<APInt(const APInt &)> mapping) const {
493   switch (getKind()) {
494   case StandardAttributes::DenseIntOrFPElements:
495     return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
496   default:
497     llvm_unreachable("unsupported ElementsAttr subtype");
498   }
499 }
500 
501 ElementsAttr
502 ElementsAttr::mapValues(Type newElementType,
503                         function_ref<APInt(const APFloat &)> mapping) const {
504   switch (getKind()) {
505   case StandardAttributes::DenseIntOrFPElements:
506     return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
507   default:
508     llvm_unreachable("unsupported ElementsAttr subtype");
509   }
510 }
511 
512 /// Returns the 1 dimensional flattened row-major index from the given
513 /// multi-dimensional index.
514 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
515   assert(isValidIndex(index) && "expected valid multi-dimensional index");
516   auto type = getType();
517 
518   // Reduce the provided multidimensional index into a flattended 1D row-major
519   // index.
520   auto rank = type.getRank();
521   auto shape = type.getShape();
522   uint64_t valueIndex = 0;
523   uint64_t dimMultiplier = 1;
524   for (int i = rank - 1; i >= 0; --i) {
525     valueIndex += index[i] * dimMultiplier;
526     dimMultiplier *= shape[i];
527   }
528   return valueIndex;
529 }
530 
531 //===----------------------------------------------------------------------===//
532 // DenseElementAttr Utilities
533 //===----------------------------------------------------------------------===//
534 
535 /// Get the bitwidth of a dense element type within the buffer.
536 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
537 static size_t getDenseElementStorageWidth(size_t origWidth) {
538   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
539 }
540 
541 /// Set a bit to a specific value.
542 static void setBit(char *rawData, size_t bitPos, bool value) {
543   if (value)
544     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
545   else
546     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
547 }
548 
549 /// Return the value of the specified bit.
550 static bool getBit(const char *rawData, size_t bitPos) {
551   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
552 }
553 
554 /// Get start position of actual data in `value`. Actual data is
555 /// stored in last `bitWidth`/CHAR_BIT bytes in big endian.
556 static char *getAPIntDataPos(APInt &value, size_t bitWidth) {
557   char *dataPos =
558       const_cast<char *>(reinterpret_cast<const char *>(value.getRawData()));
559   if (llvm::support::endian::system_endianness() ==
560       llvm::support::endianness::big)
561     dataPos = dataPos + 8 - llvm::divideCeil(bitWidth, CHAR_BIT);
562   return dataPos;
563 }
564 
565 /// Read APInt `value` from appropriate position.
566 static void readAPInt(APInt &value, size_t bitWidth, char *outData) {
567   char *dataPos = getAPIntDataPos(value, bitWidth);
568   std::copy_n(dataPos, llvm::divideCeil(bitWidth, CHAR_BIT), outData);
569 }
570 
571 /// Write `inData` to appropriate position of APInt `value`.
572 static void writeAPInt(const char *inData, size_t bitWidth, APInt &value) {
573   char *dataPos = getAPIntDataPos(value, bitWidth);
574   std::copy_n(inData, llvm::divideCeil(bitWidth, CHAR_BIT), dataPos);
575 }
576 
577 /// Writes value to the bit position `bitPos` in array `rawData`.
578 static void writeBits(char *rawData, size_t bitPos, APInt value) {
579   size_t bitWidth = value.getBitWidth();
580 
581   // If the bitwidth is 1 we just toggle the specific bit.
582   if (bitWidth == 1)
583     return setBit(rawData, bitPos, value.isOneValue());
584 
585   // Otherwise, the bit position is guaranteed to be byte aligned.
586   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
587   readAPInt(value, bitWidth, rawData + (bitPos / CHAR_BIT));
588 }
589 
590 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
591 /// `rawData`.
592 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
593   // Handle a boolean bit position.
594   if (bitWidth == 1)
595     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
596 
597   // Otherwise, the bit position must be 8-bit aligned.
598   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
599   APInt result(bitWidth, 0);
600   writeAPInt(rawData + (bitPos / CHAR_BIT), bitWidth, result);
601   return result;
602 }
603 
604 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the
605 /// same element count as 'type'.
606 template <typename Values>
607 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
608   return (values.size() == 1) ||
609          (type.getNumElements() == static_cast<int64_t>(values.size()));
610 }
611 
612 //===----------------------------------------------------------------------===//
613 // DenseElementAttr Iterators
614 //===----------------------------------------------------------------------===//
615 
616 /// Constructs a new iterator.
617 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
618     DenseElementsAttr attr, size_t index)
619     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
620                                       Attribute, Attribute, Attribute>(
621           attr.getAsOpaquePointer(), index) {}
622 
623 /// Accesses the Attribute value at this iterator position.
624 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
625   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
626   Type eltTy = owner.getType().getElementType();
627   if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) {
628     if (intEltTy.getWidth() == 1)
629       return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(),
630                            owner.getContext());
631     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
632   }
633   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
634     IntElementIterator intIt(owner, index);
635     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
636     return FloatAttr::get(eltTy, *floatIt);
637   }
638   if (owner.isa<DenseStringElementsAttr>())
639     return StringAttr::get(owner.getRawStringData()[index], eltTy);
640   llvm_unreachable("unexpected element type");
641 }
642 
643 /// Constructs a new iterator.
644 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
645     DenseElementsAttr attr, size_t dataIndex)
646     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
647           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
648 
649 /// Accesses the bool value at this iterator position.
650 bool DenseElementsAttr::BoolElementIterator::operator*() const {
651   return getBit(getData(), getDataIndex());
652 }
653 
654 /// Constructs a new iterator.
655 DenseElementsAttr::IntElementIterator::IntElementIterator(
656     DenseElementsAttr attr, size_t dataIndex)
657     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
658           attr.getRawData().data(), attr.isSplat(), dataIndex),
659       bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
660 
661 /// Accesses the raw APInt value at this iterator position.
662 APInt DenseElementsAttr::IntElementIterator::operator*() const {
663   return readBits(getData(),
664                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
665                   bitWidth);
666 }
667 
668 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
669     const llvm::fltSemantics &smt, IntElementIterator it)
670     : llvm::mapped_iterator<IntElementIterator,
671                             std::function<APFloat(const APInt &)>>(
672           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
673 
674 //===----------------------------------------------------------------------===//
675 // DenseElementsAttr
676 //===----------------------------------------------------------------------===//
677 
678 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
679                                          ArrayRef<Attribute> values) {
680   assert(hasSameElementsOrSplat(type, values));
681 
682   // If the element type is not based on int/float/index, assume it is a string
683   // type.
684   auto eltType = type.getElementType();
685   if (!type.getElementType().isIntOrIndexOrFloat()) {
686     SmallVector<StringRef, 8> stringValues;
687     stringValues.reserve(values.size());
688     for (Attribute attr : values) {
689       assert(attr.isa<StringAttr>() &&
690              "expected string value for non integer/index/float element");
691       stringValues.push_back(attr.cast<StringAttr>().getValue());
692     }
693     return get(type, stringValues);
694   }
695 
696   // Otherwise, get the raw storage width to use for the allocation.
697   size_t bitWidth = getDenseElementBitWidth(eltType);
698   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
699 
700   // Compress the attribute values into a character buffer.
701   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
702                             values.size());
703   APInt intVal;
704   for (unsigned i = 0, e = values.size(); i < e; ++i) {
705     assert(eltType == values[i].getType() &&
706            "expected attribute value to have element type");
707 
708     switch (eltType.getKind()) {
709     case StandardTypes::BF16:
710     case StandardTypes::F16:
711     case StandardTypes::F32:
712     case StandardTypes::F64:
713       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
714       break;
715     case StandardTypes::Integer:
716     case StandardTypes::Index:
717       intVal = values[i].isa<BoolAttr>()
718                    ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0)
719                    : values[i].cast<IntegerAttr>().getValue();
720       break;
721     default:
722       llvm_unreachable("unexpected element type");
723     }
724     assert(intVal.getBitWidth() == bitWidth &&
725            "expected value to have same bitwidth as element type");
726     writeBits(data.data(), i * storageBitWidth, intVal);
727   }
728   return DenseIntOrFPElementsAttr::getRaw(type, data,
729                                           /*isSplat=*/(values.size() == 1));
730 }
731 
732 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
733                                          ArrayRef<bool> values) {
734   assert(hasSameElementsOrSplat(type, values));
735   assert(type.getElementType().isInteger(1));
736 
737   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
738   for (int i = 0, e = values.size(); i != e; ++i)
739     setBit(buff.data(), i, values[i]);
740   return DenseIntOrFPElementsAttr::getRaw(type, buff,
741                                           /*isSplat=*/(values.size() == 1));
742 }
743 
744 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
745                                          ArrayRef<StringRef> values) {
746   assert(!type.getElementType().isIntOrFloat());
747   return DenseStringElementsAttr::get(type, values);
748 }
749 
750 /// Constructs a dense integer elements attribute from an array of APInt
751 /// values. Each APInt value is expected to have the same bitwidth as the
752 /// element type of 'type'.
753 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
754                                          ArrayRef<APInt> values) {
755   assert(type.getElementType().isIntOrIndex());
756   return DenseIntOrFPElementsAttr::getRaw(type, values);
757 }
758 
759 // Constructs a dense float elements attribute from an array of APFloat
760 // values. Each APFloat value is expected to have the same bitwidth as the
761 // element type of 'type'.
762 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
763                                          ArrayRef<APFloat> values) {
764   assert(type.getElementType().isa<FloatType>());
765 
766   // Convert the APFloat values to APInt and create a dense elements attribute.
767   std::vector<APInt> intValues(values.size());
768   for (unsigned i = 0, e = values.size(); i != e; ++i)
769     intValues[i] = values[i].bitcastToAPInt();
770   return DenseIntOrFPElementsAttr::getRaw(type, intValues);
771 }
772 
773 /// Construct a dense elements attribute from a raw buffer representing the
774 /// data for this attribute. Users should generally not use this methods as
775 /// the expected buffer format may not be a form the user expects.
776 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
777                                                       ArrayRef<char> rawBuffer,
778                                                       bool isSplatBuffer) {
779   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
780 }
781 
782 /// Returns true if the given buffer is a valid raw buffer for the given type.
783 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
784                                          ArrayRef<char> rawBuffer,
785                                          bool &detectedSplat) {
786   size_t elementWidth = getDenseElementBitWidth(type.getElementType());
787   size_t storageWidth = getDenseElementStorageWidth(elementWidth);
788   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
789 
790   // Storage width of 1 is special as it is packed by the bit.
791   if (storageWidth == 1) {
792     // Check for a splat, or a buffer equal to the number of elements.
793     if ((detectedSplat = rawBuffer.size() == 1))
794       return true;
795     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
796   }
797   // All other types are 8-bit aligned.
798   if ((detectedSplat = rawBufferWidth == storageWidth))
799     return true;
800   return rawBufferWidth == (storageWidth * type.getNumElements());
801 }
802 
803 /// Check the information for a C++ data type, check if this type is valid for
804 /// the current attribute. This method is used to verify specific type
805 /// invariants that the templatized 'getValues' method cannot.
806 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
807                               bool isSigned) {
808   // Make sure that the data element size is the same as the type element width.
809   if (getDenseElementBitWidth(type) !=
810       static_cast<size_t>(dataEltSize * CHAR_BIT))
811     return false;
812 
813   // Check that the element type is either float or integer or index.
814   if (!isInt)
815     return type.isa<FloatType>();
816   if (type.isIndex())
817     return true;
818 
819   auto intType = type.dyn_cast<IntegerType>();
820   if (!intType)
821     return false;
822 
823   // Make sure signedness semantics is consistent.
824   if (intType.isSignless())
825     return true;
826   return intType.isSigned() ? isSigned : !isSigned;
827 }
828 
829 /// Defaults down the subclass implementation.
830 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
831                                                    ArrayRef<char> data,
832                                                    int64_t dataEltSize,
833                                                    bool isInt, bool isSigned) {
834   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
835                                                  isSigned);
836 }
837 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
838                                                       ArrayRef<char> data,
839                                                       int64_t dataEltSize,
840                                                       bool isInt,
841                                                       bool isSigned) {
842   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
843                                                     isInt, isSigned);
844 }
845 
846 /// A method used to verify specific type invariants that the templatized 'get'
847 /// method cannot.
848 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
849                                           bool isSigned) const {
850   return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
851                              isSigned);
852 }
853 
854 /// Check the information for a C++ data type, check if this type is valid for
855 /// the current attribute.
856 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
857                                        bool isSigned) const {
858   return ::isValidIntOrFloat(
859       getType().getElementType().cast<ComplexType>().getElementType(),
860       dataEltSize / 2, isInt, isSigned);
861 }
862 
863 /// Returns if this attribute corresponds to a splat, i.e. if all element
864 /// values are the same.
865 bool DenseElementsAttr::isSplat() const {
866   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
867 }
868 
869 /// Return the held element values as a range of Attributes.
870 auto DenseElementsAttr::getAttributeValues() const
871     -> llvm::iterator_range<AttributeElementIterator> {
872   return {attr_value_begin(), attr_value_end()};
873 }
874 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
875   return AttributeElementIterator(*this, 0);
876 }
877 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
878   return AttributeElementIterator(*this, getNumElements());
879 }
880 
881 /// Return the held element values as a range of bool. The element type of
882 /// this attribute must be of integer type of bitwidth 1.
883 auto DenseElementsAttr::getBoolValues() const
884     -> llvm::iterator_range<BoolElementIterator> {
885   auto eltType = getType().getElementType().dyn_cast<IntegerType>();
886   assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
887   (void)eltType;
888   return {BoolElementIterator(*this, 0),
889           BoolElementIterator(*this, getNumElements())};
890 }
891 
892 /// Return the held element values as a range of APInts. The element type of
893 /// this attribute must be of integer type.
894 auto DenseElementsAttr::getIntValues() const
895     -> llvm::iterator_range<IntElementIterator> {
896   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
897   return {raw_int_begin(), raw_int_end()};
898 }
899 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
900   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
901   return raw_int_begin();
902 }
903 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
904   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
905   return raw_int_end();
906 }
907 
908 /// Return the held element values as a range of APFloat. The element type of
909 /// this attribute must be of float type.
910 auto DenseElementsAttr::getFloatValues() const
911     -> llvm::iterator_range<FloatElementIterator> {
912   auto elementType = getType().getElementType().cast<FloatType>();
913   assert(elementType.isa<FloatType>() && "expected float type");
914   const auto &elementSemantics = elementType.getFloatSemantics();
915   return {FloatElementIterator(elementSemantics, raw_int_begin()),
916           FloatElementIterator(elementSemantics, raw_int_end())};
917 }
918 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
919   return getFloatValues().begin();
920 }
921 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
922   return getFloatValues().end();
923 }
924 
925 /// Return the raw storage data held by this attribute.
926 ArrayRef<char> DenseElementsAttr::getRawData() const {
927   return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data;
928 }
929 
930 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
931   return static_cast<DenseStringElementsAttributeStorage *>(impl)->data;
932 }
933 
934 /// Return a new DenseElementsAttr that has the same data as the current
935 /// attribute, but has been reshaped to 'newType'. The new type must have the
936 /// same total number of elements as well as element type.
937 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
938   ShapedType curType = getType();
939   if (curType == newType)
940     return *this;
941 
942   (void)curType;
943   assert(newType.getElementType() == curType.getElementType() &&
944          "expected the same element type");
945   assert(newType.getNumElements() == curType.getNumElements() &&
946          "expected the same number of elements");
947   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
948 }
949 
950 DenseElementsAttr
951 DenseElementsAttr::mapValues(Type newElementType,
952                              function_ref<APInt(const APInt &)> mapping) const {
953   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
954 }
955 
956 DenseElementsAttr DenseElementsAttr::mapValues(
957     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
958   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
959 }
960 
961 //===----------------------------------------------------------------------===//
962 // DenseStringElementsAttr
963 //===----------------------------------------------------------------------===//
964 
965 DenseStringElementsAttr
966 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
967   return Base::get(type.getContext(), StandardAttributes::DenseStringElements,
968                    type, values, (values.size() == 1));
969 }
970 
971 //===----------------------------------------------------------------------===//
972 // DenseIntOrFPElementsAttr
973 //===----------------------------------------------------------------------===//
974 
975 /// Constructs a dense elements attribute from an array of raw APInt values.
976 /// Each APInt value is expected to have the same bitwidth as the element type
977 /// of 'type'.
978 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
979                                                    ArrayRef<APInt> values) {
980   assert(hasSameElementsOrSplat(type, values));
981 
982   size_t bitWidth = getDenseElementBitWidth(type.getElementType());
983   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
984   std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
985                                 values.size());
986   for (unsigned i = 0, e = values.size(); i != e; ++i) {
987     assert(values[i].getBitWidth() == bitWidth);
988     writeBits(elementData.data(), i * storageBitWidth, values[i]);
989   }
990   return DenseIntOrFPElementsAttr::getRaw(type, elementData,
991                                           /*isSplat=*/(values.size() == 1));
992 }
993 
994 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
995                                                    ArrayRef<char> data,
996                                                    bool isSplat) {
997   assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
998          "type must be ranked tensor or vector");
999   assert(type.hasStaticShape() && "type must have static shape");
1000   return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements,
1001                    type, data, isSplat);
1002 }
1003 
1004 /// Overload of the raw 'get' method that asserts that the given type is of
1005 /// complex type. This method is used to verify type invariants that the
1006 /// templatized 'get' method cannot.
1007 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1008                                                           ArrayRef<char> data,
1009                                                           int64_t dataEltSize,
1010                                                           bool isInt,
1011                                                           bool isSigned) {
1012   assert(::isValidIntOrFloat(
1013       type.getElementType().cast<ComplexType>().getElementType(),
1014       dataEltSize / 2, isInt, isSigned));
1015 
1016   int64_t numElements = data.size() / dataEltSize;
1017   assert(numElements == 1 || numElements == type.getNumElements());
1018   return getRaw(type, data, /*isSplat=*/numElements == 1);
1019 }
1020 
1021 /// Overload of the 'getRaw' method that asserts that the given type is of
1022 /// integer type. This method is used to verify type invariants that the
1023 /// templatized 'get' method cannot.
1024 DenseElementsAttr
1025 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1026                                            int64_t dataEltSize, bool isInt,
1027                                            bool isSigned) {
1028   assert(
1029       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1030 
1031   int64_t numElements = data.size() / dataEltSize;
1032   assert(numElements == 1 || numElements == type.getNumElements());
1033   return getRaw(type, data, /*isSplat=*/numElements == 1);
1034 }
1035 
1036 //===----------------------------------------------------------------------===//
1037 // DenseFPElementsAttr
1038 //===----------------------------------------------------------------------===//
1039 
1040 template <typename Fn, typename Attr>
1041 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1042                                 Type newElementType,
1043                                 llvm::SmallVectorImpl<char> &data) {
1044   size_t bitWidth = getDenseElementBitWidth(newElementType);
1045   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1046 
1047   ShapedType newArrayType;
1048   if (inType.isa<RankedTensorType>())
1049     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1050   else if (inType.isa<UnrankedTensorType>())
1051     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1052   else if (inType.isa<VectorType>())
1053     newArrayType = VectorType::get(inType.getShape(), newElementType);
1054   else
1055     assert(newArrayType && "Unhandled tensor type");
1056 
1057   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1058   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1059 
1060   // Functor used to process a single element value of the attribute.
1061   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1062     auto newInt = mapping(value);
1063     assert(newInt.getBitWidth() == bitWidth);
1064     writeBits(data.data(), index * storageBitWidth, newInt);
1065   };
1066 
1067   // Check for the splat case.
1068   if (attr.isSplat()) {
1069     processElt(*attr.begin(), /*index=*/0);
1070     return newArrayType;
1071   }
1072 
1073   // Otherwise, process all of the element values.
1074   uint64_t elementIdx = 0;
1075   for (auto value : attr)
1076     processElt(value, elementIdx++);
1077   return newArrayType;
1078 }
1079 
1080 DenseElementsAttr DenseFPElementsAttr::mapValues(
1081     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1082   llvm::SmallVector<char, 8> elementData;
1083   auto newArrayType =
1084       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1085 
1086   return getRaw(newArrayType, elementData, isSplat());
1087 }
1088 
1089 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1090 bool DenseFPElementsAttr::classof(Attribute attr) {
1091   return attr.isa<DenseElementsAttr>() &&
1092          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1093 }
1094 
1095 //===----------------------------------------------------------------------===//
1096 // DenseIntElementsAttr
1097 //===----------------------------------------------------------------------===//
1098 
1099 DenseElementsAttr DenseIntElementsAttr::mapValues(
1100     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1101   llvm::SmallVector<char, 8> elementData;
1102   auto newArrayType =
1103       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1104 
1105   return getRaw(newArrayType, elementData, isSplat());
1106 }
1107 
1108 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1109 bool DenseIntElementsAttr::classof(Attribute attr) {
1110   return attr.isa<DenseElementsAttr>() &&
1111          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1112 }
1113 
1114 //===----------------------------------------------------------------------===//
1115 // OpaqueElementsAttr
1116 //===----------------------------------------------------------------------===//
1117 
1118 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
1119                                            StringRef bytes) {
1120   assert(TensorType::isValidElementType(type.getElementType()) &&
1121          "Input element type should be a valid tensor element type");
1122   return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type,
1123                    dialect, bytes);
1124 }
1125 
1126 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
1127 
1128 /// Return the value at the given index. If index does not refer to a valid
1129 /// element, then a null attribute is returned.
1130 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1131   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1132   if (Dialect *dialect = getDialect())
1133     return dialect->extractElementHook(*this, index);
1134   return Attribute();
1135 }
1136 
1137 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
1138 
1139 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1140   if (auto *d = getDialect())
1141     return d->decodeHook(*this, result);
1142   return true;
1143 }
1144 
1145 //===----------------------------------------------------------------------===//
1146 // SparseElementsAttr
1147 //===----------------------------------------------------------------------===//
1148 
1149 SparseElementsAttr SparseElementsAttr::get(ShapedType type,
1150                                            DenseElementsAttr indices,
1151                                            DenseElementsAttr values) {
1152   assert(indices.getType().getElementType().isInteger(64) &&
1153          "expected sparse indices to be 64-bit integer values");
1154   assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
1155          "type must be ranked tensor or vector");
1156   assert(type.hasStaticShape() && "type must have static shape");
1157   return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
1158                    indices.cast<DenseIntElementsAttr>(), values);
1159 }
1160 
1161 DenseIntElementsAttr SparseElementsAttr::getIndices() const {
1162   return getImpl()->indices;
1163 }
1164 
1165 DenseElementsAttr SparseElementsAttr::getValues() const {
1166   return getImpl()->values;
1167 }
1168 
1169 /// Return the value of the element at the given index.
1170 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1171   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1172   auto type = getType();
1173 
1174   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1175   // as a 1-D index array.
1176   auto sparseIndices = getIndices();
1177   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1178 
1179   // Check to see if the indices are a splat.
1180   if (sparseIndices.isSplat()) {
1181     // If the index is also not a splat of the index value, we know that the
1182     // value is zero.
1183     auto splatIndex = *sparseIndexValues.begin();
1184     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
1185       return getZeroAttr();
1186 
1187     // If the indices are a splat, we also expect the values to be a splat.
1188     assert(getValues().isSplat() && "expected splat values");
1189     return getValues().getSplatValue();
1190   }
1191 
1192   // Build a mapping between known indices and the offset of the stored element.
1193   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
1194   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1195   size_t rank = type.getRank();
1196   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1197     mappedIndices.try_emplace(
1198         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1199 
1200   // Look for the provided index key within the mapped indices. If the provided
1201   // index is not found, then return a zero attribute.
1202   auto it = mappedIndices.find(index);
1203   if (it == mappedIndices.end())
1204     return getZeroAttr();
1205 
1206   // Otherwise, return the held sparse value element.
1207   return getValues().getValue(it->second);
1208 }
1209 
1210 /// Get a zero APFloat for the given sparse attribute.
1211 APFloat SparseElementsAttr::getZeroAPFloat() const {
1212   auto eltType = getType().getElementType().cast<FloatType>();
1213   return APFloat(eltType.getFloatSemantics());
1214 }
1215 
1216 /// Get a zero APInt for the given sparse attribute.
1217 APInt SparseElementsAttr::getZeroAPInt() const {
1218   auto eltType = getType().getElementType().cast<IntegerType>();
1219   return APInt::getNullValue(eltType.getWidth());
1220 }
1221 
1222 /// Get a zero attribute for the given attribute type.
1223 Attribute SparseElementsAttr::getZeroAttr() const {
1224   auto eltType = getType().getElementType();
1225 
1226   // Handle floating point elements.
1227   if (eltType.isa<FloatType>())
1228     return FloatAttr::get(eltType, 0);
1229 
1230   // Otherwise, this is an integer.
1231   auto intEltTy = eltType.cast<IntegerType>();
1232   if (intEltTy.getWidth() == 1)
1233     return BoolAttr::get(false, eltType.getContext());
1234   return IntegerAttr::get(eltType, 0);
1235 }
1236 
1237 /// Flatten, and return, all of the sparse indices in this attribute in
1238 /// row-major order.
1239 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1240   std::vector<ptrdiff_t> flatSparseIndices;
1241 
1242   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1243   // as a 1-D index array.
1244   auto sparseIndices = getIndices();
1245   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1246   if (sparseIndices.isSplat()) {
1247     SmallVector<uint64_t, 8> indices(getType().getRank(),
1248                                      *sparseIndexValues.begin());
1249     flatSparseIndices.push_back(getFlattenedIndex(indices));
1250     return flatSparseIndices;
1251   }
1252 
1253   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1254   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1255   size_t rank = getType().getRank();
1256   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1257     flatSparseIndices.push_back(getFlattenedIndex(
1258         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1259   return flatSparseIndices;
1260 }
1261 
1262 //===----------------------------------------------------------------------===//
1263 // MutableDictionaryAttr
1264 //===----------------------------------------------------------------------===//
1265 
1266 MutableDictionaryAttr::MutableDictionaryAttr(
1267     ArrayRef<NamedAttribute> attributes) {
1268   setAttrs(attributes);
1269 }
1270 
1271 ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const {
1272   return attrs ? attrs.getValue() : llvm::None;
1273 }
1274 
1275 /// Replace the held attributes with ones provided in 'newAttrs'.
1276 void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) {
1277   // Don't create an attribute list if there are no attributes.
1278   if (attributes.empty())
1279     attrs = nullptr;
1280   else
1281     attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
1282 }
1283 
1284 /// Return the specified attribute if present, null otherwise.
1285 Attribute MutableDictionaryAttr::get(StringRef name) const {
1286   return attrs ? attrs.get(name) : nullptr;
1287 }
1288 
1289 /// Return the specified attribute if present, null otherwise.
1290 Attribute MutableDictionaryAttr::get(Identifier name) const {
1291   return attrs ? attrs.get(name) : nullptr;
1292 }
1293 
1294 /// Return the specified named attribute if present, None otherwise.
1295 Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const {
1296   return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1297 }
1298 Optional<NamedAttribute>
1299 MutableDictionaryAttr::getNamed(Identifier name) const {
1300   return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1301 }
1302 
1303 /// If the an attribute exists with the specified name, change it to the new
1304 /// value.  Otherwise, add a new attribute with the specified name/value.
1305 void MutableDictionaryAttr::set(Identifier name, Attribute value) {
1306   assert(value && "attributes may never be null");
1307 
1308   // Look for an existing value for the given name, and set it in-place.
1309   ArrayRef<NamedAttribute> values = getAttrs();
1310   auto it = llvm::find_if(
1311       values, [name](NamedAttribute attr) { return attr.first == name; });
1312   if (it != values.end()) {
1313     // Bail out early if the value is the same as what we already have.
1314     if (it->second == value)
1315       return;
1316 
1317     SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end());
1318     newAttrs[it - values.begin()].second = value;
1319     attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1320     return;
1321   }
1322 
1323   // Otherwise, insert the new attribute into its sorted position.
1324   it = llvm::lower_bound(values, name, compareNamedAttributeWithName);
1325   SmallVector<NamedAttribute, 8> newAttrs;
1326   newAttrs.reserve(values.size() + 1);
1327   newAttrs.append(values.begin(), it);
1328   newAttrs.push_back({name, value});
1329   newAttrs.append(it, values.end());
1330   attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1331 }
1332 
1333 /// Remove the attribute with the specified name if it exists.  The return
1334 /// value indicates whether the attribute was present or not.
1335 auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult {
1336   auto origAttrs = getAttrs();
1337   for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
1338     if (origAttrs[i].first == name) {
1339       // Handle the simple case of removing the only attribute in the list.
1340       if (e == 1) {
1341         attrs = nullptr;
1342         return RemoveResult::Removed;
1343       }
1344 
1345       SmallVector<NamedAttribute, 8> newAttrs;
1346       newAttrs.reserve(origAttrs.size() - 1);
1347       newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
1348       newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
1349       attrs = DictionaryAttr::getWithSorted(newAttrs,
1350                                             newAttrs[0].second.getContext());
1351       return RemoveResult::Removed;
1352     }
1353   }
1354   return RemoveResult::NotFound;
1355 }
1356