1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Attributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Diagnostics.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/Function.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Types.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Endian.h"
20 
21 using namespace mlir;
22 using namespace mlir::detail;
23 
24 //===----------------------------------------------------------------------===//
25 // AttributeStorage
26 //===----------------------------------------------------------------------===//
27 
28 AttributeStorage::AttributeStorage(Type type)
29     : type(type.getAsOpaquePointer()) {}
30 AttributeStorage::AttributeStorage() : type(nullptr) {}
31 
32 Type AttributeStorage::getType() const {
33   return Type::getFromOpaquePointer(type);
34 }
35 void AttributeStorage::setType(Type newType) {
36   type = newType.getAsOpaquePointer();
37 }
38 
39 //===----------------------------------------------------------------------===//
40 // Attribute
41 //===----------------------------------------------------------------------===//
42 
43 /// Return the type of this attribute.
44 Type Attribute::getType() const { return impl->getType(); }
45 
46 /// Return the context this attribute belongs to.
47 MLIRContext *Attribute::getContext() const { return getType().getContext(); }
48 
49 /// Get the dialect this attribute is registered to.
50 Dialect &Attribute::getDialect() const { return impl->getDialect(); }
51 
52 //===----------------------------------------------------------------------===//
53 // AffineMapAttr
54 //===----------------------------------------------------------------------===//
55 
56 AffineMapAttr AffineMapAttr::get(AffineMap value) {
57   return Base::get(value.getContext(), StandardAttributes::AffineMap, value);
58 }
59 
60 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
61 
62 //===----------------------------------------------------------------------===//
63 // ArrayAttr
64 //===----------------------------------------------------------------------===//
65 
66 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
67   return Base::get(context, StandardAttributes::Array, value);
68 }
69 
70 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
71 
72 Attribute ArrayAttr::operator[](unsigned idx) const {
73   assert(idx < size() && "index out of bounds");
74   return getValue()[idx];
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // BoolAttr
79 //===----------------------------------------------------------------------===//
80 
81 bool BoolAttr::getValue() const { return getImpl()->value; }
82 
83 //===----------------------------------------------------------------------===//
84 // DictionaryAttr
85 //===----------------------------------------------------------------------===//
86 
87 /// Perform a three-way comparison between the names of the specified
88 /// NamedAttributes.
89 static int compareNamedAttributes(const NamedAttribute *lhs,
90                                   const NamedAttribute *rhs) {
91   return strcmp(lhs->first.data(), rhs->first.data());
92 }
93 
94 /// Returns if the name of the given attribute precedes that of 'name'.
95 static bool compareNamedAttributeWithName(const NamedAttribute &attr,
96                                           StringRef name) {
97   // This is correct even when attr.first.data()[name.size()] is not a zero
98   // string terminator, because we only care about a less than comparison.
99   // This can't use memcmp, because it doesn't guarantee that it will stop
100   // reading both buffers if one is shorter than the other, even if there is
101   // a difference.
102   return strncmp(attr.first.data(), name.data(), name.size()) < 0;
103 }
104 
105 /// Helper function that does either an in place sort or sorts from source array
106 /// into destination. If inPlace then storage is both the source and the
107 /// destination, else value is the source and storage destination. Returns
108 /// whether source was sorted.
109 template <bool inPlace>
110 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
111                                SmallVectorImpl<NamedAttribute> &storage) {
112   // Specialize for the common case.
113   switch (value.size()) {
114   case 0:
115   case 1:
116     // Zero or one elements are already sorted.
117     break;
118   case 2:
119     assert(value[0].first != value[1].first &&
120            "DictionaryAttr element names must be unique");
121     if (compareNamedAttributes(&value[0], &value[1]) > 0) {
122       if (inPlace)
123         std::swap(storage[0], storage[1]);
124       else
125         storage.append({value[1], value[0]});
126       return true;
127     }
128     break;
129   default:
130     // Check to see they are sorted already.
131     bool isSorted =
132         llvm::is_sorted(value, [](NamedAttribute l, NamedAttribute r) {
133           return compareNamedAttributes(&l, &r) < 0;
134         });
135     if (!isSorted) {
136       // If not, do a general sort.
137       if (!inPlace)
138         storage.append(value.begin(), value.end());
139       llvm::array_pod_sort(storage.begin(), storage.end(),
140                            compareNamedAttributes);
141       value = storage;
142     }
143 
144     // Ensure that the attribute elements are unique.
145     assert(std::adjacent_find(value.begin(), value.end(),
146                               [](NamedAttribute l, NamedAttribute r) {
147                                 return l.first == r.first;
148                               }) == value.end() &&
149            "DictionaryAttr element names must be unique");
150     return !isSorted;
151   }
152   return false;
153 }
154 
155 /// Sorts the NamedAttributes in the array ordered by name as expected by
156 /// getWithSorted.
157 /// Requires: uniquely named attributes.
158 void DictionaryAttr::sort(SmallVectorImpl<NamedAttribute> &array) {
159   dictionaryAttrSort</*inPlace=*/true>(array, array);
160 }
161 
162 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
163                                    MLIRContext *context) {
164   assert(llvm::all_of(value,
165                       [](const NamedAttribute &attr) { return attr.second; }) &&
166          "value cannot have null entries");
167 
168   // We need to sort the element list to canonicalize it.
169   SmallVector<NamedAttribute, 8> storage;
170   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
171     value = storage;
172 
173   return Base::get(context, StandardAttributes::Dictionary, value);
174 }
175 
176 /// Construct a dictionary with an array of values that is known to already be
177 /// sorted by name and uniqued.
178 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
179                                              MLIRContext *context) {
180   // Ensure that the attribute elements are unique and sorted.
181   assert(llvm::is_sorted(value,
182                          [](NamedAttribute l, NamedAttribute r) {
183                            return l.first.strref() < r.first.strref();
184                          }) &&
185          "expected attribute values to be sorted");
186   assert(std::adjacent_find(value.begin(), value.end(),
187                             [](NamedAttribute l, NamedAttribute r) {
188                               return l.first == r.first;
189                             }) == value.end() &&
190          "DictionaryAttr element names must be unique");
191   return Base::get(context, StandardAttributes::Dictionary, value);
192 }
193 
194 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
195   return getImpl()->getElements();
196 }
197 
198 /// Return the specified attribute if present, null otherwise.
199 Attribute DictionaryAttr::get(StringRef name) const {
200   Optional<NamedAttribute> attr = getNamed(name);
201   return attr ? attr->second : nullptr;
202 }
203 Attribute DictionaryAttr::get(Identifier name) const {
204   Optional<NamedAttribute> attr = getNamed(name);
205   return attr ? attr->second : nullptr;
206 }
207 
208 /// Return the specified named attribute if present, None otherwise.
209 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
210   ArrayRef<NamedAttribute> values = getValue();
211   auto it = llvm::lower_bound(values, name, compareNamedAttributeWithName);
212   return it != values.end() && it->first == name ? *it
213                                                  : Optional<NamedAttribute>();
214 }
215 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
216   for (auto elt : getValue())
217     if (elt.first == name)
218       return elt;
219   return llvm::None;
220 }
221 
222 DictionaryAttr::iterator DictionaryAttr::begin() const {
223   return getValue().begin();
224 }
225 DictionaryAttr::iterator DictionaryAttr::end() const {
226   return getValue().end();
227 }
228 size_t DictionaryAttr::size() const { return getValue().size(); }
229 
230 //===----------------------------------------------------------------------===//
231 // FloatAttr
232 //===----------------------------------------------------------------------===//
233 
234 FloatAttr FloatAttr::get(Type type, double value) {
235   return Base::get(type.getContext(), StandardAttributes::Float, type, value);
236 }
237 
238 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
239   return Base::getChecked(loc, StandardAttributes::Float, type, value);
240 }
241 
242 FloatAttr FloatAttr::get(Type type, const APFloat &value) {
243   return Base::get(type.getContext(), StandardAttributes::Float, type, value);
244 }
245 
246 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
247   return Base::getChecked(loc, StandardAttributes::Float, type, value);
248 }
249 
250 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
251 
252 double FloatAttr::getValueAsDouble() const {
253   return getValueAsDouble(getValue());
254 }
255 double FloatAttr::getValueAsDouble(APFloat value) {
256   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
257     bool losesInfo = false;
258     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
259                   &losesInfo);
260   }
261   return value.convertToDouble();
262 }
263 
264 /// Verify construction invariants.
265 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
266   if (!type.isa<FloatType>())
267     return emitError(loc, "expected floating point type");
268   return success();
269 }
270 
271 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
272                                                       double value) {
273   return verifyFloatTypeInvariants(loc, type);
274 }
275 
276 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
277                                                       const APFloat &value) {
278   // Verify that the type is correct.
279   if (failed(verifyFloatTypeInvariants(loc, type)))
280     return failure();
281 
282   // Verify that the type semantics match that of the value.
283   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
284     return emitError(
285         loc, "FloatAttr type doesn't match the type implied by its value");
286   }
287   return success();
288 }
289 
290 //===----------------------------------------------------------------------===//
291 // SymbolRefAttr
292 //===----------------------------------------------------------------------===//
293 
294 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
295   return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None)
296       .cast<FlatSymbolRefAttr>();
297 }
298 
299 SymbolRefAttr SymbolRefAttr::get(StringRef value,
300                                  ArrayRef<FlatSymbolRefAttr> nestedReferences,
301                                  MLIRContext *ctx) {
302   return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences);
303 }
304 
305 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
306 
307 StringRef SymbolRefAttr::getLeafReference() const {
308   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
309   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
310 }
311 
312 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
313   return getImpl()->getNestedRefs();
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // IntegerAttr
318 //===----------------------------------------------------------------------===//
319 
320 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
321   return Base::get(type.getContext(), StandardAttributes::Integer, type, value);
322 }
323 
324 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
325   // This uses 64 bit APInts by default for index type.
326   if (type.isIndex())
327     return get(type, APInt(IndexType::kInternalStorageBitWidth, value));
328 
329   auto intType = type.cast<IntegerType>();
330   return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
331 }
332 
333 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
334 
335 int64_t IntegerAttr::getInt() const {
336   assert((getImpl()->getType().isIndex() ||
337           getImpl()->getType().isSignlessInteger()) &&
338          "must be signless integer");
339   return getValue().getSExtValue();
340 }
341 
342 int64_t IntegerAttr::getSInt() const {
343   assert(getImpl()->getType().isSignedInteger() && "must be signed integer");
344   return getValue().getSExtValue();
345 }
346 
347 uint64_t IntegerAttr::getUInt() const {
348   assert(getImpl()->getType().isUnsignedInteger() &&
349          "must be unsigned integer");
350   return getValue().getZExtValue();
351 }
352 
353 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
354   if (type.isa<IntegerType>() || type.isa<IndexType>())
355     return success();
356   return emitError(loc, "expected integer or index type");
357 }
358 
359 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
360                                                         int64_t value) {
361   return verifyIntegerTypeInvariants(loc, type);
362 }
363 
364 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
365                                                         const APInt &value) {
366   if (failed(verifyIntegerTypeInvariants(loc, type)))
367     return failure();
368   if (auto integerType = type.dyn_cast<IntegerType>())
369     if (integerType.getWidth() != value.getBitWidth())
370       return emitError(loc, "integer type bit width (")
371              << integerType.getWidth() << ") doesn't match value bit width ("
372              << value.getBitWidth() << ")";
373   return success();
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // IntegerSetAttr
378 //===----------------------------------------------------------------------===//
379 
380 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
381   return Base::get(value.getConstraint(0).getContext(),
382                    StandardAttributes::IntegerSet, value);
383 }
384 
385 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
386 
387 //===----------------------------------------------------------------------===//
388 // OpaqueAttr
389 //===----------------------------------------------------------------------===//
390 
391 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
392                            MLIRContext *context) {
393   return Base::get(context, StandardAttributes::Opaque, dialect, attrData,
394                    type);
395 }
396 
397 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
398                                   Type type, Location location) {
399   return Base::getChecked(location, StandardAttributes::Opaque, dialect,
400                           attrData, type);
401 }
402 
403 /// Returns the dialect namespace of the opaque attribute.
404 Identifier OpaqueAttr::getDialectNamespace() const {
405   return getImpl()->dialectNamespace;
406 }
407 
408 /// Returns the raw attribute data of the opaque attribute.
409 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
410 
411 /// Verify the construction of an opaque attribute.
412 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
413                                                        Identifier dialect,
414                                                        StringRef attrData,
415                                                        Type type) {
416   if (!Dialect::isValidNamespace(dialect.strref()))
417     return emitError(loc, "invalid dialect namespace '") << dialect << "'";
418   return success();
419 }
420 
421 //===----------------------------------------------------------------------===//
422 // StringAttr
423 //===----------------------------------------------------------------------===//
424 
425 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
426   return get(bytes, NoneType::get(context));
427 }
428 
429 /// Get an instance of a StringAttr with the given string and Type.
430 StringAttr StringAttr::get(StringRef bytes, Type type) {
431   return Base::get(type.getContext(), StandardAttributes::String, bytes, type);
432 }
433 
434 StringRef StringAttr::getValue() const { return getImpl()->value; }
435 
436 //===----------------------------------------------------------------------===//
437 // TypeAttr
438 //===----------------------------------------------------------------------===//
439 
440 TypeAttr TypeAttr::get(Type value) {
441   return Base::get(value.getContext(), StandardAttributes::Type, value);
442 }
443 
444 Type TypeAttr::getValue() const { return getImpl()->value; }
445 
446 //===----------------------------------------------------------------------===//
447 // ElementsAttr
448 //===----------------------------------------------------------------------===//
449 
450 ShapedType ElementsAttr::getType() const {
451   return Attribute::getType().cast<ShapedType>();
452 }
453 
454 /// Returns the number of elements held by this attribute.
455 int64_t ElementsAttr::getNumElements() const {
456   return getType().getNumElements();
457 }
458 
459 /// Return the value at the given index. If index does not refer to a valid
460 /// element, then a null attribute is returned.
461 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
462   switch (getKind()) {
463   case StandardAttributes::DenseIntOrFPElements:
464     return cast<DenseElementsAttr>().getValue(index);
465   case StandardAttributes::OpaqueElements:
466     return cast<OpaqueElementsAttr>().getValue(index);
467   case StandardAttributes::SparseElements:
468     return cast<SparseElementsAttr>().getValue(index);
469   default:
470     llvm_unreachable("unknown ElementsAttr kind");
471   }
472 }
473 
474 /// Return if the given 'index' refers to a valid element in this attribute.
475 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
476   auto type = getType();
477 
478   // Verify that the rank of the indices matches the held type.
479   auto rank = type.getRank();
480   if (rank != static_cast<int64_t>(index.size()))
481     return false;
482 
483   // Verify that all of the indices are within the shape dimensions.
484   auto shape = type.getShape();
485   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
486     return static_cast<int64_t>(index[i]) < shape[i];
487   });
488 }
489 
490 ElementsAttr
491 ElementsAttr::mapValues(Type newElementType,
492                         function_ref<APInt(const APInt &)> mapping) const {
493   switch (getKind()) {
494   case StandardAttributes::DenseIntOrFPElements:
495     return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
496   default:
497     llvm_unreachable("unsupported ElementsAttr subtype");
498   }
499 }
500 
501 ElementsAttr
502 ElementsAttr::mapValues(Type newElementType,
503                         function_ref<APInt(const APFloat &)> mapping) const {
504   switch (getKind()) {
505   case StandardAttributes::DenseIntOrFPElements:
506     return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
507   default:
508     llvm_unreachable("unsupported ElementsAttr subtype");
509   }
510 }
511 
512 /// Returns the 1 dimensional flattened row-major index from the given
513 /// multi-dimensional index.
514 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
515   assert(isValidIndex(index) && "expected valid multi-dimensional index");
516   auto type = getType();
517 
518   // Reduce the provided multidimensional index into a flattended 1D row-major
519   // index.
520   auto rank = type.getRank();
521   auto shape = type.getShape();
522   uint64_t valueIndex = 0;
523   uint64_t dimMultiplier = 1;
524   for (int i = rank - 1; i >= 0; --i) {
525     valueIndex += index[i] * dimMultiplier;
526     dimMultiplier *= shape[i];
527   }
528   return valueIndex;
529 }
530 
531 //===----------------------------------------------------------------------===//
532 // DenseElementAttr Utilities
533 //===----------------------------------------------------------------------===//
534 
535 /// Get the bitwidth of a dense element type within the buffer.
536 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
537 static size_t getDenseElementStorageWidth(size_t origWidth) {
538   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
539 }
540 static size_t getDenseElementStorageWidth(Type elementType) {
541   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
542 }
543 
544 /// Set a bit to a specific value.
545 static void setBit(char *rawData, size_t bitPos, bool value) {
546   if (value)
547     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
548   else
549     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
550 }
551 
552 /// Return the value of the specified bit.
553 static bool getBit(const char *rawData, size_t bitPos) {
554   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
555 }
556 
557 /// Get start position of actual data in `value`. Actual data is
558 /// stored in last `bitWidth`/CHAR_BIT bytes in big endian.
559 static char *getAPIntDataPos(APInt &value, size_t bitWidth) {
560   char *dataPos =
561       const_cast<char *>(reinterpret_cast<const char *>(value.getRawData()));
562   if (llvm::support::endian::system_endianness() ==
563       llvm::support::endianness::big)
564     dataPos = dataPos + 8 - llvm::divideCeil(bitWidth, CHAR_BIT);
565   return dataPos;
566 }
567 
568 /// Read APInt `value` from appropriate position.
569 static void readAPInt(APInt &value, size_t bitWidth, char *outData) {
570   char *dataPos = getAPIntDataPos(value, bitWidth);
571   std::copy_n(dataPos, llvm::divideCeil(bitWidth, CHAR_BIT), outData);
572 }
573 
574 /// Write `inData` to appropriate position of APInt `value`.
575 static void writeAPInt(const char *inData, size_t bitWidth, APInt &value) {
576   char *dataPos = getAPIntDataPos(value, bitWidth);
577   std::copy_n(inData, llvm::divideCeil(bitWidth, CHAR_BIT), dataPos);
578 }
579 
580 /// Writes value to the bit position `bitPos` in array `rawData`.
581 static void writeBits(char *rawData, size_t bitPos, APInt value) {
582   size_t bitWidth = value.getBitWidth();
583 
584   // If the bitwidth is 1 we just toggle the specific bit.
585   if (bitWidth == 1)
586     return setBit(rawData, bitPos, value.isOneValue());
587 
588   // Otherwise, the bit position is guaranteed to be byte aligned.
589   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
590   readAPInt(value, bitWidth, rawData + (bitPos / CHAR_BIT));
591 }
592 
593 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
594 /// `rawData`.
595 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
596   // Handle a boolean bit position.
597   if (bitWidth == 1)
598     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
599 
600   // Otherwise, the bit position must be 8-bit aligned.
601   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
602   APInt result(bitWidth, 0);
603   writeAPInt(rawData + (bitPos / CHAR_BIT), bitWidth, result);
604   return result;
605 }
606 
607 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the
608 /// same element count as 'type'.
609 template <typename Values>
610 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
611   return (values.size() == 1) ||
612          (type.getNumElements() == static_cast<int64_t>(values.size()));
613 }
614 
615 //===----------------------------------------------------------------------===//
616 // DenseElementAttr Iterators
617 //===----------------------------------------------------------------------===//
618 
619 //===----------------------------------------------------------------------===//
620 // AttributeElementIterator
621 
622 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
623     DenseElementsAttr attr, size_t index)
624     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
625                                       Attribute, Attribute, Attribute>(
626           attr.getAsOpaquePointer(), index) {}
627 
628 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
629   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
630   Type eltTy = owner.getType().getElementType();
631   if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) {
632     if (intEltTy.getWidth() == 1)
633       return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(),
634                            owner.getContext());
635     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
636   }
637   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
638     IntElementIterator intIt(owner, index);
639     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
640     return FloatAttr::get(eltTy, *floatIt);
641   }
642   if (owner.isa<DenseStringElementsAttr>()) {
643     ArrayRef<StringRef> vals = owner.getRawStringData();
644     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
645   }
646   llvm_unreachable("unexpected element type");
647 }
648 
649 //===----------------------------------------------------------------------===//
650 // BoolElementIterator
651 
652 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
653     DenseElementsAttr attr, size_t dataIndex)
654     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
655           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
656 
657 bool DenseElementsAttr::BoolElementIterator::operator*() const {
658   return getBit(getData(), getDataIndex());
659 }
660 
661 //===----------------------------------------------------------------------===//
662 // IntElementIterator
663 
664 DenseElementsAttr::IntElementIterator::IntElementIterator(
665     DenseElementsAttr attr, size_t dataIndex)
666     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
667           attr.getRawData().data(), attr.isSplat(), dataIndex),
668       bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
669 
670 APInt DenseElementsAttr::IntElementIterator::operator*() const {
671   return readBits(getData(),
672                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
673                   bitWidth);
674 }
675 
676 //===----------------------------------------------------------------------===//
677 // ComplexIntElementIterator
678 
679 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
680     DenseElementsAttr attr, size_t dataIndex)
681     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
682                                       std::complex<APInt>, std::complex<APInt>,
683                                       std::complex<APInt>>(
684           attr.getRawData().data(), attr.isSplat(), dataIndex) {
685   auto complexType = attr.getType().getElementType().cast<ComplexType>();
686   bitWidth = getDenseElementBitWidth(complexType.getElementType());
687 }
688 
689 std::complex<APInt>
690 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
691   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
692   size_t offset = getDataIndex() * storageWidth * 2;
693   return {readBits(getData(), offset, bitWidth),
694           readBits(getData(), offset + storageWidth, bitWidth)};
695 }
696 
697 //===----------------------------------------------------------------------===//
698 // FloatElementIterator
699 
700 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
701     const llvm::fltSemantics &smt, IntElementIterator it)
702     : llvm::mapped_iterator<IntElementIterator,
703                             std::function<APFloat(const APInt &)>>(
704           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
705 
706 //===----------------------------------------------------------------------===//
707 // ComplexFloatElementIterator
708 
709 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
710     const llvm::fltSemantics &smt, ComplexIntElementIterator it)
711     : llvm::mapped_iterator<
712           ComplexIntElementIterator,
713           std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
714           it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
715             return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
716           }) {}
717 
718 //===----------------------------------------------------------------------===//
719 // DenseElementsAttr
720 //===----------------------------------------------------------------------===//
721 
722 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
723                                          ArrayRef<Attribute> values) {
724   assert(hasSameElementsOrSplat(type, values));
725 
726   // If the element type is not based on int/float/index, assume it is a string
727   // type.
728   auto eltType = type.getElementType();
729   if (!type.getElementType().isIntOrIndexOrFloat()) {
730     SmallVector<StringRef, 8> stringValues;
731     stringValues.reserve(values.size());
732     for (Attribute attr : values) {
733       assert(attr.isa<StringAttr>() &&
734              "expected string value for non integer/index/float element");
735       stringValues.push_back(attr.cast<StringAttr>().getValue());
736     }
737     return get(type, stringValues);
738   }
739 
740   // Otherwise, get the raw storage width to use for the allocation.
741   size_t bitWidth = getDenseElementBitWidth(eltType);
742   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
743 
744   // Compress the attribute values into a character buffer.
745   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
746                             values.size());
747   APInt intVal;
748   for (unsigned i = 0, e = values.size(); i < e; ++i) {
749     assert(eltType == values[i].getType() &&
750            "expected attribute value to have element type");
751 
752     switch (eltType.getKind()) {
753     case StandardTypes::BF16:
754     case StandardTypes::F16:
755     case StandardTypes::F32:
756     case StandardTypes::F64:
757       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
758       break;
759     case StandardTypes::Integer:
760     case StandardTypes::Index:
761       intVal = values[i].isa<BoolAttr>()
762                    ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0)
763                    : values[i].cast<IntegerAttr>().getValue();
764       break;
765     default:
766       llvm_unreachable("unexpected element type");
767     }
768     assert(intVal.getBitWidth() == bitWidth &&
769            "expected value to have same bitwidth as element type");
770     writeBits(data.data(), i * storageBitWidth, intVal);
771   }
772   return DenseIntOrFPElementsAttr::getRaw(type, data,
773                                           /*isSplat=*/(values.size() == 1));
774 }
775 
776 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
777                                          ArrayRef<bool> values) {
778   assert(hasSameElementsOrSplat(type, values));
779   assert(type.getElementType().isInteger(1));
780 
781   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
782   for (int i = 0, e = values.size(); i != e; ++i)
783     setBit(buff.data(), i, values[i]);
784   return DenseIntOrFPElementsAttr::getRaw(type, buff,
785                                           /*isSplat=*/(values.size() == 1));
786 }
787 
788 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
789                                          ArrayRef<StringRef> values) {
790   assert(!type.getElementType().isIntOrFloat());
791   return DenseStringElementsAttr::get(type, values);
792 }
793 
794 /// Constructs a dense integer elements attribute from an array of APInt
795 /// values. Each APInt value is expected to have the same bitwidth as the
796 /// element type of 'type'.
797 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
798                                          ArrayRef<APInt> values) {
799   assert(type.getElementType().isIntOrIndex());
800   assert(hasSameElementsOrSplat(type, values));
801   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
802   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
803                                           /*isSplat=*/(values.size() == 1));
804 }
805 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
806                                          ArrayRef<std::complex<APInt>> values) {
807   ComplexType complex = type.getElementType().cast<ComplexType>();
808   assert(complex.getElementType().isa<IntegerType>());
809   assert(hasSameElementsOrSplat(type, values));
810   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
811   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
812                           values.size() * 2);
813   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
814                                           /*isSplat=*/(values.size() == 1));
815 }
816 
817 // Constructs a dense float elements attribute from an array of APFloat
818 // values. Each APFloat value is expected to have the same bitwidth as the
819 // element type of 'type'.
820 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
821                                          ArrayRef<APFloat> values) {
822   assert(type.getElementType().isa<FloatType>());
823   assert(hasSameElementsOrSplat(type, values));
824   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
825   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
826                                           /*isSplat=*/(values.size() == 1));
827 }
828 DenseElementsAttr
829 DenseElementsAttr::get(ShapedType type,
830                        ArrayRef<std::complex<APFloat>> values) {
831   ComplexType complex = type.getElementType().cast<ComplexType>();
832   assert(complex.getElementType().isa<FloatType>());
833   assert(hasSameElementsOrSplat(type, values));
834   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
835                            values.size() * 2);
836   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
837   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
838                                           /*isSplat=*/(values.size() == 1));
839 }
840 
841 /// Construct a dense elements attribute from a raw buffer representing the
842 /// data for this attribute. Users should generally not use this methods as
843 /// the expected buffer format may not be a form the user expects.
844 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
845                                                       ArrayRef<char> rawBuffer,
846                                                       bool isSplatBuffer) {
847   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
848 }
849 
850 /// Returns true if the given buffer is a valid raw buffer for the given type.
851 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
852                                          ArrayRef<char> rawBuffer,
853                                          bool &detectedSplat) {
854   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
855   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
856 
857   // Storage width of 1 is special as it is packed by the bit.
858   if (storageWidth == 1) {
859     // Check for a splat, or a buffer equal to the number of elements.
860     if ((detectedSplat = rawBuffer.size() == 1))
861       return true;
862     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
863   }
864   // All other types are 8-bit aligned.
865   if ((detectedSplat = rawBufferWidth == storageWidth))
866     return true;
867   return rawBufferWidth == (storageWidth * type.getNumElements());
868 }
869 
870 /// Check the information for a C++ data type, check if this type is valid for
871 /// the current attribute. This method is used to verify specific type
872 /// invariants that the templatized 'getValues' method cannot.
873 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
874                               bool isSigned) {
875   // Make sure that the data element size is the same as the type element width.
876   if (getDenseElementBitWidth(type) !=
877       static_cast<size_t>(dataEltSize * CHAR_BIT))
878     return false;
879 
880   // Check that the element type is either float or integer or index.
881   if (!isInt)
882     return type.isa<FloatType>();
883   if (type.isIndex())
884     return true;
885 
886   auto intType = type.dyn_cast<IntegerType>();
887   if (!intType)
888     return false;
889 
890   // Make sure signedness semantics is consistent.
891   if (intType.isSignless())
892     return true;
893   return intType.isSigned() ? isSigned : !isSigned;
894 }
895 
896 /// Defaults down the subclass implementation.
897 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
898                                                    ArrayRef<char> data,
899                                                    int64_t dataEltSize,
900                                                    bool isInt, bool isSigned) {
901   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
902                                                  isSigned);
903 }
904 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
905                                                       ArrayRef<char> data,
906                                                       int64_t dataEltSize,
907                                                       bool isInt,
908                                                       bool isSigned) {
909   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
910                                                     isInt, isSigned);
911 }
912 
913 /// A method used to verify specific type invariants that the templatized 'get'
914 /// method cannot.
915 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
916                                           bool isSigned) const {
917   return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
918                              isSigned);
919 }
920 
921 /// Check the information for a C++ data type, check if this type is valid for
922 /// the current attribute.
923 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
924                                        bool isSigned) const {
925   return ::isValidIntOrFloat(
926       getType().getElementType().cast<ComplexType>().getElementType(),
927       dataEltSize / 2, isInt, isSigned);
928 }
929 
930 /// Returns if this attribute corresponds to a splat, i.e. if all element
931 /// values are the same.
932 bool DenseElementsAttr::isSplat() const {
933   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
934 }
935 
936 /// Return the held element values as a range of Attributes.
937 auto DenseElementsAttr::getAttributeValues() const
938     -> llvm::iterator_range<AttributeElementIterator> {
939   return {attr_value_begin(), attr_value_end()};
940 }
941 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
942   return AttributeElementIterator(*this, 0);
943 }
944 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
945   return AttributeElementIterator(*this, getNumElements());
946 }
947 
948 /// Return the held element values as a range of bool. The element type of
949 /// this attribute must be of integer type of bitwidth 1.
950 auto DenseElementsAttr::getBoolValues() const
951     -> llvm::iterator_range<BoolElementIterator> {
952   auto eltType = getType().getElementType().dyn_cast<IntegerType>();
953   assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
954   (void)eltType;
955   return {BoolElementIterator(*this, 0),
956           BoolElementIterator(*this, getNumElements())};
957 }
958 
959 /// Return the held element values as a range of APInts. The element type of
960 /// this attribute must be of integer type.
961 auto DenseElementsAttr::getIntValues() const
962     -> llvm::iterator_range<IntElementIterator> {
963   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
964   return {raw_int_begin(), raw_int_end()};
965 }
966 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
967   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
968   return raw_int_begin();
969 }
970 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
971   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
972   return raw_int_end();
973 }
974 auto DenseElementsAttr::getComplexIntValues() const
975     -> llvm::iterator_range<ComplexIntElementIterator> {
976   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
977   (void)eltTy;
978   assert(eltTy.isa<IntegerType>() && "expected complex integral type");
979   return {ComplexIntElementIterator(*this, 0),
980           ComplexIntElementIterator(*this, getNumElements())};
981 }
982 
983 /// Return the held element values as a range of APFloat. The element type of
984 /// this attribute must be of float type.
985 auto DenseElementsAttr::getFloatValues() const
986     -> llvm::iterator_range<FloatElementIterator> {
987   auto elementType = getType().getElementType().cast<FloatType>();
988   const auto &elementSemantics = elementType.getFloatSemantics();
989   return {FloatElementIterator(elementSemantics, raw_int_begin()),
990           FloatElementIterator(elementSemantics, raw_int_end())};
991 }
992 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
993   return getFloatValues().begin();
994 }
995 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
996   return getFloatValues().end();
997 }
998 auto DenseElementsAttr::getComplexFloatValues() const
999     -> llvm::iterator_range<ComplexFloatElementIterator> {
1000   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
1001   assert(eltTy.isa<FloatType>() && "expected complex float type");
1002   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
1003   return {{semantics, {*this, 0}},
1004           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
1005 }
1006 
1007 /// Return the raw storage data held by this attribute.
1008 ArrayRef<char> DenseElementsAttr::getRawData() const {
1009   return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data;
1010 }
1011 
1012 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1013   return static_cast<DenseStringElementsAttributeStorage *>(impl)->data;
1014 }
1015 
1016 /// Return a new DenseElementsAttr that has the same data as the current
1017 /// attribute, but has been reshaped to 'newType'. The new type must have the
1018 /// same total number of elements as well as element type.
1019 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1020   ShapedType curType = getType();
1021   if (curType == newType)
1022     return *this;
1023 
1024   (void)curType;
1025   assert(newType.getElementType() == curType.getElementType() &&
1026          "expected the same element type");
1027   assert(newType.getNumElements() == curType.getNumElements() &&
1028          "expected the same number of elements");
1029   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
1030 }
1031 
1032 DenseElementsAttr
1033 DenseElementsAttr::mapValues(Type newElementType,
1034                              function_ref<APInt(const APInt &)> mapping) const {
1035   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1036 }
1037 
1038 DenseElementsAttr DenseElementsAttr::mapValues(
1039     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1040   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1041 }
1042 
1043 //===----------------------------------------------------------------------===//
1044 // DenseStringElementsAttr
1045 //===----------------------------------------------------------------------===//
1046 
1047 DenseStringElementsAttr
1048 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
1049   return Base::get(type.getContext(), StandardAttributes::DenseStringElements,
1050                    type, values, (values.size() == 1));
1051 }
1052 
1053 //===----------------------------------------------------------------------===//
1054 // DenseIntOrFPElementsAttr
1055 //===----------------------------------------------------------------------===//
1056 
1057 /// Utility method to write a range of APInt values to a buffer.
1058 template <typename APRangeT>
1059 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1060                                 APRangeT &&values) {
1061   data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
1062   size_t offset = 0;
1063   for (auto it = values.begin(), e = values.end(); it != e;
1064        ++it, offset += storageWidth) {
1065     assert((*it).getBitWidth() <= storageWidth);
1066     writeBits(data.data(), offset, *it);
1067   }
1068 }
1069 
1070 /// Constructs a dense elements attribute from an array of raw APFloat values.
1071 /// Each APFloat value is expected to have the same bitwidth as the element
1072 /// type of 'type'. 'type' must be a vector or tensor with static shape.
1073 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1074                                                    size_t storageWidth,
1075                                                    ArrayRef<APFloat> values,
1076                                                    bool isSplat) {
1077   std::vector<char> data;
1078   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1079   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1080   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1081 }
1082 
1083 /// Constructs a dense elements attribute from an array of raw APInt values.
1084 /// Each APInt value is expected to have the same bitwidth as the element type
1085 /// of 'type'.
1086 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1087                                                    size_t storageWidth,
1088                                                    ArrayRef<APInt> values,
1089                                                    bool isSplat) {
1090   std::vector<char> data;
1091   writeAPIntsToBuffer(storageWidth, data, values);
1092   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1093 }
1094 
1095 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1096                                                    ArrayRef<char> data,
1097                                                    bool isSplat) {
1098   assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
1099          "type must be ranked tensor or vector");
1100   assert(type.hasStaticShape() && "type must have static shape");
1101   return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements,
1102                    type, data, isSplat);
1103 }
1104 
1105 /// Overload of the raw 'get' method that asserts that the given type is of
1106 /// complex type. This method is used to verify type invariants that the
1107 /// templatized 'get' method cannot.
1108 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1109                                                           ArrayRef<char> data,
1110                                                           int64_t dataEltSize,
1111                                                           bool isInt,
1112                                                           bool isSigned) {
1113   assert(::isValidIntOrFloat(
1114       type.getElementType().cast<ComplexType>().getElementType(),
1115       dataEltSize / 2, isInt, isSigned));
1116 
1117   int64_t numElements = data.size() / dataEltSize;
1118   assert(numElements == 1 || numElements == type.getNumElements());
1119   return getRaw(type, data, /*isSplat=*/numElements == 1);
1120 }
1121 
1122 /// Overload of the 'getRaw' method that asserts that the given type is of
1123 /// integer type. This method is used to verify type invariants that the
1124 /// templatized 'get' method cannot.
1125 DenseElementsAttr
1126 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1127                                            int64_t dataEltSize, bool isInt,
1128                                            bool isSigned) {
1129   assert(
1130       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1131 
1132   int64_t numElements = data.size() / dataEltSize;
1133   assert(numElements == 1 || numElements == type.getNumElements());
1134   return getRaw(type, data, /*isSplat=*/numElements == 1);
1135 }
1136 
1137 //===----------------------------------------------------------------------===//
1138 // DenseFPElementsAttr
1139 //===----------------------------------------------------------------------===//
1140 
1141 template <typename Fn, typename Attr>
1142 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1143                                 Type newElementType,
1144                                 llvm::SmallVectorImpl<char> &data) {
1145   size_t bitWidth = getDenseElementBitWidth(newElementType);
1146   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1147 
1148   ShapedType newArrayType;
1149   if (inType.isa<RankedTensorType>())
1150     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1151   else if (inType.isa<UnrankedTensorType>())
1152     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1153   else if (inType.isa<VectorType>())
1154     newArrayType = VectorType::get(inType.getShape(), newElementType);
1155   else
1156     assert(newArrayType && "Unhandled tensor type");
1157 
1158   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1159   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1160 
1161   // Functor used to process a single element value of the attribute.
1162   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1163     auto newInt = mapping(value);
1164     assert(newInt.getBitWidth() == bitWidth);
1165     writeBits(data.data(), index * storageBitWidth, newInt);
1166   };
1167 
1168   // Check for the splat case.
1169   if (attr.isSplat()) {
1170     processElt(*attr.begin(), /*index=*/0);
1171     return newArrayType;
1172   }
1173 
1174   // Otherwise, process all of the element values.
1175   uint64_t elementIdx = 0;
1176   for (auto value : attr)
1177     processElt(value, elementIdx++);
1178   return newArrayType;
1179 }
1180 
1181 DenseElementsAttr DenseFPElementsAttr::mapValues(
1182     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1183   llvm::SmallVector<char, 8> elementData;
1184   auto newArrayType =
1185       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1186 
1187   return getRaw(newArrayType, elementData, isSplat());
1188 }
1189 
1190 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1191 bool DenseFPElementsAttr::classof(Attribute attr) {
1192   return attr.isa<DenseElementsAttr>() &&
1193          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1194 }
1195 
1196 //===----------------------------------------------------------------------===//
1197 // DenseIntElementsAttr
1198 //===----------------------------------------------------------------------===//
1199 
1200 DenseElementsAttr DenseIntElementsAttr::mapValues(
1201     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1202   llvm::SmallVector<char, 8> elementData;
1203   auto newArrayType =
1204       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1205 
1206   return getRaw(newArrayType, elementData, isSplat());
1207 }
1208 
1209 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1210 bool DenseIntElementsAttr::classof(Attribute attr) {
1211   return attr.isa<DenseElementsAttr>() &&
1212          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1213 }
1214 
1215 //===----------------------------------------------------------------------===//
1216 // OpaqueElementsAttr
1217 //===----------------------------------------------------------------------===//
1218 
1219 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
1220                                            StringRef bytes) {
1221   assert(TensorType::isValidElementType(type.getElementType()) &&
1222          "Input element type should be a valid tensor element type");
1223   return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type,
1224                    dialect, bytes);
1225 }
1226 
1227 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
1228 
1229 /// Return the value at the given index. If index does not refer to a valid
1230 /// element, then a null attribute is returned.
1231 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1232   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1233   if (Dialect *dialect = getDialect())
1234     return dialect->extractElementHook(*this, index);
1235   return Attribute();
1236 }
1237 
1238 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
1239 
1240 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1241   if (auto *d = getDialect())
1242     return d->decodeHook(*this, result);
1243   return true;
1244 }
1245 
1246 //===----------------------------------------------------------------------===//
1247 // SparseElementsAttr
1248 //===----------------------------------------------------------------------===//
1249 
1250 SparseElementsAttr SparseElementsAttr::get(ShapedType type,
1251                                            DenseElementsAttr indices,
1252                                            DenseElementsAttr values) {
1253   assert(indices.getType().getElementType().isInteger(64) &&
1254          "expected sparse indices to be 64-bit integer values");
1255   assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
1256          "type must be ranked tensor or vector");
1257   assert(type.hasStaticShape() && "type must have static shape");
1258   return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
1259                    indices.cast<DenseIntElementsAttr>(), values);
1260 }
1261 
1262 DenseIntElementsAttr SparseElementsAttr::getIndices() const {
1263   return getImpl()->indices;
1264 }
1265 
1266 DenseElementsAttr SparseElementsAttr::getValues() const {
1267   return getImpl()->values;
1268 }
1269 
1270 /// Return the value of the element at the given index.
1271 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1272   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1273   auto type = getType();
1274 
1275   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1276   // as a 1-D index array.
1277   auto sparseIndices = getIndices();
1278   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1279 
1280   // Check to see if the indices are a splat.
1281   if (sparseIndices.isSplat()) {
1282     // If the index is also not a splat of the index value, we know that the
1283     // value is zero.
1284     auto splatIndex = *sparseIndexValues.begin();
1285     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
1286       return getZeroAttr();
1287 
1288     // If the indices are a splat, we also expect the values to be a splat.
1289     assert(getValues().isSplat() && "expected splat values");
1290     return getValues().getSplatValue();
1291   }
1292 
1293   // Build a mapping between known indices and the offset of the stored element.
1294   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
1295   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1296   size_t rank = type.getRank();
1297   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1298     mappedIndices.try_emplace(
1299         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1300 
1301   // Look for the provided index key within the mapped indices. If the provided
1302   // index is not found, then return a zero attribute.
1303   auto it = mappedIndices.find(index);
1304   if (it == mappedIndices.end())
1305     return getZeroAttr();
1306 
1307   // Otherwise, return the held sparse value element.
1308   return getValues().getValue(it->second);
1309 }
1310 
1311 /// Get a zero APFloat for the given sparse attribute.
1312 APFloat SparseElementsAttr::getZeroAPFloat() const {
1313   auto eltType = getType().getElementType().cast<FloatType>();
1314   return APFloat(eltType.getFloatSemantics());
1315 }
1316 
1317 /// Get a zero APInt for the given sparse attribute.
1318 APInt SparseElementsAttr::getZeroAPInt() const {
1319   auto eltType = getType().getElementType().cast<IntegerType>();
1320   return APInt::getNullValue(eltType.getWidth());
1321 }
1322 
1323 /// Get a zero attribute for the given attribute type.
1324 Attribute SparseElementsAttr::getZeroAttr() const {
1325   auto eltType = getType().getElementType();
1326 
1327   // Handle floating point elements.
1328   if (eltType.isa<FloatType>())
1329     return FloatAttr::get(eltType, 0);
1330 
1331   // Otherwise, this is an integer.
1332   auto intEltTy = eltType.cast<IntegerType>();
1333   if (intEltTy.getWidth() == 1)
1334     return BoolAttr::get(false, eltType.getContext());
1335   return IntegerAttr::get(eltType, 0);
1336 }
1337 
1338 /// Flatten, and return, all of the sparse indices in this attribute in
1339 /// row-major order.
1340 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1341   std::vector<ptrdiff_t> flatSparseIndices;
1342 
1343   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1344   // as a 1-D index array.
1345   auto sparseIndices = getIndices();
1346   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1347   if (sparseIndices.isSplat()) {
1348     SmallVector<uint64_t, 8> indices(getType().getRank(),
1349                                      *sparseIndexValues.begin());
1350     flatSparseIndices.push_back(getFlattenedIndex(indices));
1351     return flatSparseIndices;
1352   }
1353 
1354   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1355   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1356   size_t rank = getType().getRank();
1357   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1358     flatSparseIndices.push_back(getFlattenedIndex(
1359         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1360   return flatSparseIndices;
1361 }
1362 
1363 //===----------------------------------------------------------------------===//
1364 // MutableDictionaryAttr
1365 //===----------------------------------------------------------------------===//
1366 
1367 MutableDictionaryAttr::MutableDictionaryAttr(
1368     ArrayRef<NamedAttribute> attributes) {
1369   setAttrs(attributes);
1370 }
1371 
1372 ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const {
1373   return attrs ? attrs.getValue() : llvm::None;
1374 }
1375 
1376 /// Replace the held attributes with ones provided in 'newAttrs'.
1377 void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) {
1378   // Don't create an attribute list if there are no attributes.
1379   if (attributes.empty())
1380     attrs = nullptr;
1381   else
1382     attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
1383 }
1384 
1385 /// Return the specified attribute if present, null otherwise.
1386 Attribute MutableDictionaryAttr::get(StringRef name) const {
1387   return attrs ? attrs.get(name) : nullptr;
1388 }
1389 
1390 /// Return the specified attribute if present, null otherwise.
1391 Attribute MutableDictionaryAttr::get(Identifier name) const {
1392   return attrs ? attrs.get(name) : nullptr;
1393 }
1394 
1395 /// Return the specified named attribute if present, None otherwise.
1396 Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const {
1397   return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1398 }
1399 Optional<NamedAttribute>
1400 MutableDictionaryAttr::getNamed(Identifier name) const {
1401   return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1402 }
1403 
1404 /// If the an attribute exists with the specified name, change it to the new
1405 /// value.  Otherwise, add a new attribute with the specified name/value.
1406 void MutableDictionaryAttr::set(Identifier name, Attribute value) {
1407   assert(value && "attributes may never be null");
1408 
1409   // Look for an existing value for the given name, and set it in-place.
1410   ArrayRef<NamedAttribute> values = getAttrs();
1411   auto it = llvm::find_if(
1412       values, [name](NamedAttribute attr) { return attr.first == name; });
1413   if (it != values.end()) {
1414     // Bail out early if the value is the same as what we already have.
1415     if (it->second == value)
1416       return;
1417 
1418     SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end());
1419     newAttrs[it - values.begin()].second = value;
1420     attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1421     return;
1422   }
1423 
1424   // Otherwise, insert the new attribute into its sorted position.
1425   it = llvm::lower_bound(values, name, compareNamedAttributeWithName);
1426   SmallVector<NamedAttribute, 8> newAttrs;
1427   newAttrs.reserve(values.size() + 1);
1428   newAttrs.append(values.begin(), it);
1429   newAttrs.push_back({name, value});
1430   newAttrs.append(it, values.end());
1431   attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1432 }
1433 
1434 /// Remove the attribute with the specified name if it exists.  The return
1435 /// value indicates whether the attribute was present or not.
1436 auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult {
1437   auto origAttrs = getAttrs();
1438   for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
1439     if (origAttrs[i].first == name) {
1440       // Handle the simple case of removing the only attribute in the list.
1441       if (e == 1) {
1442         attrs = nullptr;
1443         return RemoveResult::Removed;
1444       }
1445 
1446       SmallVector<NamedAttribute, 8> newAttrs;
1447       newAttrs.reserve(origAttrs.size() - 1);
1448       newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
1449       newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
1450       attrs = DictionaryAttr::getWithSorted(newAttrs,
1451                                             newAttrs[0].second.getContext());
1452       return RemoveResult::Removed;
1453     }
1454   }
1455   return RemoveResult::NotFound;
1456 }
1457