1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Attributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Types.h"
17 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/Twine.h"
20 #include "llvm/Support/Endian.h"
21 
22 using namespace mlir;
23 using namespace mlir::detail;
24 
25 //===----------------------------------------------------------------------===//
26 // AttributeStorage
27 //===----------------------------------------------------------------------===//
28 
29 AttributeStorage::AttributeStorage(Type type)
30     : type(type.getAsOpaquePointer()) {}
31 AttributeStorage::AttributeStorage() : type(nullptr) {}
32 
33 Type AttributeStorage::getType() const {
34   return Type::getFromOpaquePointer(type);
35 }
36 void AttributeStorage::setType(Type newType) {
37   type = newType.getAsOpaquePointer();
38 }
39 
40 //===----------------------------------------------------------------------===//
41 // Attribute
42 //===----------------------------------------------------------------------===//
43 
44 /// Return the type of this attribute.
45 Type Attribute::getType() const { return impl->getType(); }
46 
47 /// Return the context this attribute belongs to.
48 MLIRContext *Attribute::getContext() const { return getType().getContext(); }
49 
50 /// Get the dialect this attribute is registered to.
51 Dialect &Attribute::getDialect() const {
52   return impl->getAbstractAttribute().getDialect();
53 }
54 
55 //===----------------------------------------------------------------------===//
56 // AffineMapAttr
57 //===----------------------------------------------------------------------===//
58 
59 AffineMapAttr AffineMapAttr::get(AffineMap value) {
60   return Base::get(value.getContext(), value);
61 }
62 
63 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
64 
65 //===----------------------------------------------------------------------===//
66 // ArrayAttr
67 //===----------------------------------------------------------------------===//
68 
69 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
70   return Base::get(context, value);
71 }
72 
73 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
74 
75 Attribute ArrayAttr::operator[](unsigned idx) const {
76   assert(idx < size() && "index out of bounds");
77   return getValue()[idx];
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // DictionaryAttr
82 //===----------------------------------------------------------------------===//
83 
84 /// Helper function that does either an in place sort or sorts from source array
85 /// into destination. If inPlace then storage is both the source and the
86 /// destination, else value is the source and storage destination. Returns
87 /// whether source was sorted.
88 template <bool inPlace>
89 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
90                                SmallVectorImpl<NamedAttribute> &storage) {
91   // Specialize for the common case.
92   switch (value.size()) {
93   case 0:
94     // Zero already sorted.
95     break;
96   case 1:
97     // One already sorted but may need to be copied.
98     if (!inPlace)
99       storage.assign({value[0]});
100     break;
101   case 2: {
102     bool isSorted = value[0] < value[1];
103     if (inPlace) {
104       if (!isSorted)
105         std::swap(storage[0], storage[1]);
106     } else if (isSorted) {
107       storage.assign({value[0], value[1]});
108     } else {
109       storage.assign({value[1], value[0]});
110     }
111     return !isSorted;
112   }
113   default:
114     if (!inPlace)
115       storage.assign(value.begin(), value.end());
116     // Check to see they are sorted already.
117     bool isSorted = llvm::is_sorted(value);
118     if (!isSorted) {
119       // If not, do a general sort.
120       llvm::array_pod_sort(storage.begin(), storage.end());
121       value = storage;
122     }
123     return !isSorted;
124   }
125   return false;
126 }
127 
128 /// Returns an entry with a duplicate name from the given sorted array of named
129 /// attributes. Returns llvm::None if all elements have unique names.
130 static Optional<NamedAttribute>
131 findDuplicateElement(ArrayRef<NamedAttribute> value) {
132   const Optional<NamedAttribute> none{llvm::None};
133   if (value.size() < 2)
134     return none;
135 
136   if (value.size() == 2)
137     return value[0].first == value[1].first ? value[0] : none;
138 
139   auto it = std::adjacent_find(
140       value.begin(), value.end(),
141       [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; });
142   return it != value.end() ? *it : none;
143 }
144 
145 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
146                           SmallVectorImpl<NamedAttribute> &storage) {
147   bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
148   assert(!findDuplicateElement(storage) &&
149          "DictionaryAttr element names must be unique");
150   return isSorted;
151 }
152 
153 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
154   bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
155   assert(!findDuplicateElement(array) &&
156          "DictionaryAttr element names must be unique");
157   return isSorted;
158 }
159 
160 Optional<NamedAttribute>
161 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
162                               bool isSorted) {
163   if (!isSorted)
164     dictionaryAttrSort</*inPlace=*/true>(array, array);
165   return findDuplicateElement(array);
166 }
167 
168 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
169                                    MLIRContext *context) {
170   if (value.empty())
171     return DictionaryAttr::getEmpty(context);
172   assert(llvm::all_of(value,
173                       [](const NamedAttribute &attr) { return attr.second; }) &&
174          "value cannot have null entries");
175 
176   // We need to sort the element list to canonicalize it.
177   SmallVector<NamedAttribute, 8> storage;
178   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
179     value = storage;
180   assert(!findDuplicateElement(value) &&
181          "DictionaryAttr element names must be unique");
182   return Base::get(context, value);
183 }
184 /// Construct a dictionary with an array of values that is known to already be
185 /// sorted by name and uniqued.
186 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
187                                              MLIRContext *context) {
188   if (value.empty())
189     return DictionaryAttr::getEmpty(context);
190   // Ensure that the attribute elements are unique and sorted.
191   assert(llvm::is_sorted(value,
192                          [](NamedAttribute l, NamedAttribute r) {
193                            return l.first.strref() < r.first.strref();
194                          }) &&
195          "expected attribute values to be sorted");
196   assert(!findDuplicateElement(value) &&
197          "DictionaryAttr element names must be unique");
198   return Base::get(context, value);
199 }
200 
201 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
202   return getImpl()->getElements();
203 }
204 
205 /// Return the specified attribute if present, null otherwise.
206 Attribute DictionaryAttr::get(StringRef name) const {
207   Optional<NamedAttribute> attr = getNamed(name);
208   return attr ? attr->second : nullptr;
209 }
210 Attribute DictionaryAttr::get(Identifier name) const {
211   Optional<NamedAttribute> attr = getNamed(name);
212   return attr ? attr->second : nullptr;
213 }
214 
215 /// Return the specified named attribute if present, None otherwise.
216 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
217   ArrayRef<NamedAttribute> values = getValue();
218   const auto *it = llvm::lower_bound(values, name);
219   return it != values.end() && it->first == name ? *it
220                                                  : Optional<NamedAttribute>();
221 }
222 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
223   for (auto elt : getValue())
224     if (elt.first == name)
225       return elt;
226   return llvm::None;
227 }
228 
229 DictionaryAttr::iterator DictionaryAttr::begin() const {
230   return getValue().begin();
231 }
232 DictionaryAttr::iterator DictionaryAttr::end() const {
233   return getValue().end();
234 }
235 size_t DictionaryAttr::size() const { return getValue().size(); }
236 
237 //===----------------------------------------------------------------------===//
238 // FloatAttr
239 //===----------------------------------------------------------------------===//
240 
241 FloatAttr FloatAttr::get(Type type, double value) {
242   return Base::get(type.getContext(), type, value);
243 }
244 
245 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
246   return Base::getChecked(loc, type, value);
247 }
248 
249 FloatAttr FloatAttr::get(Type type, const APFloat &value) {
250   return Base::get(type.getContext(), type, value);
251 }
252 
253 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
254   return Base::getChecked(loc, type, value);
255 }
256 
257 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
258 
259 double FloatAttr::getValueAsDouble() const {
260   return getValueAsDouble(getValue());
261 }
262 double FloatAttr::getValueAsDouble(APFloat value) {
263   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
264     bool losesInfo = false;
265     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
266                   &losesInfo);
267   }
268   return value.convertToDouble();
269 }
270 
271 /// Verify construction invariants.
272 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
273   if (!type.isa<FloatType>())
274     return emitError(loc, "expected floating point type");
275   return success();
276 }
277 
278 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
279                                                       double value) {
280   return verifyFloatTypeInvariants(loc, type);
281 }
282 
283 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
284                                                       const APFloat &value) {
285   // Verify that the type is correct.
286   if (failed(verifyFloatTypeInvariants(loc, type)))
287     return failure();
288 
289   // Verify that the type semantics match that of the value.
290   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
291     return emitError(
292         loc, "FloatAttr type doesn't match the type implied by its value");
293   }
294   return success();
295 }
296 
297 //===----------------------------------------------------------------------===//
298 // SymbolRefAttr
299 //===----------------------------------------------------------------------===//
300 
301 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
302   return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
303 }
304 
305 SymbolRefAttr SymbolRefAttr::get(StringRef value,
306                                  ArrayRef<FlatSymbolRefAttr> nestedReferences,
307                                  MLIRContext *ctx) {
308   return Base::get(ctx, value, nestedReferences);
309 }
310 
311 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
312 
313 StringRef SymbolRefAttr::getLeafReference() const {
314   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
315   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
316 }
317 
318 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
319   return getImpl()->getNestedRefs();
320 }
321 
322 //===----------------------------------------------------------------------===//
323 // IntegerAttr
324 //===----------------------------------------------------------------------===//
325 
326 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
327   if (type.isSignlessInteger(1))
328     return BoolAttr::get(value.getBoolValue(), type.getContext());
329   return Base::get(type.getContext(), type, value);
330 }
331 
332 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
333   // This uses 64 bit APInts by default for index type.
334   if (type.isIndex())
335     return get(type, APInt(IndexType::kInternalStorageBitWidth, value));
336 
337   auto intType = type.cast<IntegerType>();
338   return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
339 }
340 
341 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
342 
343 int64_t IntegerAttr::getInt() const {
344   assert((getImpl()->getType().isIndex() ||
345           getImpl()->getType().isSignlessInteger()) &&
346          "must be signless integer");
347   return getValue().getSExtValue();
348 }
349 
350 int64_t IntegerAttr::getSInt() const {
351   assert(getImpl()->getType().isSignedInteger() && "must be signed integer");
352   return getValue().getSExtValue();
353 }
354 
355 uint64_t IntegerAttr::getUInt() const {
356   assert(getImpl()->getType().isUnsignedInteger() &&
357          "must be unsigned integer");
358   return getValue().getZExtValue();
359 }
360 
361 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
362   if (type.isa<IntegerType, IndexType>())
363     return success();
364   return emitError(loc, "expected integer or index type");
365 }
366 
367 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
368                                                         int64_t value) {
369   return verifyIntegerTypeInvariants(loc, type);
370 }
371 
372 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
373                                                         const APInt &value) {
374   if (failed(verifyIntegerTypeInvariants(loc, type)))
375     return failure();
376   if (auto integerType = type.dyn_cast<IntegerType>())
377     if (integerType.getWidth() != value.getBitWidth())
378       return emitError(loc, "integer type bit width (")
379              << integerType.getWidth() << ") doesn't match value bit width ("
380              << value.getBitWidth() << ")";
381   return success();
382 }
383 
384 //===----------------------------------------------------------------------===//
385 // BoolAttr
386 
387 bool BoolAttr::getValue() const {
388   auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl);
389   return storage->getValue().getBoolValue();
390 }
391 
392 bool BoolAttr::classof(Attribute attr) {
393   IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
394   return intAttr && intAttr.getType().isSignlessInteger(1);
395 }
396 
397 //===----------------------------------------------------------------------===//
398 // IntegerSetAttr
399 //===----------------------------------------------------------------------===//
400 
401 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
402   return Base::get(value.getConstraint(0).getContext(), value);
403 }
404 
405 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
406 
407 //===----------------------------------------------------------------------===//
408 // OpaqueAttr
409 //===----------------------------------------------------------------------===//
410 
411 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
412                            MLIRContext *context) {
413   return Base::get(context, dialect, attrData, type);
414 }
415 
416 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
417                                   Type type, Location location) {
418   return Base::getChecked(location, dialect, attrData, type);
419 }
420 
421 /// Returns the dialect namespace of the opaque attribute.
422 Identifier OpaqueAttr::getDialectNamespace() const {
423   return getImpl()->dialectNamespace;
424 }
425 
426 /// Returns the raw attribute data of the opaque attribute.
427 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
428 
429 /// Verify the construction of an opaque attribute.
430 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
431                                                        Identifier dialect,
432                                                        StringRef attrData,
433                                                        Type type) {
434   if (!Dialect::isValidNamespace(dialect.strref()))
435     return emitError(loc, "invalid dialect namespace '") << dialect << "'";
436   return success();
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // StringAttr
441 //===----------------------------------------------------------------------===//
442 
443 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
444   return get(bytes, NoneType::get(context));
445 }
446 
447 /// Get an instance of a StringAttr with the given string and Type.
448 StringAttr StringAttr::get(StringRef bytes, Type type) {
449   return Base::get(type.getContext(), bytes, type);
450 }
451 
452 StringRef StringAttr::getValue() const { return getImpl()->value; }
453 
454 //===----------------------------------------------------------------------===//
455 // TypeAttr
456 //===----------------------------------------------------------------------===//
457 
458 TypeAttr TypeAttr::get(Type value) {
459   return Base::get(value.getContext(), value);
460 }
461 
462 Type TypeAttr::getValue() const { return getImpl()->value; }
463 
464 //===----------------------------------------------------------------------===//
465 // ElementsAttr
466 //===----------------------------------------------------------------------===//
467 
468 ShapedType ElementsAttr::getType() const {
469   return Attribute::getType().cast<ShapedType>();
470 }
471 
472 /// Returns the number of elements held by this attribute.
473 int64_t ElementsAttr::getNumElements() const {
474   return getType().getNumElements();
475 }
476 
477 /// Return the value at the given index. If index does not refer to a valid
478 /// element, then a null attribute is returned.
479 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
480   if (auto denseAttr = dyn_cast<DenseElementsAttr>())
481     return denseAttr.getValue(index);
482   if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>())
483     return opaqueAttr.getValue(index);
484   return cast<SparseElementsAttr>().getValue(index);
485 }
486 
487 /// Return if the given 'index' refers to a valid element in this attribute.
488 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
489   auto type = getType();
490 
491   // Verify that the rank of the indices matches the held type.
492   auto rank = type.getRank();
493   if (rank != static_cast<int64_t>(index.size()))
494     return false;
495 
496   // Verify that all of the indices are within the shape dimensions.
497   auto shape = type.getShape();
498   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
499     return static_cast<int64_t>(index[i]) < shape[i];
500   });
501 }
502 
503 ElementsAttr
504 ElementsAttr::mapValues(Type newElementType,
505                         function_ref<APInt(const APInt &)> mapping) const {
506   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
507     return intOrFpAttr.mapValues(newElementType, mapping);
508   llvm_unreachable("unsupported ElementsAttr subtype");
509 }
510 
511 ElementsAttr
512 ElementsAttr::mapValues(Type newElementType,
513                         function_ref<APInt(const APFloat &)> mapping) const {
514   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
515     return intOrFpAttr.mapValues(newElementType, mapping);
516   llvm_unreachable("unsupported ElementsAttr subtype");
517 }
518 
519 /// Method for support type inquiry through isa, cast and dyn_cast.
520 bool ElementsAttr::classof(Attribute attr) {
521   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr,
522                   OpaqueElementsAttr, SparseElementsAttr>();
523 }
524 
525 /// Returns the 1 dimensional flattened row-major index from the given
526 /// multi-dimensional index.
527 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
528   assert(isValidIndex(index) && "expected valid multi-dimensional index");
529   auto type = getType();
530 
531   // Reduce the provided multidimensional index into a flattended 1D row-major
532   // index.
533   auto rank = type.getRank();
534   auto shape = type.getShape();
535   uint64_t valueIndex = 0;
536   uint64_t dimMultiplier = 1;
537   for (int i = rank - 1; i >= 0; --i) {
538     valueIndex += index[i] * dimMultiplier;
539     dimMultiplier *= shape[i];
540   }
541   return valueIndex;
542 }
543 
544 //===----------------------------------------------------------------------===//
545 // DenseElementsAttr Utilities
546 //===----------------------------------------------------------------------===//
547 
548 /// Get the bitwidth of a dense element type within the buffer.
549 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
550 static size_t getDenseElementStorageWidth(size_t origWidth) {
551   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
552 }
553 static size_t getDenseElementStorageWidth(Type elementType) {
554   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
555 }
556 
557 /// Set a bit to a specific value.
558 static void setBit(char *rawData, size_t bitPos, bool value) {
559   if (value)
560     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
561   else
562     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
563 }
564 
565 /// Return the value of the specified bit.
566 static bool getBit(const char *rawData, size_t bitPos) {
567   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
568 }
569 
570 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
571 /// BE format.
572 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
573                                          char *result) {
574   assert(llvm::support::endian::system_endianness() == // NOLINT
575          llvm::support::endianness::big);              // NOLINT
576   assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
577 
578   // Copy the words filled with data.
579   // For example, when `value` has 2 words, the first word is filled with data.
580   // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
581   size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
582   std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
583               numFilledWords, result);
584   // Convert last word of APInt to LE format and store it in char
585   // array(`valueLE`).
586   // ex. last word of `value` (BE): |------ij|  ==> `valueLE` (LE): |ji------|
587   size_t lastWordPos = numFilledWords;
588   SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
589   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
590       reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
591       valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
592   // Extract actual APInt data from `valueLE`, convert endianness to BE format,
593   // and store it in `result`.
594   // ex. `valueLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|ij|
595   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
596       valueLE.begin(), result + lastWordPos,
597       (numBytes - lastWordPos) * CHAR_BIT, 1);
598 }
599 
600 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
601 /// format.
602 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
603                                          APInt &result) {
604   assert(llvm::support::endian::system_endianness() == // NOLINT
605          llvm::support::endianness::big);              // NOLINT
606   assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
607 
608   // Copy the data that fills the word of `result` from `inArray`.
609   // For example, when `result` has 2 words, the first word will be filled with
610   // data. So, the first 8 bytes are copied from `inArray` here.
611   // `inArray` (10 bytes, BE): |abcdefgh|ij|
612   //                     ==> `result` (2 words, BE): |abcdefgh|--------|
613   size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
614   std::copy_n(
615       inArray, numFilledWords,
616       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
617 
618   // Convert array data which will be last word of `result` to LE format, and
619   // store it in char array(`inArrayLE`).
620   // ex. `inArray` (last two bytes, BE): |ij|  ==> `inArrayLE` (LE): |ji------|
621   size_t lastWordPos = numFilledWords;
622   SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
623   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
624       inArray + lastWordPos, inArrayLE.begin(),
625       (numBytes - lastWordPos) * CHAR_BIT, 1);
626 
627   // Convert `inArrayLE` to BE format, and store it in last word of `result`.
628   // ex. `inArrayLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|------ij|
629   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
630       inArrayLE.begin(),
631       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
632           lastWordPos,
633       APInt::APINT_BITS_PER_WORD, 1);
634 }
635 
636 /// Writes value to the bit position `bitPos` in array `rawData`.
637 static void writeBits(char *rawData, size_t bitPos, APInt value) {
638   size_t bitWidth = value.getBitWidth();
639 
640   // If the bitwidth is 1 we just toggle the specific bit.
641   if (bitWidth == 1)
642     return setBit(rawData, bitPos, value.isOneValue());
643 
644   // Otherwise, the bit position is guaranteed to be byte aligned.
645   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
646   if (llvm::support::endian::system_endianness() ==
647       llvm::support::endianness::big) {
648     // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
649     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
650     // work correctly in BE format.
651     // ex. `value` (2 words including 10 bytes)
652     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------|
653     copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
654                                  rawData + (bitPos / CHAR_BIT));
655   } else {
656     std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
657                 llvm::divideCeil(bitWidth, CHAR_BIT),
658                 rawData + (bitPos / CHAR_BIT));
659   }
660 }
661 
662 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
663 /// `rawData`.
664 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
665   // Handle a boolean bit position.
666   if (bitWidth == 1)
667     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
668 
669   // Otherwise, the bit position must be 8-bit aligned.
670   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
671   APInt result(bitWidth, 0);
672   if (llvm::support::endian::system_endianness() ==
673       llvm::support::endianness::big) {
674     // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
675     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
676     // work correctly in BE format.
677     // ex. `result` (2 words including 10 bytes)
678     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------| This function
679     copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
680                                  llvm::divideCeil(bitWidth, CHAR_BIT), result);
681   } else {
682     std::copy_n(rawData + (bitPos / CHAR_BIT),
683                 llvm::divideCeil(bitWidth, CHAR_BIT),
684                 const_cast<char *>(
685                     reinterpret_cast<const char *>(result.getRawData())));
686   }
687   return result;
688 }
689 
690 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
691 /// the same element count as 'type'.
692 template <typename Values>
693 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
694   return (values.size() == 1) ||
695          (type.getNumElements() == static_cast<int64_t>(values.size()));
696 }
697 
698 //===----------------------------------------------------------------------===//
699 // DenseElementsAttr Iterators
700 //===----------------------------------------------------------------------===//
701 
702 //===----------------------------------------------------------------------===//
703 // AttributeElementIterator
704 
705 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
706     DenseElementsAttr attr, size_t index)
707     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
708                                       Attribute, Attribute, Attribute>(
709           attr.getAsOpaquePointer(), index) {}
710 
711 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
712   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
713   Type eltTy = owner.getType().getElementType();
714   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
715     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
716   if (eltTy.isa<IndexType>())
717     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
718   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
719     IntElementIterator intIt(owner, index);
720     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
721     return FloatAttr::get(eltTy, *floatIt);
722   }
723   if (owner.isa<DenseStringElementsAttr>()) {
724     ArrayRef<StringRef> vals = owner.getRawStringData();
725     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
726   }
727   llvm_unreachable("unexpected element type");
728 }
729 
730 //===----------------------------------------------------------------------===//
731 // BoolElementIterator
732 
733 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
734     DenseElementsAttr attr, size_t dataIndex)
735     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
736           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
737 
738 bool DenseElementsAttr::BoolElementIterator::operator*() const {
739   return getBit(getData(), getDataIndex());
740 }
741 
742 //===----------------------------------------------------------------------===//
743 // IntElementIterator
744 
745 DenseElementsAttr::IntElementIterator::IntElementIterator(
746     DenseElementsAttr attr, size_t dataIndex)
747     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
748           attr.getRawData().data(), attr.isSplat(), dataIndex),
749       bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
750 
751 APInt DenseElementsAttr::IntElementIterator::operator*() const {
752   return readBits(getData(),
753                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
754                   bitWidth);
755 }
756 
757 //===----------------------------------------------------------------------===//
758 // ComplexIntElementIterator
759 
760 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
761     DenseElementsAttr attr, size_t dataIndex)
762     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
763                                       std::complex<APInt>, std::complex<APInt>,
764                                       std::complex<APInt>>(
765           attr.getRawData().data(), attr.isSplat(), dataIndex) {
766   auto complexType = attr.getType().getElementType().cast<ComplexType>();
767   bitWidth = getDenseElementBitWidth(complexType.getElementType());
768 }
769 
770 std::complex<APInt>
771 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
772   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
773   size_t offset = getDataIndex() * storageWidth * 2;
774   return {readBits(getData(), offset, bitWidth),
775           readBits(getData(), offset + storageWidth, bitWidth)};
776 }
777 
778 //===----------------------------------------------------------------------===//
779 // FloatElementIterator
780 
781 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
782     const llvm::fltSemantics &smt, IntElementIterator it)
783     : llvm::mapped_iterator<IntElementIterator,
784                             std::function<APFloat(const APInt &)>>(
785           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
786 
787 //===----------------------------------------------------------------------===//
788 // ComplexFloatElementIterator
789 
790 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
791     const llvm::fltSemantics &smt, ComplexIntElementIterator it)
792     : llvm::mapped_iterator<
793           ComplexIntElementIterator,
794           std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
795           it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
796             return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
797           }) {}
798 
799 //===----------------------------------------------------------------------===//
800 // DenseElementsAttr
801 //===----------------------------------------------------------------------===//
802 
803 /// Method for support type inquiry through isa, cast and dyn_cast.
804 bool DenseElementsAttr::classof(Attribute attr) {
805   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
806 }
807 
808 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
809                                          ArrayRef<Attribute> values) {
810   assert(hasSameElementsOrSplat(type, values));
811 
812   // If the element type is not based on int/float/index, assume it is a string
813   // type.
814   auto eltType = type.getElementType();
815   if (!type.getElementType().isIntOrIndexOrFloat()) {
816     SmallVector<StringRef, 8> stringValues;
817     stringValues.reserve(values.size());
818     for (Attribute attr : values) {
819       assert(attr.isa<StringAttr>() &&
820              "expected string value for non integer/index/float element");
821       stringValues.push_back(attr.cast<StringAttr>().getValue());
822     }
823     return get(type, stringValues);
824   }
825 
826   // Otherwise, get the raw storage width to use for the allocation.
827   size_t bitWidth = getDenseElementBitWidth(eltType);
828   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
829 
830   // Compress the attribute values into a character buffer.
831   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
832                             values.size());
833   APInt intVal;
834   for (unsigned i = 0, e = values.size(); i < e; ++i) {
835     assert(eltType == values[i].getType() &&
836            "expected attribute value to have element type");
837     if (eltType.isa<FloatType>())
838       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
839     else if (eltType.isa<IntegerType>())
840       intVal = values[i].cast<IntegerAttr>().getValue();
841     else
842       llvm_unreachable("unexpected element type");
843 
844     assert(intVal.getBitWidth() == bitWidth &&
845            "expected value to have same bitwidth as element type");
846     writeBits(data.data(), i * storageBitWidth, intVal);
847   }
848   return DenseIntOrFPElementsAttr::getRaw(type, data,
849                                           /*isSplat=*/(values.size() == 1));
850 }
851 
852 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
853                                          ArrayRef<bool> values) {
854   assert(hasSameElementsOrSplat(type, values));
855   assert(type.getElementType().isInteger(1));
856 
857   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
858   for (int i = 0, e = values.size(); i != e; ++i)
859     setBit(buff.data(), i, values[i]);
860   return DenseIntOrFPElementsAttr::getRaw(type, buff,
861                                           /*isSplat=*/(values.size() == 1));
862 }
863 
864 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
865                                          ArrayRef<StringRef> values) {
866   assert(!type.getElementType().isIntOrFloat());
867   return DenseStringElementsAttr::get(type, values);
868 }
869 
870 /// Constructs a dense integer elements attribute from an array of APInt
871 /// values. Each APInt value is expected to have the same bitwidth as the
872 /// element type of 'type'.
873 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
874                                          ArrayRef<APInt> values) {
875   assert(type.getElementType().isIntOrIndex());
876   assert(hasSameElementsOrSplat(type, values));
877   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
878   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
879                                           /*isSplat=*/(values.size() == 1));
880 }
881 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
882                                          ArrayRef<std::complex<APInt>> values) {
883   ComplexType complex = type.getElementType().cast<ComplexType>();
884   assert(complex.getElementType().isa<IntegerType>());
885   assert(hasSameElementsOrSplat(type, values));
886   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
887   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
888                           values.size() * 2);
889   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
890                                           /*isSplat=*/(values.size() == 1));
891 }
892 
893 // Constructs a dense float elements attribute from an array of APFloat
894 // values. Each APFloat value is expected to have the same bitwidth as the
895 // element type of 'type'.
896 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
897                                          ArrayRef<APFloat> values) {
898   assert(type.getElementType().isa<FloatType>());
899   assert(hasSameElementsOrSplat(type, values));
900   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
901   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
902                                           /*isSplat=*/(values.size() == 1));
903 }
904 DenseElementsAttr
905 DenseElementsAttr::get(ShapedType type,
906                        ArrayRef<std::complex<APFloat>> values) {
907   ComplexType complex = type.getElementType().cast<ComplexType>();
908   assert(complex.getElementType().isa<FloatType>());
909   assert(hasSameElementsOrSplat(type, values));
910   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
911                            values.size() * 2);
912   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
913   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
914                                           /*isSplat=*/(values.size() == 1));
915 }
916 
917 /// Construct a dense elements attribute from a raw buffer representing the
918 /// data for this attribute. Users should generally not use this methods as
919 /// the expected buffer format may not be a form the user expects.
920 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
921                                                       ArrayRef<char> rawBuffer,
922                                                       bool isSplatBuffer) {
923   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
924 }
925 
926 /// Returns true if the given buffer is a valid raw buffer for the given type.
927 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
928                                          ArrayRef<char> rawBuffer,
929                                          bool &detectedSplat) {
930   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
931   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
932 
933   // Storage width of 1 is special as it is packed by the bit.
934   if (storageWidth == 1) {
935     // Check for a splat, or a buffer equal to the number of elements.
936     if ((detectedSplat = rawBuffer.size() == 1))
937       return true;
938     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
939   }
940   // All other types are 8-bit aligned.
941   if ((detectedSplat = rawBufferWidth == storageWidth))
942     return true;
943   return rawBufferWidth == (storageWidth * type.getNumElements());
944 }
945 
946 /// Check the information for a C++ data type, check if this type is valid for
947 /// the current attribute. This method is used to verify specific type
948 /// invariants that the templatized 'getValues' method cannot.
949 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
950                               bool isSigned) {
951   // Make sure that the data element size is the same as the type element width.
952   if (getDenseElementBitWidth(type) !=
953       static_cast<size_t>(dataEltSize * CHAR_BIT))
954     return false;
955 
956   // Check that the element type is either float or integer or index.
957   if (!isInt)
958     return type.isa<FloatType>();
959   if (type.isIndex())
960     return true;
961 
962   auto intType = type.dyn_cast<IntegerType>();
963   if (!intType)
964     return false;
965 
966   // Make sure signedness semantics is consistent.
967   if (intType.isSignless())
968     return true;
969   return intType.isSigned() ? isSigned : !isSigned;
970 }
971 
972 /// Defaults down the subclass implementation.
973 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
974                                                    ArrayRef<char> data,
975                                                    int64_t dataEltSize,
976                                                    bool isInt, bool isSigned) {
977   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
978                                                  isSigned);
979 }
980 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
981                                                       ArrayRef<char> data,
982                                                       int64_t dataEltSize,
983                                                       bool isInt,
984                                                       bool isSigned) {
985   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
986                                                     isInt, isSigned);
987 }
988 
989 /// A method used to verify specific type invariants that the templatized 'get'
990 /// method cannot.
991 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
992                                           bool isSigned) const {
993   return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
994                              isSigned);
995 }
996 
997 /// Check the information for a C++ data type, check if this type is valid for
998 /// the current attribute.
999 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
1000                                        bool isSigned) const {
1001   return ::isValidIntOrFloat(
1002       getType().getElementType().cast<ComplexType>().getElementType(),
1003       dataEltSize / 2, isInt, isSigned);
1004 }
1005 
1006 /// Returns true if this attribute corresponds to a splat, i.e. if all element
1007 /// values are the same.
1008 bool DenseElementsAttr::isSplat() const {
1009   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
1010 }
1011 
1012 /// Return the held element values as a range of Attributes.
1013 auto DenseElementsAttr::getAttributeValues() const
1014     -> llvm::iterator_range<AttributeElementIterator> {
1015   return {attr_value_begin(), attr_value_end()};
1016 }
1017 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
1018   return AttributeElementIterator(*this, 0);
1019 }
1020 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
1021   return AttributeElementIterator(*this, getNumElements());
1022 }
1023 
1024 /// Return the held element values as a range of bool. The element type of
1025 /// this attribute must be of integer type of bitwidth 1.
1026 auto DenseElementsAttr::getBoolValues() const
1027     -> llvm::iterator_range<BoolElementIterator> {
1028   auto eltType = getType().getElementType().dyn_cast<IntegerType>();
1029   assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
1030   (void)eltType;
1031   return {BoolElementIterator(*this, 0),
1032           BoolElementIterator(*this, getNumElements())};
1033 }
1034 
1035 /// Return the held element values as a range of APInts. The element type of
1036 /// this attribute must be of integer type.
1037 auto DenseElementsAttr::getIntValues() const
1038     -> llvm::iterator_range<IntElementIterator> {
1039   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
1040   return {raw_int_begin(), raw_int_end()};
1041 }
1042 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
1043   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
1044   return raw_int_begin();
1045 }
1046 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
1047   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
1048   return raw_int_end();
1049 }
1050 auto DenseElementsAttr::getComplexIntValues() const
1051     -> llvm::iterator_range<ComplexIntElementIterator> {
1052   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
1053   (void)eltTy;
1054   assert(eltTy.isa<IntegerType>() && "expected complex integral type");
1055   return {ComplexIntElementIterator(*this, 0),
1056           ComplexIntElementIterator(*this, getNumElements())};
1057 }
1058 
1059 /// Return the held element values as a range of APFloat. The element type of
1060 /// this attribute must be of float type.
1061 auto DenseElementsAttr::getFloatValues() const
1062     -> llvm::iterator_range<FloatElementIterator> {
1063   auto elementType = getType().getElementType().cast<FloatType>();
1064   const auto &elementSemantics = elementType.getFloatSemantics();
1065   return {FloatElementIterator(elementSemantics, raw_int_begin()),
1066           FloatElementIterator(elementSemantics, raw_int_end())};
1067 }
1068 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
1069   return getFloatValues().begin();
1070 }
1071 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
1072   return getFloatValues().end();
1073 }
1074 auto DenseElementsAttr::getComplexFloatValues() const
1075     -> llvm::iterator_range<ComplexFloatElementIterator> {
1076   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
1077   assert(eltTy.isa<FloatType>() && "expected complex float type");
1078   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
1079   return {{semantics, {*this, 0}},
1080           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
1081 }
1082 
1083 /// Return the raw storage data held by this attribute.
1084 ArrayRef<char> DenseElementsAttr::getRawData() const {
1085   return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data;
1086 }
1087 
1088 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1089   return static_cast<DenseStringElementsAttributeStorage *>(impl)->data;
1090 }
1091 
1092 /// Return a new DenseElementsAttr that has the same data as the current
1093 /// attribute, but has been reshaped to 'newType'. The new type must have the
1094 /// same total number of elements as well as element type.
1095 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1096   ShapedType curType = getType();
1097   if (curType == newType)
1098     return *this;
1099 
1100   (void)curType;
1101   assert(newType.getElementType() == curType.getElementType() &&
1102          "expected the same element type");
1103   assert(newType.getNumElements() == curType.getNumElements() &&
1104          "expected the same number of elements");
1105   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
1106 }
1107 
1108 DenseElementsAttr
1109 DenseElementsAttr::mapValues(Type newElementType,
1110                              function_ref<APInt(const APInt &)> mapping) const {
1111   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1112 }
1113 
1114 DenseElementsAttr DenseElementsAttr::mapValues(
1115     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1116   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1117 }
1118 
1119 //===----------------------------------------------------------------------===//
1120 // DenseStringElementsAttr
1121 //===----------------------------------------------------------------------===//
1122 
1123 DenseStringElementsAttr
1124 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
1125   return Base::get(type.getContext(), type, values, (values.size() == 1));
1126 }
1127 
1128 //===----------------------------------------------------------------------===//
1129 // DenseIntOrFPElementsAttr
1130 //===----------------------------------------------------------------------===//
1131 
1132 /// Utility method to write a range of APInt values to a buffer.
1133 template <typename APRangeT>
1134 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1135                                 APRangeT &&values) {
1136   data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
1137   size_t offset = 0;
1138   for (auto it = values.begin(), e = values.end(); it != e;
1139        ++it, offset += storageWidth) {
1140     assert((*it).getBitWidth() <= storageWidth);
1141     writeBits(data.data(), offset, *it);
1142   }
1143 }
1144 
1145 /// Constructs a dense elements attribute from an array of raw APFloat values.
1146 /// Each APFloat value is expected to have the same bitwidth as the element
1147 /// type of 'type'. 'type' must be a vector or tensor with static shape.
1148 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1149                                                    size_t storageWidth,
1150                                                    ArrayRef<APFloat> values,
1151                                                    bool isSplat) {
1152   std::vector<char> data;
1153   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1154   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1155   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1156 }
1157 
1158 /// Constructs a dense elements attribute from an array of raw APInt values.
1159 /// Each APInt value is expected to have the same bitwidth as the element type
1160 /// of 'type'.
1161 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1162                                                    size_t storageWidth,
1163                                                    ArrayRef<APInt> values,
1164                                                    bool isSplat) {
1165   std::vector<char> data;
1166   writeAPIntsToBuffer(storageWidth, data, values);
1167   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1168 }
1169 
1170 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1171                                                    ArrayRef<char> data,
1172                                                    bool isSplat) {
1173   assert((type.isa<RankedTensorType, VectorType>()) &&
1174          "type must be ranked tensor or vector");
1175   assert(type.hasStaticShape() && "type must have static shape");
1176   return Base::get(type.getContext(), type, data, isSplat);
1177 }
1178 
1179 /// Overload of the raw 'get' method that asserts that the given type is of
1180 /// complex type. This method is used to verify type invariants that the
1181 /// templatized 'get' method cannot.
1182 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1183                                                           ArrayRef<char> data,
1184                                                           int64_t dataEltSize,
1185                                                           bool isInt,
1186                                                           bool isSigned) {
1187   assert(::isValidIntOrFloat(
1188       type.getElementType().cast<ComplexType>().getElementType(),
1189       dataEltSize / 2, isInt, isSigned));
1190 
1191   int64_t numElements = data.size() / dataEltSize;
1192   assert(numElements == 1 || numElements == type.getNumElements());
1193   return getRaw(type, data, /*isSplat=*/numElements == 1);
1194 }
1195 
1196 /// Overload of the 'getRaw' method that asserts that the given type is of
1197 /// integer type. This method is used to verify type invariants that the
1198 /// templatized 'get' method cannot.
1199 DenseElementsAttr
1200 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1201                                            int64_t dataEltSize, bool isInt,
1202                                            bool isSigned) {
1203   assert(
1204       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1205 
1206   int64_t numElements = data.size() / dataEltSize;
1207   assert(numElements == 1 || numElements == type.getNumElements());
1208   return getRaw(type, data, /*isSplat=*/numElements == 1);
1209 }
1210 
1211 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1212     const char *inRawData, char *outRawData, size_t elementBitWidth,
1213     size_t numElements) {
1214   using llvm::support::ulittle16_t;
1215   using llvm::support::ulittle32_t;
1216   using llvm::support::ulittle64_t;
1217 
1218   assert(llvm::support::endian::system_endianness() == // NOLINT
1219          llvm::support::endianness::big);              // NOLINT
1220   // NOLINT to avoid warning message about replacing by static_assert()
1221 
1222   // Following std::copy_n always converts endianness on BE machine.
1223   switch (elementBitWidth) {
1224   case 16: {
1225     const ulittle16_t *inRawDataPos =
1226         reinterpret_cast<const ulittle16_t *>(inRawData);
1227     uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1228     std::copy_n(inRawDataPos, numElements, outDataPos);
1229     break;
1230   }
1231   case 32: {
1232     const ulittle32_t *inRawDataPos =
1233         reinterpret_cast<const ulittle32_t *>(inRawData);
1234     uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1235     std::copy_n(inRawDataPos, numElements, outDataPos);
1236     break;
1237   }
1238   case 64: {
1239     const ulittle64_t *inRawDataPos =
1240         reinterpret_cast<const ulittle64_t *>(inRawData);
1241     uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1242     std::copy_n(inRawDataPos, numElements, outDataPos);
1243     break;
1244   }
1245   default: {
1246     size_t nBytes = elementBitWidth / CHAR_BIT;
1247     for (size_t i = 0; i < nBytes; i++)
1248       std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1249     break;
1250   }
1251   }
1252 }
1253 
1254 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1255     ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1256     ShapedType type) {
1257   size_t numElements = type.getNumElements();
1258   Type elementType = type.getElementType();
1259   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1260     elementType = complexTy.getElementType();
1261     numElements = numElements * 2;
1262   }
1263   size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1264   assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1265          inRawData.size() <= outRawData.size());
1266   convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1267                                   elementBitWidth, numElements);
1268 }
1269 
1270 //===----------------------------------------------------------------------===//
1271 // DenseFPElementsAttr
1272 //===----------------------------------------------------------------------===//
1273 
1274 template <typename Fn, typename Attr>
1275 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1276                                 Type newElementType,
1277                                 llvm::SmallVectorImpl<char> &data) {
1278   size_t bitWidth = getDenseElementBitWidth(newElementType);
1279   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1280 
1281   ShapedType newArrayType;
1282   if (inType.isa<RankedTensorType>())
1283     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1284   else if (inType.isa<UnrankedTensorType>())
1285     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1286   else if (inType.isa<VectorType>())
1287     newArrayType = VectorType::get(inType.getShape(), newElementType);
1288   else
1289     assert(newArrayType && "Unhandled tensor type");
1290 
1291   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1292   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1293 
1294   // Functor used to process a single element value of the attribute.
1295   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1296     auto newInt = mapping(value);
1297     assert(newInt.getBitWidth() == bitWidth);
1298     writeBits(data.data(), index * storageBitWidth, newInt);
1299   };
1300 
1301   // Check for the splat case.
1302   if (attr.isSplat()) {
1303     processElt(*attr.begin(), /*index=*/0);
1304     return newArrayType;
1305   }
1306 
1307   // Otherwise, process all of the element values.
1308   uint64_t elementIdx = 0;
1309   for (auto value : attr)
1310     processElt(value, elementIdx++);
1311   return newArrayType;
1312 }
1313 
1314 DenseElementsAttr DenseFPElementsAttr::mapValues(
1315     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1316   llvm::SmallVector<char, 8> elementData;
1317   auto newArrayType =
1318       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1319 
1320   return getRaw(newArrayType, elementData, isSplat());
1321 }
1322 
1323 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1324 bool DenseFPElementsAttr::classof(Attribute attr) {
1325   return attr.isa<DenseElementsAttr>() &&
1326          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1327 }
1328 
1329 //===----------------------------------------------------------------------===//
1330 // DenseIntElementsAttr
1331 //===----------------------------------------------------------------------===//
1332 
1333 DenseElementsAttr DenseIntElementsAttr::mapValues(
1334     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1335   llvm::SmallVector<char, 8> elementData;
1336   auto newArrayType =
1337       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1338 
1339   return getRaw(newArrayType, elementData, isSplat());
1340 }
1341 
1342 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1343 bool DenseIntElementsAttr::classof(Attribute attr) {
1344   return attr.isa<DenseElementsAttr>() &&
1345          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1346 }
1347 
1348 //===----------------------------------------------------------------------===//
1349 // OpaqueElementsAttr
1350 //===----------------------------------------------------------------------===//
1351 
1352 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
1353                                            StringRef bytes) {
1354   assert(TensorType::isValidElementType(type.getElementType()) &&
1355          "Input element type should be a valid tensor element type");
1356   return Base::get(type.getContext(), type, dialect, bytes);
1357 }
1358 
1359 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
1360 
1361 /// Return the value at the given index. If index does not refer to a valid
1362 /// element, then a null attribute is returned.
1363 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1364   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1365   return Attribute();
1366 }
1367 
1368 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
1369 
1370 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1371   auto *d = getDialect();
1372   if (!d)
1373     return true;
1374   auto *interface =
1375       d->getRegisteredInterface<DialectDecodeAttributesInterface>();
1376   if (!interface)
1377     return true;
1378   return failed(interface->decode(*this, result));
1379 }
1380 
1381 //===----------------------------------------------------------------------===//
1382 // SparseElementsAttr
1383 //===----------------------------------------------------------------------===//
1384 
1385 SparseElementsAttr SparseElementsAttr::get(ShapedType type,
1386                                            DenseElementsAttr indices,
1387                                            DenseElementsAttr values) {
1388   assert(indices.getType().getElementType().isInteger(64) &&
1389          "expected sparse indices to be 64-bit integer values");
1390   assert((type.isa<RankedTensorType, VectorType>()) &&
1391          "type must be ranked tensor or vector");
1392   assert(type.hasStaticShape() && "type must have static shape");
1393   return Base::get(type.getContext(), type,
1394                    indices.cast<DenseIntElementsAttr>(), values);
1395 }
1396 
1397 DenseIntElementsAttr SparseElementsAttr::getIndices() const {
1398   return getImpl()->indices;
1399 }
1400 
1401 DenseElementsAttr SparseElementsAttr::getValues() const {
1402   return getImpl()->values;
1403 }
1404 
1405 /// Return the value of the element at the given index.
1406 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1407   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1408   auto type = getType();
1409 
1410   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1411   // as a 1-D index array.
1412   auto sparseIndices = getIndices();
1413   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1414 
1415   // Check to see if the indices are a splat.
1416   if (sparseIndices.isSplat()) {
1417     // If the index is also not a splat of the index value, we know that the
1418     // value is zero.
1419     auto splatIndex = *sparseIndexValues.begin();
1420     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
1421       return getZeroAttr();
1422 
1423     // If the indices are a splat, we also expect the values to be a splat.
1424     assert(getValues().isSplat() && "expected splat values");
1425     return getValues().getSplatValue();
1426   }
1427 
1428   // Build a mapping between known indices and the offset of the stored element.
1429   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
1430   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1431   size_t rank = type.getRank();
1432   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1433     mappedIndices.try_emplace(
1434         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1435 
1436   // Look for the provided index key within the mapped indices. If the provided
1437   // index is not found, then return a zero attribute.
1438   auto it = mappedIndices.find(index);
1439   if (it == mappedIndices.end())
1440     return getZeroAttr();
1441 
1442   // Otherwise, return the held sparse value element.
1443   return getValues().getValue(it->second);
1444 }
1445 
1446 /// Get a zero APFloat for the given sparse attribute.
1447 APFloat SparseElementsAttr::getZeroAPFloat() const {
1448   auto eltType = getType().getElementType().cast<FloatType>();
1449   return APFloat(eltType.getFloatSemantics());
1450 }
1451 
1452 /// Get a zero APInt for the given sparse attribute.
1453 APInt SparseElementsAttr::getZeroAPInt() const {
1454   auto eltType = getType().getElementType().cast<IntegerType>();
1455   return APInt::getNullValue(eltType.getWidth());
1456 }
1457 
1458 /// Get a zero attribute for the given attribute type.
1459 Attribute SparseElementsAttr::getZeroAttr() const {
1460   auto eltType = getType().getElementType();
1461 
1462   // Handle floating point elements.
1463   if (eltType.isa<FloatType>())
1464     return FloatAttr::get(eltType, 0);
1465 
1466   // Otherwise, this is an integer.
1467   // TODO: Handle StringAttr here.
1468   return IntegerAttr::get(eltType, 0);
1469 }
1470 
1471 /// Flatten, and return, all of the sparse indices in this attribute in
1472 /// row-major order.
1473 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1474   std::vector<ptrdiff_t> flatSparseIndices;
1475 
1476   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1477   // as a 1-D index array.
1478   auto sparseIndices = getIndices();
1479   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1480   if (sparseIndices.isSplat()) {
1481     SmallVector<uint64_t, 8> indices(getType().getRank(),
1482                                      *sparseIndexValues.begin());
1483     flatSparseIndices.push_back(getFlattenedIndex(indices));
1484     return flatSparseIndices;
1485   }
1486 
1487   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1488   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1489   size_t rank = getType().getRank();
1490   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1491     flatSparseIndices.push_back(getFlattenedIndex(
1492         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1493   return flatSparseIndices;
1494 }
1495 
1496 //===----------------------------------------------------------------------===//
1497 // MutableDictionaryAttr
1498 //===----------------------------------------------------------------------===//
1499 
1500 MutableDictionaryAttr::MutableDictionaryAttr(
1501     ArrayRef<NamedAttribute> attributes) {
1502   setAttrs(attributes);
1503 }
1504 
1505 /// Return the underlying dictionary attribute.
1506 DictionaryAttr
1507 MutableDictionaryAttr::getDictionary(MLIRContext *context) const {
1508   // Construct empty DictionaryAttr if needed.
1509   if (!attrs)
1510     return DictionaryAttr::get({}, context);
1511   return attrs;
1512 }
1513 
1514 ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const {
1515   return attrs ? attrs.getValue() : llvm::None;
1516 }
1517 
1518 /// Replace the held attributes with ones provided in 'newAttrs'.
1519 void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) {
1520   // Don't create an attribute list if there are no attributes.
1521   if (attributes.empty())
1522     attrs = nullptr;
1523   else
1524     attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
1525 }
1526 
1527 /// Return the specified attribute if present, null otherwise.
1528 Attribute MutableDictionaryAttr::get(StringRef name) const {
1529   return attrs ? attrs.get(name) : nullptr;
1530 }
1531 
1532 /// Return the specified attribute if present, null otherwise.
1533 Attribute MutableDictionaryAttr::get(Identifier name) const {
1534   return attrs ? attrs.get(name) : nullptr;
1535 }
1536 
1537 /// Return the specified named attribute if present, None otherwise.
1538 Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const {
1539   return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1540 }
1541 Optional<NamedAttribute>
1542 MutableDictionaryAttr::getNamed(Identifier name) const {
1543   return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1544 }
1545 
1546 /// If the an attribute exists with the specified name, change it to the new
1547 /// value.  Otherwise, add a new attribute with the specified name/value.
1548 void MutableDictionaryAttr::set(Identifier name, Attribute value) {
1549   assert(value && "attributes may never be null");
1550 
1551   // Look for an existing value for the given name, and set it in-place.
1552   ArrayRef<NamedAttribute> values = getAttrs();
1553   const auto *it = llvm::find_if(
1554       values, [name](NamedAttribute attr) { return attr.first == name; });
1555   if (it != values.end()) {
1556     // Bail out early if the value is the same as what we already have.
1557     if (it->second == value)
1558       return;
1559 
1560     SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end());
1561     newAttrs[it - values.begin()].second = value;
1562     attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1563     return;
1564   }
1565 
1566   // Otherwise, insert the new attribute into its sorted position.
1567   it = llvm::lower_bound(values, name);
1568   SmallVector<NamedAttribute, 8> newAttrs;
1569   newAttrs.reserve(values.size() + 1);
1570   newAttrs.append(values.begin(), it);
1571   newAttrs.push_back({name, value});
1572   newAttrs.append(it, values.end());
1573   attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1574 }
1575 
1576 /// Remove the attribute with the specified name if it exists.  The return
1577 /// value indicates whether the attribute was present or not.
1578 auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult {
1579   auto origAttrs = getAttrs();
1580   for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
1581     if (origAttrs[i].first == name) {
1582       // Handle the simple case of removing the only attribute in the list.
1583       if (e == 1) {
1584         attrs = nullptr;
1585         return RemoveResult::Removed;
1586       }
1587 
1588       SmallVector<NamedAttribute, 8> newAttrs;
1589       newAttrs.reserve(origAttrs.size() - 1);
1590       newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
1591       newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
1592       attrs = DictionaryAttr::getWithSorted(newAttrs,
1593                                             newAttrs[0].second.getContext());
1594       return RemoveResult::Removed;
1595     }
1596   }
1597   return RemoveResult::NotFound;
1598 }
1599 
1600 bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) {
1601   return strcmp(lhs.first.data(), rhs.first.data()) < 0;
1602 }
1603 bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) {
1604   // This is correct even when attr.first.data()[name.size()] is not a zero
1605   // string terminator, because we only care about a less than comparison.
1606   // This can't use memcmp, because it doesn't guarantee that it will stop
1607   // reading both buffers if one is shorter than the other, even if there is
1608   // a difference.
1609   return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0;
1610 }
1611