1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Attributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Diagnostics.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/Function.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Types.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Endian.h"
20 
21 using namespace mlir;
22 using namespace mlir::detail;
23 
24 //===----------------------------------------------------------------------===//
25 // AttributeStorage
26 //===----------------------------------------------------------------------===//
27 
28 AttributeStorage::AttributeStorage(Type type)
29     : type(type.getAsOpaquePointer()) {}
30 AttributeStorage::AttributeStorage() : type(nullptr) {}
31 
32 Type AttributeStorage::getType() const {
33   return Type::getFromOpaquePointer(type);
34 }
35 void AttributeStorage::setType(Type newType) {
36   type = newType.getAsOpaquePointer();
37 }
38 
39 //===----------------------------------------------------------------------===//
40 // Attribute
41 //===----------------------------------------------------------------------===//
42 
43 /// Return the type of this attribute.
44 Type Attribute::getType() const { return impl->getType(); }
45 
46 /// Return the context this attribute belongs to.
47 MLIRContext *Attribute::getContext() const { return getType().getContext(); }
48 
49 /// Get the dialect this attribute is registered to.
50 Dialect &Attribute::getDialect() const { return impl->getDialect(); }
51 
52 //===----------------------------------------------------------------------===//
53 // AffineMapAttr
54 //===----------------------------------------------------------------------===//
55 
56 AffineMapAttr AffineMapAttr::get(AffineMap value) {
57   return Base::get(value.getContext(), StandardAttributes::AffineMap, value);
58 }
59 
60 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
61 
62 //===----------------------------------------------------------------------===//
63 // ArrayAttr
64 //===----------------------------------------------------------------------===//
65 
66 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
67   return Base::get(context, StandardAttributes::Array, value);
68 }
69 
70 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
71 
72 Attribute ArrayAttr::operator[](unsigned idx) const {
73   assert(idx < size() && "index out of bounds");
74   return getValue()[idx];
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // DictionaryAttr
79 //===----------------------------------------------------------------------===//
80 
81 /// Helper function that does either an in place sort or sorts from source array
82 /// into destination. If inPlace then storage is both the source and the
83 /// destination, else value is the source and storage destination. Returns
84 /// whether source was sorted.
85 template <bool inPlace>
86 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
87                                SmallVectorImpl<NamedAttribute> &storage) {
88   // Specialize for the common case.
89   switch (value.size()) {
90   case 0:
91     // Zero already sorted.
92     break;
93   case 1:
94     // One already sorted but may need to be copied.
95     if (!inPlace)
96       storage.assign({value[0]});
97     break;
98   case 2: {
99     assert(value[0].first != value[1].first &&
100            "DictionaryAttr element names must be unique");
101     bool isSorted = value[0] < value[1];
102     if (inPlace) {
103       if (!isSorted)
104         std::swap(storage[0], storage[1]);
105     } else if (isSorted) {
106       storage.assign({value[0], value[1]});
107     } else {
108       storage.assign({value[1], value[0]});
109     }
110     return !isSorted;
111   }
112   default:
113     if (!inPlace)
114       storage.assign(value.begin(), value.end());
115     // Check to see they are sorted already.
116     bool isSorted = llvm::is_sorted(value);
117     if (!isSorted) {
118       // If not, do a general sort.
119       llvm::array_pod_sort(storage.begin(), storage.end());
120       value = storage;
121     }
122 
123     // Ensure that the attribute elements are unique.
124     assert(std::adjacent_find(value.begin(), value.end(),
125                               [](NamedAttribute l, NamedAttribute r) {
126                                 return l.first == r.first;
127                               }) == value.end() &&
128            "DictionaryAttr element names must be unique");
129     return !isSorted;
130   }
131   return false;
132 }
133 
134 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
135                           SmallVectorImpl<NamedAttribute> &storage) {
136   return dictionaryAttrSort</*inPlace=*/false>(value, storage);
137 }
138 
139 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
140   return dictionaryAttrSort</*inPlace=*/true>(array, array);
141 }
142 
143 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
144                                    MLIRContext *context) {
145   if (value.empty())
146     return DictionaryAttr::getEmpty(context);
147   assert(llvm::all_of(value,
148                       [](const NamedAttribute &attr) { return attr.second; }) &&
149          "value cannot have null entries");
150 
151   // We need to sort the element list to canonicalize it.
152   SmallVector<NamedAttribute, 8> storage;
153   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
154     value = storage;
155 
156   return Base::get(context, StandardAttributes::Dictionary, value);
157 }
158 /// Construct a dictionary with an array of values that is known to already be
159 /// sorted by name and uniqued.
160 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
161                                              MLIRContext *context) {
162   if (value.empty())
163     return DictionaryAttr::getEmpty(context);
164   // Ensure that the attribute elements are unique and sorted.
165   assert(llvm::is_sorted(value,
166                          [](NamedAttribute l, NamedAttribute r) {
167                            return l.first.strref() < r.first.strref();
168                          }) &&
169          "expected attribute values to be sorted");
170   assert(std::adjacent_find(value.begin(), value.end(),
171                             [](NamedAttribute l, NamedAttribute r) {
172                               return l.first == r.first;
173                             }) == value.end() &&
174          "DictionaryAttr element names must be unique");
175   return Base::get(context, StandardAttributes::Dictionary, value);
176 }
177 
178 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
179   return getImpl()->getElements();
180 }
181 
182 /// Return the specified attribute if present, null otherwise.
183 Attribute DictionaryAttr::get(StringRef name) const {
184   Optional<NamedAttribute> attr = getNamed(name);
185   return attr ? attr->second : nullptr;
186 }
187 Attribute DictionaryAttr::get(Identifier name) const {
188   Optional<NamedAttribute> attr = getNamed(name);
189   return attr ? attr->second : nullptr;
190 }
191 
192 /// Return the specified named attribute if present, None otherwise.
193 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
194   ArrayRef<NamedAttribute> values = getValue();
195   const auto *it = llvm::lower_bound(values, name);
196   return it != values.end() && it->first == name ? *it
197                                                  : Optional<NamedAttribute>();
198 }
199 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
200   for (auto elt : getValue())
201     if (elt.first == name)
202       return elt;
203   return llvm::None;
204 }
205 
206 DictionaryAttr::iterator DictionaryAttr::begin() const {
207   return getValue().begin();
208 }
209 DictionaryAttr::iterator DictionaryAttr::end() const {
210   return getValue().end();
211 }
212 size_t DictionaryAttr::size() const { return getValue().size(); }
213 
214 //===----------------------------------------------------------------------===//
215 // FloatAttr
216 //===----------------------------------------------------------------------===//
217 
218 FloatAttr FloatAttr::get(Type type, double value) {
219   return Base::get(type.getContext(), StandardAttributes::Float, type, value);
220 }
221 
222 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
223   return Base::getChecked(loc, StandardAttributes::Float, type, value);
224 }
225 
226 FloatAttr FloatAttr::get(Type type, const APFloat &value) {
227   return Base::get(type.getContext(), StandardAttributes::Float, type, value);
228 }
229 
230 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
231   return Base::getChecked(loc, StandardAttributes::Float, type, value);
232 }
233 
234 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
235 
236 double FloatAttr::getValueAsDouble() const {
237   return getValueAsDouble(getValue());
238 }
239 double FloatAttr::getValueAsDouble(APFloat value) {
240   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
241     bool losesInfo = false;
242     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
243                   &losesInfo);
244   }
245   return value.convertToDouble();
246 }
247 
248 /// Verify construction invariants.
249 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
250   if (!type.isa<FloatType>())
251     return emitError(loc, "expected floating point type");
252   return success();
253 }
254 
255 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
256                                                       double value) {
257   return verifyFloatTypeInvariants(loc, type);
258 }
259 
260 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
261                                                       const APFloat &value) {
262   // Verify that the type is correct.
263   if (failed(verifyFloatTypeInvariants(loc, type)))
264     return failure();
265 
266   // Verify that the type semantics match that of the value.
267   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
268     return emitError(
269         loc, "FloatAttr type doesn't match the type implied by its value");
270   }
271   return success();
272 }
273 
274 //===----------------------------------------------------------------------===//
275 // SymbolRefAttr
276 //===----------------------------------------------------------------------===//
277 
278 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
279   return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None)
280       .cast<FlatSymbolRefAttr>();
281 }
282 
283 SymbolRefAttr SymbolRefAttr::get(StringRef value,
284                                  ArrayRef<FlatSymbolRefAttr> nestedReferences,
285                                  MLIRContext *ctx) {
286   return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences);
287 }
288 
289 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
290 
291 StringRef SymbolRefAttr::getLeafReference() const {
292   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
293   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
294 }
295 
296 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
297   return getImpl()->getNestedRefs();
298 }
299 
300 //===----------------------------------------------------------------------===//
301 // IntegerAttr
302 //===----------------------------------------------------------------------===//
303 
304 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
305   if (type.isSignlessInteger(1))
306     return BoolAttr::get(value.getBoolValue(), type.getContext());
307   return Base::get(type.getContext(), StandardAttributes::Integer, type, value);
308 }
309 
310 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
311   // This uses 64 bit APInts by default for index type.
312   if (type.isIndex())
313     return get(type, APInt(IndexType::kInternalStorageBitWidth, value));
314 
315   auto intType = type.cast<IntegerType>();
316   return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
317 }
318 
319 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
320 
321 int64_t IntegerAttr::getInt() const {
322   assert((getImpl()->getType().isIndex() ||
323           getImpl()->getType().isSignlessInteger()) &&
324          "must be signless integer");
325   return getValue().getSExtValue();
326 }
327 
328 int64_t IntegerAttr::getSInt() const {
329   assert(getImpl()->getType().isSignedInteger() && "must be signed integer");
330   return getValue().getSExtValue();
331 }
332 
333 uint64_t IntegerAttr::getUInt() const {
334   assert(getImpl()->getType().isUnsignedInteger() &&
335          "must be unsigned integer");
336   return getValue().getZExtValue();
337 }
338 
339 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
340   if (type.isa<IntegerType>() || type.isa<IndexType>())
341     return success();
342   return emitError(loc, "expected integer or index type");
343 }
344 
345 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
346                                                         int64_t value) {
347   return verifyIntegerTypeInvariants(loc, type);
348 }
349 
350 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
351                                                         const APInt &value) {
352   if (failed(verifyIntegerTypeInvariants(loc, type)))
353     return failure();
354   if (auto integerType = type.dyn_cast<IntegerType>())
355     if (integerType.getWidth() != value.getBitWidth())
356       return emitError(loc, "integer type bit width (")
357              << integerType.getWidth() << ") doesn't match value bit width ("
358              << value.getBitWidth() << ")";
359   return success();
360 }
361 
362 //===----------------------------------------------------------------------===//
363 // BoolAttr
364 
365 bool BoolAttr::getValue() const {
366   auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl);
367   return storage->getValue().getBoolValue();
368 }
369 
370 bool BoolAttr::classof(Attribute attr) {
371   IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
372   return intAttr && intAttr.getType().isSignlessInteger(1);
373 }
374 
375 //===----------------------------------------------------------------------===//
376 // IntegerSetAttr
377 //===----------------------------------------------------------------------===//
378 
379 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
380   return Base::get(value.getConstraint(0).getContext(),
381                    StandardAttributes::IntegerSet, value);
382 }
383 
384 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
385 
386 //===----------------------------------------------------------------------===//
387 // OpaqueAttr
388 //===----------------------------------------------------------------------===//
389 
390 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
391                            MLIRContext *context) {
392   return Base::get(context, StandardAttributes::Opaque, dialect, attrData,
393                    type);
394 }
395 
396 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
397                                   Type type, Location location) {
398   return Base::getChecked(location, StandardAttributes::Opaque, dialect,
399                           attrData, type);
400 }
401 
402 /// Returns the dialect namespace of the opaque attribute.
403 Identifier OpaqueAttr::getDialectNamespace() const {
404   return getImpl()->dialectNamespace;
405 }
406 
407 /// Returns the raw attribute data of the opaque attribute.
408 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
409 
410 /// Verify the construction of an opaque attribute.
411 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
412                                                        Identifier dialect,
413                                                        StringRef attrData,
414                                                        Type type) {
415   if (!Dialect::isValidNamespace(dialect.strref()))
416     return emitError(loc, "invalid dialect namespace '") << dialect << "'";
417   return success();
418 }
419 
420 //===----------------------------------------------------------------------===//
421 // StringAttr
422 //===----------------------------------------------------------------------===//
423 
424 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
425   return get(bytes, NoneType::get(context));
426 }
427 
428 /// Get an instance of a StringAttr with the given string and Type.
429 StringAttr StringAttr::get(StringRef bytes, Type type) {
430   return Base::get(type.getContext(), StandardAttributes::String, bytes, type);
431 }
432 
433 StringRef StringAttr::getValue() const { return getImpl()->value; }
434 
435 //===----------------------------------------------------------------------===//
436 // TypeAttr
437 //===----------------------------------------------------------------------===//
438 
439 TypeAttr TypeAttr::get(Type value) {
440   return Base::get(value.getContext(), StandardAttributes::Type, value);
441 }
442 
443 Type TypeAttr::getValue() const { return getImpl()->value; }
444 
445 //===----------------------------------------------------------------------===//
446 // ElementsAttr
447 //===----------------------------------------------------------------------===//
448 
449 ShapedType ElementsAttr::getType() const {
450   return Attribute::getType().cast<ShapedType>();
451 }
452 
453 /// Returns the number of elements held by this attribute.
454 int64_t ElementsAttr::getNumElements() const {
455   return getType().getNumElements();
456 }
457 
458 /// Return the value at the given index. If index does not refer to a valid
459 /// element, then a null attribute is returned.
460 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
461   switch (getKind()) {
462   case StandardAttributes::DenseIntOrFPElements:
463     return cast<DenseElementsAttr>().getValue(index);
464   case StandardAttributes::OpaqueElements:
465     return cast<OpaqueElementsAttr>().getValue(index);
466   case StandardAttributes::SparseElements:
467     return cast<SparseElementsAttr>().getValue(index);
468   default:
469     llvm_unreachable("unknown ElementsAttr kind");
470   }
471 }
472 
473 /// Return if the given 'index' refers to a valid element in this attribute.
474 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
475   auto type = getType();
476 
477   // Verify that the rank of the indices matches the held type.
478   auto rank = type.getRank();
479   if (rank != static_cast<int64_t>(index.size()))
480     return false;
481 
482   // Verify that all of the indices are within the shape dimensions.
483   auto shape = type.getShape();
484   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
485     return static_cast<int64_t>(index[i]) < shape[i];
486   });
487 }
488 
489 ElementsAttr
490 ElementsAttr::mapValues(Type newElementType,
491                         function_ref<APInt(const APInt &)> mapping) const {
492   switch (getKind()) {
493   case StandardAttributes::DenseIntOrFPElements:
494     return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
495   default:
496     llvm_unreachable("unsupported ElementsAttr subtype");
497   }
498 }
499 
500 ElementsAttr
501 ElementsAttr::mapValues(Type newElementType,
502                         function_ref<APInt(const APFloat &)> mapping) const {
503   switch (getKind()) {
504   case StandardAttributes::DenseIntOrFPElements:
505     return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
506   default:
507     llvm_unreachable("unsupported ElementsAttr subtype");
508   }
509 }
510 
511 /// Returns the 1 dimensional flattened row-major index from the given
512 /// multi-dimensional index.
513 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
514   assert(isValidIndex(index) && "expected valid multi-dimensional index");
515   auto type = getType();
516 
517   // Reduce the provided multidimensional index into a flattended 1D row-major
518   // index.
519   auto rank = type.getRank();
520   auto shape = type.getShape();
521   uint64_t valueIndex = 0;
522   uint64_t dimMultiplier = 1;
523   for (int i = rank - 1; i >= 0; --i) {
524     valueIndex += index[i] * dimMultiplier;
525     dimMultiplier *= shape[i];
526   }
527   return valueIndex;
528 }
529 
530 //===----------------------------------------------------------------------===//
531 // DenseElementAttr Utilities
532 //===----------------------------------------------------------------------===//
533 
534 /// Get the bitwidth of a dense element type within the buffer.
535 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
536 static size_t getDenseElementStorageWidth(size_t origWidth) {
537   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
538 }
539 static size_t getDenseElementStorageWidth(Type elementType) {
540   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
541 }
542 
543 /// Set a bit to a specific value.
544 static void setBit(char *rawData, size_t bitPos, bool value) {
545   if (value)
546     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
547   else
548     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
549 }
550 
551 /// Return the value of the specified bit.
552 static bool getBit(const char *rawData, size_t bitPos) {
553   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
554 }
555 
556 /// Get start position of actual data in `value`. Actual data is
557 /// stored in last `bitWidth`/CHAR_BIT bytes in big endian.
558 static char *getAPIntDataPos(APInt &value, size_t bitWidth) {
559   char *dataPos =
560       const_cast<char *>(reinterpret_cast<const char *>(value.getRawData()));
561   if (llvm::support::endian::system_endianness() ==
562       llvm::support::endianness::big)
563     dataPos = dataPos + 8 - llvm::divideCeil(bitWidth, CHAR_BIT);
564   return dataPos;
565 }
566 
567 /// Read APInt `value` from appropriate position.
568 static void readAPInt(APInt &value, size_t bitWidth, char *outData) {
569   char *dataPos = getAPIntDataPos(value, bitWidth);
570   std::copy_n(dataPos, llvm::divideCeil(bitWidth, CHAR_BIT), outData);
571 }
572 
573 /// Write `inData` to appropriate position of APInt `value`.
574 static void writeAPInt(const char *inData, size_t bitWidth, APInt &value) {
575   char *dataPos = getAPIntDataPos(value, bitWidth);
576   std::copy_n(inData, llvm::divideCeil(bitWidth, CHAR_BIT), dataPos);
577 }
578 
579 /// Writes value to the bit position `bitPos` in array `rawData`.
580 static void writeBits(char *rawData, size_t bitPos, APInt value) {
581   size_t bitWidth = value.getBitWidth();
582 
583   // If the bitwidth is 1 we just toggle the specific bit.
584   if (bitWidth == 1)
585     return setBit(rawData, bitPos, value.isOneValue());
586 
587   // Otherwise, the bit position is guaranteed to be byte aligned.
588   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
589   readAPInt(value, bitWidth, rawData + (bitPos / CHAR_BIT));
590 }
591 
592 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
593 /// `rawData`.
594 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
595   // Handle a boolean bit position.
596   if (bitWidth == 1)
597     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
598 
599   // Otherwise, the bit position must be 8-bit aligned.
600   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
601   APInt result(bitWidth, 0);
602   writeAPInt(rawData + (bitPos / CHAR_BIT), bitWidth, result);
603   return result;
604 }
605 
606 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the
607 /// same element count as 'type'.
608 template <typename Values>
609 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
610   return (values.size() == 1) ||
611          (type.getNumElements() == static_cast<int64_t>(values.size()));
612 }
613 
614 //===----------------------------------------------------------------------===//
615 // DenseElementAttr Iterators
616 //===----------------------------------------------------------------------===//
617 
618 //===----------------------------------------------------------------------===//
619 // AttributeElementIterator
620 
621 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
622     DenseElementsAttr attr, size_t index)
623     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
624                                       Attribute, Attribute, Attribute>(
625           attr.getAsOpaquePointer(), index) {}
626 
627 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
628   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
629   Type eltTy = owner.getType().getElementType();
630   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
631     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
632   if (eltTy.isa<IndexType>())
633     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
634   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
635     IntElementIterator intIt(owner, index);
636     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
637     return FloatAttr::get(eltTy, *floatIt);
638   }
639   if (owner.isa<DenseStringElementsAttr>()) {
640     ArrayRef<StringRef> vals = owner.getRawStringData();
641     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
642   }
643   llvm_unreachable("unexpected element type");
644 }
645 
646 //===----------------------------------------------------------------------===//
647 // BoolElementIterator
648 
649 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
650     DenseElementsAttr attr, size_t dataIndex)
651     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
652           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
653 
654 bool DenseElementsAttr::BoolElementIterator::operator*() const {
655   return getBit(getData(), getDataIndex());
656 }
657 
658 //===----------------------------------------------------------------------===//
659 // IntElementIterator
660 
661 DenseElementsAttr::IntElementIterator::IntElementIterator(
662     DenseElementsAttr attr, size_t dataIndex)
663     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
664           attr.getRawData().data(), attr.isSplat(), dataIndex),
665       bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
666 
667 APInt DenseElementsAttr::IntElementIterator::operator*() const {
668   return readBits(getData(),
669                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
670                   bitWidth);
671 }
672 
673 //===----------------------------------------------------------------------===//
674 // ComplexIntElementIterator
675 
676 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
677     DenseElementsAttr attr, size_t dataIndex)
678     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
679                                       std::complex<APInt>, std::complex<APInt>,
680                                       std::complex<APInt>>(
681           attr.getRawData().data(), attr.isSplat(), dataIndex) {
682   auto complexType = attr.getType().getElementType().cast<ComplexType>();
683   bitWidth = getDenseElementBitWidth(complexType.getElementType());
684 }
685 
686 std::complex<APInt>
687 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
688   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
689   size_t offset = getDataIndex() * storageWidth * 2;
690   return {readBits(getData(), offset, bitWidth),
691           readBits(getData(), offset + storageWidth, bitWidth)};
692 }
693 
694 //===----------------------------------------------------------------------===//
695 // FloatElementIterator
696 
697 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
698     const llvm::fltSemantics &smt, IntElementIterator it)
699     : llvm::mapped_iterator<IntElementIterator,
700                             std::function<APFloat(const APInt &)>>(
701           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
702 
703 //===----------------------------------------------------------------------===//
704 // ComplexFloatElementIterator
705 
706 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
707     const llvm::fltSemantics &smt, ComplexIntElementIterator it)
708     : llvm::mapped_iterator<
709           ComplexIntElementIterator,
710           std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
711           it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
712             return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
713           }) {}
714 
715 //===----------------------------------------------------------------------===//
716 // DenseElementsAttr
717 //===----------------------------------------------------------------------===//
718 
719 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
720                                          ArrayRef<Attribute> values) {
721   assert(hasSameElementsOrSplat(type, values));
722 
723   // If the element type is not based on int/float/index, assume it is a string
724   // type.
725   auto eltType = type.getElementType();
726   if (!type.getElementType().isIntOrIndexOrFloat()) {
727     SmallVector<StringRef, 8> stringValues;
728     stringValues.reserve(values.size());
729     for (Attribute attr : values) {
730       assert(attr.isa<StringAttr>() &&
731              "expected string value for non integer/index/float element");
732       stringValues.push_back(attr.cast<StringAttr>().getValue());
733     }
734     return get(type, stringValues);
735   }
736 
737   // Otherwise, get the raw storage width to use for the allocation.
738   size_t bitWidth = getDenseElementBitWidth(eltType);
739   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
740 
741   // Compress the attribute values into a character buffer.
742   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
743                             values.size());
744   APInt intVal;
745   for (unsigned i = 0, e = values.size(); i < e; ++i) {
746     assert(eltType == values[i].getType() &&
747            "expected attribute value to have element type");
748 
749     switch (eltType.getKind()) {
750     case StandardTypes::BF16:
751     case StandardTypes::F16:
752     case StandardTypes::F32:
753     case StandardTypes::F64:
754       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
755       break;
756     case StandardTypes::Integer:
757     case StandardTypes::Index:
758       intVal = values[i].cast<IntegerAttr>().getValue();
759       break;
760     default:
761       llvm_unreachable("unexpected element type");
762     }
763     assert(intVal.getBitWidth() == bitWidth &&
764            "expected value to have same bitwidth as element type");
765     writeBits(data.data(), i * storageBitWidth, intVal);
766   }
767   return DenseIntOrFPElementsAttr::getRaw(type, data,
768                                           /*isSplat=*/(values.size() == 1));
769 }
770 
771 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
772                                          ArrayRef<bool> values) {
773   assert(hasSameElementsOrSplat(type, values));
774   assert(type.getElementType().isInteger(1));
775 
776   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
777   for (int i = 0, e = values.size(); i != e; ++i)
778     setBit(buff.data(), i, values[i]);
779   return DenseIntOrFPElementsAttr::getRaw(type, buff,
780                                           /*isSplat=*/(values.size() == 1));
781 }
782 
783 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
784                                          ArrayRef<StringRef> values) {
785   assert(!type.getElementType().isIntOrFloat());
786   return DenseStringElementsAttr::get(type, values);
787 }
788 
789 /// Constructs a dense integer elements attribute from an array of APInt
790 /// values. Each APInt value is expected to have the same bitwidth as the
791 /// element type of 'type'.
792 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
793                                          ArrayRef<APInt> values) {
794   assert(type.getElementType().isIntOrIndex());
795   assert(hasSameElementsOrSplat(type, values));
796   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
797   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
798                                           /*isSplat=*/(values.size() == 1));
799 }
800 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
801                                          ArrayRef<std::complex<APInt>> values) {
802   ComplexType complex = type.getElementType().cast<ComplexType>();
803   assert(complex.getElementType().isa<IntegerType>());
804   assert(hasSameElementsOrSplat(type, values));
805   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
806   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
807                           values.size() * 2);
808   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
809                                           /*isSplat=*/(values.size() == 1));
810 }
811 
812 // Constructs a dense float elements attribute from an array of APFloat
813 // values. Each APFloat value is expected to have the same bitwidth as the
814 // element type of 'type'.
815 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
816                                          ArrayRef<APFloat> values) {
817   assert(type.getElementType().isa<FloatType>());
818   assert(hasSameElementsOrSplat(type, values));
819   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
820   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
821                                           /*isSplat=*/(values.size() == 1));
822 }
823 DenseElementsAttr
824 DenseElementsAttr::get(ShapedType type,
825                        ArrayRef<std::complex<APFloat>> values) {
826   ComplexType complex = type.getElementType().cast<ComplexType>();
827   assert(complex.getElementType().isa<FloatType>());
828   assert(hasSameElementsOrSplat(type, values));
829   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
830                            values.size() * 2);
831   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
832   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
833                                           /*isSplat=*/(values.size() == 1));
834 }
835 
836 /// Construct a dense elements attribute from a raw buffer representing the
837 /// data for this attribute. Users should generally not use this methods as
838 /// the expected buffer format may not be a form the user expects.
839 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
840                                                       ArrayRef<char> rawBuffer,
841                                                       bool isSplatBuffer) {
842   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
843 }
844 
845 /// Returns true if the given buffer is a valid raw buffer for the given type.
846 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
847                                          ArrayRef<char> rawBuffer,
848                                          bool &detectedSplat) {
849   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
850   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
851 
852   // Storage width of 1 is special as it is packed by the bit.
853   if (storageWidth == 1) {
854     // Check for a splat, or a buffer equal to the number of elements.
855     if ((detectedSplat = rawBuffer.size() == 1))
856       return true;
857     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
858   }
859   // All other types are 8-bit aligned.
860   if ((detectedSplat = rawBufferWidth == storageWidth))
861     return true;
862   return rawBufferWidth == (storageWidth * type.getNumElements());
863 }
864 
865 /// Check the information for a C++ data type, check if this type is valid for
866 /// the current attribute. This method is used to verify specific type
867 /// invariants that the templatized 'getValues' method cannot.
868 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
869                               bool isSigned) {
870   // Make sure that the data element size is the same as the type element width.
871   if (getDenseElementBitWidth(type) !=
872       static_cast<size_t>(dataEltSize * CHAR_BIT))
873     return false;
874 
875   // Check that the element type is either float or integer or index.
876   if (!isInt)
877     return type.isa<FloatType>();
878   if (type.isIndex())
879     return true;
880 
881   auto intType = type.dyn_cast<IntegerType>();
882   if (!intType)
883     return false;
884 
885   // Make sure signedness semantics is consistent.
886   if (intType.isSignless())
887     return true;
888   return intType.isSigned() ? isSigned : !isSigned;
889 }
890 
891 /// Defaults down the subclass implementation.
892 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
893                                                    ArrayRef<char> data,
894                                                    int64_t dataEltSize,
895                                                    bool isInt, bool isSigned) {
896   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
897                                                  isSigned);
898 }
899 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
900                                                       ArrayRef<char> data,
901                                                       int64_t dataEltSize,
902                                                       bool isInt,
903                                                       bool isSigned) {
904   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
905                                                     isInt, isSigned);
906 }
907 
908 /// A method used to verify specific type invariants that the templatized 'get'
909 /// method cannot.
910 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
911                                           bool isSigned) const {
912   return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
913                              isSigned);
914 }
915 
916 /// Check the information for a C++ data type, check if this type is valid for
917 /// the current attribute.
918 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
919                                        bool isSigned) const {
920   return ::isValidIntOrFloat(
921       getType().getElementType().cast<ComplexType>().getElementType(),
922       dataEltSize / 2, isInt, isSigned);
923 }
924 
925 /// Returns if this attribute corresponds to a splat, i.e. if all element
926 /// values are the same.
927 bool DenseElementsAttr::isSplat() const {
928   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
929 }
930 
931 /// Return the held element values as a range of Attributes.
932 auto DenseElementsAttr::getAttributeValues() const
933     -> llvm::iterator_range<AttributeElementIterator> {
934   return {attr_value_begin(), attr_value_end()};
935 }
936 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
937   return AttributeElementIterator(*this, 0);
938 }
939 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
940   return AttributeElementIterator(*this, getNumElements());
941 }
942 
943 /// Return the held element values as a range of bool. The element type of
944 /// this attribute must be of integer type of bitwidth 1.
945 auto DenseElementsAttr::getBoolValues() const
946     -> llvm::iterator_range<BoolElementIterator> {
947   auto eltType = getType().getElementType().dyn_cast<IntegerType>();
948   assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
949   (void)eltType;
950   return {BoolElementIterator(*this, 0),
951           BoolElementIterator(*this, getNumElements())};
952 }
953 
954 /// Return the held element values as a range of APInts. The element type of
955 /// this attribute must be of integer type.
956 auto DenseElementsAttr::getIntValues() const
957     -> llvm::iterator_range<IntElementIterator> {
958   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
959   return {raw_int_begin(), raw_int_end()};
960 }
961 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
962   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
963   return raw_int_begin();
964 }
965 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
966   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
967   return raw_int_end();
968 }
969 auto DenseElementsAttr::getComplexIntValues() const
970     -> llvm::iterator_range<ComplexIntElementIterator> {
971   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
972   (void)eltTy;
973   assert(eltTy.isa<IntegerType>() && "expected complex integral type");
974   return {ComplexIntElementIterator(*this, 0),
975           ComplexIntElementIterator(*this, getNumElements())};
976 }
977 
978 /// Return the held element values as a range of APFloat. The element type of
979 /// this attribute must be of float type.
980 auto DenseElementsAttr::getFloatValues() const
981     -> llvm::iterator_range<FloatElementIterator> {
982   auto elementType = getType().getElementType().cast<FloatType>();
983   const auto &elementSemantics = elementType.getFloatSemantics();
984   return {FloatElementIterator(elementSemantics, raw_int_begin()),
985           FloatElementIterator(elementSemantics, raw_int_end())};
986 }
987 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
988   return getFloatValues().begin();
989 }
990 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
991   return getFloatValues().end();
992 }
993 auto DenseElementsAttr::getComplexFloatValues() const
994     -> llvm::iterator_range<ComplexFloatElementIterator> {
995   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
996   assert(eltTy.isa<FloatType>() && "expected complex float type");
997   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
998   return {{semantics, {*this, 0}},
999           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
1000 }
1001 
1002 /// Return the raw storage data held by this attribute.
1003 ArrayRef<char> DenseElementsAttr::getRawData() const {
1004   return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data;
1005 }
1006 
1007 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1008   return static_cast<DenseStringElementsAttributeStorage *>(impl)->data;
1009 }
1010 
1011 /// Return a new DenseElementsAttr that has the same data as the current
1012 /// attribute, but has been reshaped to 'newType'. The new type must have the
1013 /// same total number of elements as well as element type.
1014 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1015   ShapedType curType = getType();
1016   if (curType == newType)
1017     return *this;
1018 
1019   (void)curType;
1020   assert(newType.getElementType() == curType.getElementType() &&
1021          "expected the same element type");
1022   assert(newType.getNumElements() == curType.getNumElements() &&
1023          "expected the same number of elements");
1024   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
1025 }
1026 
1027 DenseElementsAttr
1028 DenseElementsAttr::mapValues(Type newElementType,
1029                              function_ref<APInt(const APInt &)> mapping) const {
1030   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1031 }
1032 
1033 DenseElementsAttr DenseElementsAttr::mapValues(
1034     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1035   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1036 }
1037 
1038 //===----------------------------------------------------------------------===//
1039 // DenseStringElementsAttr
1040 //===----------------------------------------------------------------------===//
1041 
1042 DenseStringElementsAttr
1043 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
1044   return Base::get(type.getContext(), StandardAttributes::DenseStringElements,
1045                    type, values, (values.size() == 1));
1046 }
1047 
1048 //===----------------------------------------------------------------------===//
1049 // DenseIntOrFPElementsAttr
1050 //===----------------------------------------------------------------------===//
1051 
1052 /// Utility method to write a range of APInt values to a buffer.
1053 template <typename APRangeT>
1054 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1055                                 APRangeT &&values) {
1056   data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
1057   size_t offset = 0;
1058   for (auto it = values.begin(), e = values.end(); it != e;
1059        ++it, offset += storageWidth) {
1060     assert((*it).getBitWidth() <= storageWidth);
1061     writeBits(data.data(), offset, *it);
1062   }
1063 }
1064 
1065 /// Constructs a dense elements attribute from an array of raw APFloat values.
1066 /// Each APFloat value is expected to have the same bitwidth as the element
1067 /// type of 'type'. 'type' must be a vector or tensor with static shape.
1068 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1069                                                    size_t storageWidth,
1070                                                    ArrayRef<APFloat> values,
1071                                                    bool isSplat) {
1072   std::vector<char> data;
1073   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1074   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1075   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1076 }
1077 
1078 /// Constructs a dense elements attribute from an array of raw APInt values.
1079 /// Each APInt value is expected to have the same bitwidth as the element type
1080 /// of 'type'.
1081 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1082                                                    size_t storageWidth,
1083                                                    ArrayRef<APInt> values,
1084                                                    bool isSplat) {
1085   std::vector<char> data;
1086   writeAPIntsToBuffer(storageWidth, data, values);
1087   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1088 }
1089 
1090 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1091                                                    ArrayRef<char> data,
1092                                                    bool isSplat) {
1093   assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
1094          "type must be ranked tensor or vector");
1095   assert(type.hasStaticShape() && "type must have static shape");
1096   return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements,
1097                    type, data, isSplat);
1098 }
1099 
1100 /// Overload of the raw 'get' method that asserts that the given type is of
1101 /// complex type. This method is used to verify type invariants that the
1102 /// templatized 'get' method cannot.
1103 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1104                                                           ArrayRef<char> data,
1105                                                           int64_t dataEltSize,
1106                                                           bool isInt,
1107                                                           bool isSigned) {
1108   assert(::isValidIntOrFloat(
1109       type.getElementType().cast<ComplexType>().getElementType(),
1110       dataEltSize / 2, isInt, isSigned));
1111 
1112   int64_t numElements = data.size() / dataEltSize;
1113   assert(numElements == 1 || numElements == type.getNumElements());
1114   return getRaw(type, data, /*isSplat=*/numElements == 1);
1115 }
1116 
1117 /// Overload of the 'getRaw' method that asserts that the given type is of
1118 /// integer type. This method is used to verify type invariants that the
1119 /// templatized 'get' method cannot.
1120 DenseElementsAttr
1121 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1122                                            int64_t dataEltSize, bool isInt,
1123                                            bool isSigned) {
1124   assert(
1125       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1126 
1127   int64_t numElements = data.size() / dataEltSize;
1128   assert(numElements == 1 || numElements == type.getNumElements());
1129   return getRaw(type, data, /*isSplat=*/numElements == 1);
1130 }
1131 
1132 //===----------------------------------------------------------------------===//
1133 // DenseFPElementsAttr
1134 //===----------------------------------------------------------------------===//
1135 
1136 template <typename Fn, typename Attr>
1137 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1138                                 Type newElementType,
1139                                 llvm::SmallVectorImpl<char> &data) {
1140   size_t bitWidth = getDenseElementBitWidth(newElementType);
1141   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1142 
1143   ShapedType newArrayType;
1144   if (inType.isa<RankedTensorType>())
1145     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1146   else if (inType.isa<UnrankedTensorType>())
1147     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1148   else if (inType.isa<VectorType>())
1149     newArrayType = VectorType::get(inType.getShape(), newElementType);
1150   else
1151     assert(newArrayType && "Unhandled tensor type");
1152 
1153   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1154   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1155 
1156   // Functor used to process a single element value of the attribute.
1157   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1158     auto newInt = mapping(value);
1159     assert(newInt.getBitWidth() == bitWidth);
1160     writeBits(data.data(), index * storageBitWidth, newInt);
1161   };
1162 
1163   // Check for the splat case.
1164   if (attr.isSplat()) {
1165     processElt(*attr.begin(), /*index=*/0);
1166     return newArrayType;
1167   }
1168 
1169   // Otherwise, process all of the element values.
1170   uint64_t elementIdx = 0;
1171   for (auto value : attr)
1172     processElt(value, elementIdx++);
1173   return newArrayType;
1174 }
1175 
1176 DenseElementsAttr DenseFPElementsAttr::mapValues(
1177     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1178   llvm::SmallVector<char, 8> elementData;
1179   auto newArrayType =
1180       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1181 
1182   return getRaw(newArrayType, elementData, isSplat());
1183 }
1184 
1185 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1186 bool DenseFPElementsAttr::classof(Attribute attr) {
1187   return attr.isa<DenseElementsAttr>() &&
1188          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1189 }
1190 
1191 //===----------------------------------------------------------------------===//
1192 // DenseIntElementsAttr
1193 //===----------------------------------------------------------------------===//
1194 
1195 DenseElementsAttr DenseIntElementsAttr::mapValues(
1196     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1197   llvm::SmallVector<char, 8> elementData;
1198   auto newArrayType =
1199       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1200 
1201   return getRaw(newArrayType, elementData, isSplat());
1202 }
1203 
1204 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1205 bool DenseIntElementsAttr::classof(Attribute attr) {
1206   return attr.isa<DenseElementsAttr>() &&
1207          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1208 }
1209 
1210 //===----------------------------------------------------------------------===//
1211 // OpaqueElementsAttr
1212 //===----------------------------------------------------------------------===//
1213 
1214 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
1215                                            StringRef bytes) {
1216   assert(TensorType::isValidElementType(type.getElementType()) &&
1217          "Input element type should be a valid tensor element type");
1218   return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type,
1219                    dialect, bytes);
1220 }
1221 
1222 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
1223 
1224 /// Return the value at the given index. If index does not refer to a valid
1225 /// element, then a null attribute is returned.
1226 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1227   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1228   if (Dialect *dialect = getDialect())
1229     return dialect->extractElementHook(*this, index);
1230   return Attribute();
1231 }
1232 
1233 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
1234 
1235 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1236   if (auto *d = getDialect())
1237     return d->decodeHook(*this, result);
1238   return true;
1239 }
1240 
1241 //===----------------------------------------------------------------------===//
1242 // SparseElementsAttr
1243 //===----------------------------------------------------------------------===//
1244 
1245 SparseElementsAttr SparseElementsAttr::get(ShapedType type,
1246                                            DenseElementsAttr indices,
1247                                            DenseElementsAttr values) {
1248   assert(indices.getType().getElementType().isInteger(64) &&
1249          "expected sparse indices to be 64-bit integer values");
1250   assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
1251          "type must be ranked tensor or vector");
1252   assert(type.hasStaticShape() && "type must have static shape");
1253   return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
1254                    indices.cast<DenseIntElementsAttr>(), values);
1255 }
1256 
1257 DenseIntElementsAttr SparseElementsAttr::getIndices() const {
1258   return getImpl()->indices;
1259 }
1260 
1261 DenseElementsAttr SparseElementsAttr::getValues() const {
1262   return getImpl()->values;
1263 }
1264 
1265 /// Return the value of the element at the given index.
1266 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1267   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1268   auto type = getType();
1269 
1270   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1271   // as a 1-D index array.
1272   auto sparseIndices = getIndices();
1273   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1274 
1275   // Check to see if the indices are a splat.
1276   if (sparseIndices.isSplat()) {
1277     // If the index is also not a splat of the index value, we know that the
1278     // value is zero.
1279     auto splatIndex = *sparseIndexValues.begin();
1280     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
1281       return getZeroAttr();
1282 
1283     // If the indices are a splat, we also expect the values to be a splat.
1284     assert(getValues().isSplat() && "expected splat values");
1285     return getValues().getSplatValue();
1286   }
1287 
1288   // Build a mapping between known indices and the offset of the stored element.
1289   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
1290   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1291   size_t rank = type.getRank();
1292   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1293     mappedIndices.try_emplace(
1294         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1295 
1296   // Look for the provided index key within the mapped indices. If the provided
1297   // index is not found, then return a zero attribute.
1298   auto it = mappedIndices.find(index);
1299   if (it == mappedIndices.end())
1300     return getZeroAttr();
1301 
1302   // Otherwise, return the held sparse value element.
1303   return getValues().getValue(it->second);
1304 }
1305 
1306 /// Get a zero APFloat for the given sparse attribute.
1307 APFloat SparseElementsAttr::getZeroAPFloat() const {
1308   auto eltType = getType().getElementType().cast<FloatType>();
1309   return APFloat(eltType.getFloatSemantics());
1310 }
1311 
1312 /// Get a zero APInt for the given sparse attribute.
1313 APInt SparseElementsAttr::getZeroAPInt() const {
1314   auto eltType = getType().getElementType().cast<IntegerType>();
1315   return APInt::getNullValue(eltType.getWidth());
1316 }
1317 
1318 /// Get a zero attribute for the given attribute type.
1319 Attribute SparseElementsAttr::getZeroAttr() const {
1320   auto eltType = getType().getElementType();
1321 
1322   // Handle floating point elements.
1323   if (eltType.isa<FloatType>())
1324     return FloatAttr::get(eltType, 0);
1325 
1326   // Otherwise, this is an integer.
1327   // TODO: Handle StringAttr here.
1328   return IntegerAttr::get(eltType, 0);
1329 }
1330 
1331 /// Flatten, and return, all of the sparse indices in this attribute in
1332 /// row-major order.
1333 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1334   std::vector<ptrdiff_t> flatSparseIndices;
1335 
1336   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1337   // as a 1-D index array.
1338   auto sparseIndices = getIndices();
1339   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1340   if (sparseIndices.isSplat()) {
1341     SmallVector<uint64_t, 8> indices(getType().getRank(),
1342                                      *sparseIndexValues.begin());
1343     flatSparseIndices.push_back(getFlattenedIndex(indices));
1344     return flatSparseIndices;
1345   }
1346 
1347   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1348   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1349   size_t rank = getType().getRank();
1350   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1351     flatSparseIndices.push_back(getFlattenedIndex(
1352         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1353   return flatSparseIndices;
1354 }
1355 
1356 //===----------------------------------------------------------------------===//
1357 // MutableDictionaryAttr
1358 //===----------------------------------------------------------------------===//
1359 
1360 MutableDictionaryAttr::MutableDictionaryAttr(
1361     ArrayRef<NamedAttribute> attributes) {
1362   setAttrs(attributes);
1363 }
1364 
1365 /// Return the underlying dictionary attribute.
1366 DictionaryAttr
1367 MutableDictionaryAttr::getDictionary(MLIRContext *context) const {
1368   // Construct empty DictionaryAttr if needed.
1369   if (!attrs)
1370     return DictionaryAttr::get({}, context);
1371   return attrs;
1372 }
1373 
1374 ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const {
1375   return attrs ? attrs.getValue() : llvm::None;
1376 }
1377 
1378 /// Replace the held attributes with ones provided in 'newAttrs'.
1379 void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) {
1380   // Don't create an attribute list if there are no attributes.
1381   if (attributes.empty())
1382     attrs = nullptr;
1383   else
1384     attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
1385 }
1386 
1387 /// Return the specified attribute if present, null otherwise.
1388 Attribute MutableDictionaryAttr::get(StringRef name) const {
1389   return attrs ? attrs.get(name) : nullptr;
1390 }
1391 
1392 /// Return the specified attribute if present, null otherwise.
1393 Attribute MutableDictionaryAttr::get(Identifier name) const {
1394   return attrs ? attrs.get(name) : nullptr;
1395 }
1396 
1397 /// Return the specified named attribute if present, None otherwise.
1398 Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const {
1399   return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1400 }
1401 Optional<NamedAttribute>
1402 MutableDictionaryAttr::getNamed(Identifier name) const {
1403   return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1404 }
1405 
1406 /// If the an attribute exists with the specified name, change it to the new
1407 /// value.  Otherwise, add a new attribute with the specified name/value.
1408 void MutableDictionaryAttr::set(Identifier name, Attribute value) {
1409   assert(value && "attributes may never be null");
1410 
1411   // Look for an existing value for the given name, and set it in-place.
1412   ArrayRef<NamedAttribute> values = getAttrs();
1413   const auto *it = llvm::find_if(
1414       values, [name](NamedAttribute attr) { return attr.first == name; });
1415   if (it != values.end()) {
1416     // Bail out early if the value is the same as what we already have.
1417     if (it->second == value)
1418       return;
1419 
1420     SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end());
1421     newAttrs[it - values.begin()].second = value;
1422     attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1423     return;
1424   }
1425 
1426   // Otherwise, insert the new attribute into its sorted position.
1427   it = llvm::lower_bound(values, name);
1428   SmallVector<NamedAttribute, 8> newAttrs;
1429   newAttrs.reserve(values.size() + 1);
1430   newAttrs.append(values.begin(), it);
1431   newAttrs.push_back({name, value});
1432   newAttrs.append(it, values.end());
1433   attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1434 }
1435 
1436 /// Remove the attribute with the specified name if it exists.  The return
1437 /// value indicates whether the attribute was present or not.
1438 auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult {
1439   auto origAttrs = getAttrs();
1440   for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
1441     if (origAttrs[i].first == name) {
1442       // Handle the simple case of removing the only attribute in the list.
1443       if (e == 1) {
1444         attrs = nullptr;
1445         return RemoveResult::Removed;
1446       }
1447 
1448       SmallVector<NamedAttribute, 8> newAttrs;
1449       newAttrs.reserve(origAttrs.size() - 1);
1450       newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
1451       newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
1452       attrs = DictionaryAttr::getWithSorted(newAttrs,
1453                                             newAttrs[0].second.getContext());
1454       return RemoveResult::Removed;
1455     }
1456   }
1457   return RemoveResult::NotFound;
1458 }
1459 
1460 bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) {
1461   return strcmp(lhs.first.data(), rhs.first.data()) < 0;
1462 }
1463 bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) {
1464   // This is correct even when attr.first.data()[name.size()] is not a zero
1465   // string terminator, because we only care about a less than comparison.
1466   // This can't use memcmp, because it doesn't guarantee that it will stop
1467   // reading both buffers if one is shorter than the other, even if there is
1468   // a difference.
1469   return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0;
1470 }
1471