1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BuiltinDialect.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/IntegerSet.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/IR/SymbolTable.h"
18 #include "mlir/IR/Types.h"
19 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
20 #include "llvm/ADT/APSInt.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/Support/Endian.h"
23 
24 using namespace mlir;
25 using namespace mlir::detail;
26 
27 //===----------------------------------------------------------------------===//
28 /// Tablegen Attribute Definitions
29 //===----------------------------------------------------------------------===//
30 
31 #define GET_ATTRDEF_CLASSES
32 #include "mlir/IR/BuiltinAttributes.cpp.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // BuiltinDialect
36 //===----------------------------------------------------------------------===//
37 
registerAttributes()38 void BuiltinDialect::registerAttributes() {
39   addAttributes<AffineMapAttr, ArrayAttr, DenseArrayBaseAttr,
40                 DenseIntOrFPElementsAttr, DenseStringElementsAttr,
41                 DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr,
42                 IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
43                 SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>();
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // ArrayAttr
48 //===----------------------------------------------------------------------===//
49 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const50 void ArrayAttr::walkImmediateSubElements(
51     function_ref<void(Attribute)> walkAttrsFn,
52     function_ref<void(Type)> walkTypesFn) const {
53   for (Attribute attr : getValue())
54     walkAttrsFn(attr);
55 }
56 
57 Attribute
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const58 ArrayAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
59                                        ArrayRef<Type> replTypes) const {
60   return get(getContext(), replAttrs);
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // DictionaryAttr
65 //===----------------------------------------------------------------------===//
66 
67 /// Helper function that does either an in place sort or sorts from source array
68 /// into destination. If inPlace then storage is both the source and the
69 /// destination, else value is the source and storage destination. Returns
70 /// whether source was sorted.
71 template <bool inPlace>
dictionaryAttrSort(ArrayRef<NamedAttribute> value,SmallVectorImpl<NamedAttribute> & storage)72 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
73                                SmallVectorImpl<NamedAttribute> &storage) {
74   // Specialize for the common case.
75   switch (value.size()) {
76   case 0:
77     // Zero already sorted.
78     if (!inPlace)
79       storage.clear();
80     break;
81   case 1:
82     // One already sorted but may need to be copied.
83     if (!inPlace)
84       storage.assign({value[0]});
85     break;
86   case 2: {
87     bool isSorted = value[0] < value[1];
88     if (inPlace) {
89       if (!isSorted)
90         std::swap(storage[0], storage[1]);
91     } else if (isSorted) {
92       storage.assign({value[0], value[1]});
93     } else {
94       storage.assign({value[1], value[0]});
95     }
96     return !isSorted;
97   }
98   default:
99     if (!inPlace)
100       storage.assign(value.begin(), value.end());
101     // Check to see they are sorted already.
102     bool isSorted = llvm::is_sorted(value);
103     // If not, do a general sort.
104     if (!isSorted)
105       llvm::array_pod_sort(storage.begin(), storage.end());
106     return !isSorted;
107   }
108   return false;
109 }
110 
111 /// Returns an entry with a duplicate name from the given sorted array of named
112 /// attributes. Returns llvm::None if all elements have unique names.
113 static Optional<NamedAttribute>
findDuplicateElement(ArrayRef<NamedAttribute> value)114 findDuplicateElement(ArrayRef<NamedAttribute> value) {
115   const Optional<NamedAttribute> none{llvm::None};
116   if (value.size() < 2)
117     return none;
118 
119   if (value.size() == 2)
120     return value[0].getName() == value[1].getName() ? value[0] : none;
121 
122   const auto *it = std::adjacent_find(value.begin(), value.end(),
123                                       [](NamedAttribute l, NamedAttribute r) {
124                                         return l.getName() == r.getName();
125                                       });
126   return it != value.end() ? *it : none;
127 }
128 
sort(ArrayRef<NamedAttribute> value,SmallVectorImpl<NamedAttribute> & storage)129 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
130                           SmallVectorImpl<NamedAttribute> &storage) {
131   bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
132   assert(!findDuplicateElement(storage) &&
133          "DictionaryAttr element names must be unique");
134   return isSorted;
135 }
136 
sortInPlace(SmallVectorImpl<NamedAttribute> & array)137 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
138   bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
139   assert(!findDuplicateElement(array) &&
140          "DictionaryAttr element names must be unique");
141   return isSorted;
142 }
143 
144 Optional<NamedAttribute>
findDuplicate(SmallVectorImpl<NamedAttribute> & array,bool isSorted)145 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
146                               bool isSorted) {
147   if (!isSorted)
148     dictionaryAttrSort</*inPlace=*/true>(array, array);
149   return findDuplicateElement(array);
150 }
151 
get(MLIRContext * context,ArrayRef<NamedAttribute> value)152 DictionaryAttr DictionaryAttr::get(MLIRContext *context,
153                                    ArrayRef<NamedAttribute> value) {
154   if (value.empty())
155     return DictionaryAttr::getEmpty(context);
156 
157   // We need to sort the element list to canonicalize it.
158   SmallVector<NamedAttribute, 8> storage;
159   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
160     value = storage;
161   assert(!findDuplicateElement(value) &&
162          "DictionaryAttr element names must be unique");
163   return Base::get(context, value);
164 }
165 /// Construct a dictionary with an array of values that is known to already be
166 /// sorted by name and uniqued.
getWithSorted(MLIRContext * context,ArrayRef<NamedAttribute> value)167 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
168                                              ArrayRef<NamedAttribute> value) {
169   if (value.empty())
170     return DictionaryAttr::getEmpty(context);
171   // Ensure that the attribute elements are unique and sorted.
172   assert(llvm::is_sorted(
173              value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) &&
174          "expected attribute values to be sorted");
175   assert(!findDuplicateElement(value) &&
176          "DictionaryAttr element names must be unique");
177   return Base::get(context, value);
178 }
179 
180 /// Return the specified attribute if present, null otherwise.
get(StringRef name) const181 Attribute DictionaryAttr::get(StringRef name) const {
182   auto it = impl::findAttrSorted(begin(), end(), name);
183   return it.second ? it.first->getValue() : Attribute();
184 }
get(StringAttr name) const185 Attribute DictionaryAttr::get(StringAttr name) const {
186   auto it = impl::findAttrSorted(begin(), end(), name);
187   return it.second ? it.first->getValue() : Attribute();
188 }
189 
190 /// Return the specified named attribute if present, None otherwise.
getNamed(StringRef name) const191 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
192   auto it = impl::findAttrSorted(begin(), end(), name);
193   return it.second ? *it.first : Optional<NamedAttribute>();
194 }
getNamed(StringAttr name) const195 Optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const {
196   auto it = impl::findAttrSorted(begin(), end(), name);
197   return it.second ? *it.first : Optional<NamedAttribute>();
198 }
199 
200 /// Return whether the specified attribute is present.
contains(StringRef name) const201 bool DictionaryAttr::contains(StringRef name) const {
202   return impl::findAttrSorted(begin(), end(), name).second;
203 }
contains(StringAttr name) const204 bool DictionaryAttr::contains(StringAttr name) const {
205   return impl::findAttrSorted(begin(), end(), name).second;
206 }
207 
begin() const208 DictionaryAttr::iterator DictionaryAttr::begin() const {
209   return getValue().begin();
210 }
end() const211 DictionaryAttr::iterator DictionaryAttr::end() const {
212   return getValue().end();
213 }
size() const214 size_t DictionaryAttr::size() const { return getValue().size(); }
215 
getEmptyUnchecked(MLIRContext * context)216 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
217   return Base::get(context, ArrayRef<NamedAttribute>());
218 }
219 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const220 void DictionaryAttr::walkImmediateSubElements(
221     function_ref<void(Attribute)> walkAttrsFn,
222     function_ref<void(Type)> walkTypesFn) const {
223   for (const NamedAttribute &attr : getValue())
224     walkAttrsFn(attr.getValue());
225 }
226 
227 Attribute
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const228 DictionaryAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
229                                             ArrayRef<Type> replTypes) const {
230   std::vector<NamedAttribute> vec = getValue().vec();
231   for (auto &it : llvm::enumerate(replAttrs))
232     vec[it.index()].setValue(it.value());
233 
234   // The above only modifies the mapped value, but not the key, and therefore
235   // not the order of the elements. It remains sorted
236   return getWithSorted(getContext(), vec);
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // StringAttr
241 //===----------------------------------------------------------------------===//
242 
getEmptyStringAttrUnchecked(MLIRContext * context)243 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) {
244   return Base::get(context, "", NoneType::get(context));
245 }
246 
247 /// Twine support for StringAttr.
get(MLIRContext * context,const Twine & twine)248 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) {
249   // Fast-path empty twine.
250   if (twine.isTriviallyEmpty())
251     return get(context);
252   SmallVector<char, 32> tempStr;
253   return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context));
254 }
255 
256 /// Twine support for StringAttr.
get(const Twine & twine,Type type)257 StringAttr StringAttr::get(const Twine &twine, Type type) {
258   SmallVector<char, 32> tempStr;
259   return Base::get(type.getContext(), twine.toStringRef(tempStr), type);
260 }
261 
getValue() const262 StringRef StringAttr::getValue() const { return getImpl()->value; }
263 
getReferencedDialect() const264 Dialect *StringAttr::getReferencedDialect() const {
265   return getImpl()->referencedDialect;
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // FloatAttr
270 //===----------------------------------------------------------------------===//
271 
getValueAsDouble() const272 double FloatAttr::getValueAsDouble() const {
273   return getValueAsDouble(getValue());
274 }
getValueAsDouble(APFloat value)275 double FloatAttr::getValueAsDouble(APFloat value) {
276   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
277     bool losesInfo = false;
278     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
279                   &losesInfo);
280   }
281   return value.convertToDouble();
282 }
283 
verify(function_ref<InFlightDiagnostic ()> emitError,Type type,APFloat value)284 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
285                                 Type type, APFloat value) {
286   // Verify that the type is correct.
287   if (!type.isa<FloatType>())
288     return emitError() << "expected floating point type";
289 
290   // Verify that the type semantics match that of the value.
291   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
292     return emitError()
293            << "FloatAttr type doesn't match the type implied by its value";
294   }
295   return success();
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // SymbolRefAttr
300 //===----------------------------------------------------------------------===//
301 
get(MLIRContext * ctx,StringRef value,ArrayRef<FlatSymbolRefAttr> nestedRefs)302 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
303                                  ArrayRef<FlatSymbolRefAttr> nestedRefs) {
304   return get(StringAttr::get(ctx, value), nestedRefs);
305 }
306 
get(MLIRContext * ctx,StringRef value)307 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
308   return get(ctx, value, {}).cast<FlatSymbolRefAttr>();
309 }
310 
get(StringAttr value)311 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
312   return get(value, {}).cast<FlatSymbolRefAttr>();
313 }
314 
get(Operation * symbol)315 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) {
316   auto symName =
317       symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
318   assert(symName && "value does not have a valid symbol name");
319   return SymbolRefAttr::get(symName);
320 }
321 
getLeafReference() const322 StringAttr SymbolRefAttr::getLeafReference() const {
323   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
324   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
325 }
326 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const327 void SymbolRefAttr::walkImmediateSubElements(
328     function_ref<void(Attribute)> walkAttrsFn,
329     function_ref<void(Type)> walkTypesFn) const {
330   walkAttrsFn(getRootReference());
331   for (FlatSymbolRefAttr ref : getNestedReferences())
332     walkAttrsFn(ref);
333 }
334 
335 Attribute
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const336 SymbolRefAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
337                                            ArrayRef<Type> replTypes) const {
338   ArrayRef<Attribute> rawNestedRefs = replAttrs.drop_front();
339   ArrayRef<FlatSymbolRefAttr> nestedRefs(
340       static_cast<const FlatSymbolRefAttr *>(rawNestedRefs.data()),
341       rawNestedRefs.size());
342   return get(replAttrs[0].cast<StringAttr>(), nestedRefs);
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // IntegerAttr
347 //===----------------------------------------------------------------------===//
348 
getInt() const349 int64_t IntegerAttr::getInt() const {
350   assert((getType().isIndex() || getType().isSignlessInteger()) &&
351          "must be signless integer");
352   return getValue().getSExtValue();
353 }
354 
getSInt() const355 int64_t IntegerAttr::getSInt() const {
356   assert(getType().isSignedInteger() && "must be signed integer");
357   return getValue().getSExtValue();
358 }
359 
getUInt() const360 uint64_t IntegerAttr::getUInt() const {
361   assert(getType().isUnsignedInteger() && "must be unsigned integer");
362   return getValue().getZExtValue();
363 }
364 
365 /// Return the value as an APSInt which carries the signed from the type of
366 /// the attribute.  This traps on signless integers types!
getAPSInt() const367 APSInt IntegerAttr::getAPSInt() const {
368   assert(!getType().isSignlessInteger() &&
369          "Signless integers don't carry a sign for APSInt");
370   return APSInt(getValue(), getType().isUnsignedInteger());
371 }
372 
verify(function_ref<InFlightDiagnostic ()> emitError,Type type,APInt value)373 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
374                                   Type type, APInt value) {
375   if (IntegerType integerType = type.dyn_cast<IntegerType>()) {
376     if (integerType.getWidth() != value.getBitWidth())
377       return emitError() << "integer type bit width (" << integerType.getWidth()
378                          << ") doesn't match value bit width ("
379                          << value.getBitWidth() << ")";
380     return success();
381   }
382   if (type.isa<IndexType>())
383     return success();
384   return emitError() << "expected integer or index type";
385 }
386 
getBoolAttrUnchecked(IntegerType type,bool value)387 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
388   auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
389   return attr.cast<BoolAttr>();
390 }
391 
392 //===----------------------------------------------------------------------===//
393 // BoolAttr
394 //===----------------------------------------------------------------------===//
395 
getValue() const396 bool BoolAttr::getValue() const {
397   auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
398   return storage->value.getBoolValue();
399 }
400 
classof(Attribute attr)401 bool BoolAttr::classof(Attribute attr) {
402   IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
403   return intAttr && intAttr.getType().isSignlessInteger(1);
404 }
405 
406 //===----------------------------------------------------------------------===//
407 // OpaqueAttr
408 //===----------------------------------------------------------------------===//
409 
verify(function_ref<InFlightDiagnostic ()> emitError,StringAttr dialect,StringRef attrData,Type type)410 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
411                                  StringAttr dialect, StringRef attrData,
412                                  Type type) {
413   if (!Dialect::isValidNamespace(dialect.strref()))
414     return emitError() << "invalid dialect namespace '" << dialect << "'";
415 
416   // Check that the dialect is actually registered.
417   MLIRContext *context = dialect.getContext();
418   if (!context->allowsUnregisteredDialects() &&
419       !context->getLoadedDialect(dialect.strref())) {
420     return emitError()
421            << "#" << dialect << "<\"" << attrData << "\"> : " << type
422            << " attribute created with unregistered dialect. If this is "
423               "intended, please call allowUnregisteredDialects() on the "
424               "MLIRContext, or use -allow-unregistered-dialect with "
425               "the MLIR opt tool used";
426   }
427 
428   return success();
429 }
430 
431 //===----------------------------------------------------------------------===//
432 // DenseElementsAttr Utilities
433 //===----------------------------------------------------------------------===//
434 
435 /// Get the bitwidth of a dense element type within the buffer.
436 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
getDenseElementStorageWidth(size_t origWidth)437 static size_t getDenseElementStorageWidth(size_t origWidth) {
438   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
439 }
getDenseElementStorageWidth(Type elementType)440 static size_t getDenseElementStorageWidth(Type elementType) {
441   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
442 }
443 
444 /// Set a bit to a specific value.
setBit(char * rawData,size_t bitPos,bool value)445 static void setBit(char *rawData, size_t bitPos, bool value) {
446   if (value)
447     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
448   else
449     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
450 }
451 
452 /// Return the value of the specified bit.
getBit(const char * rawData,size_t bitPos)453 static bool getBit(const char *rawData, size_t bitPos) {
454   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
455 }
456 
457 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
458 /// BE format.
copyAPIntToArrayForBEmachine(APInt value,size_t numBytes,char * result)459 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
460                                          char *result) {
461   assert(llvm::support::endian::system_endianness() == // NOLINT
462          llvm::support::endianness::big);              // NOLINT
463   assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
464 
465   // Copy the words filled with data.
466   // For example, when `value` has 2 words, the first word is filled with data.
467   // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
468   size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
469   std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
470               numFilledWords, result);
471   // Convert last word of APInt to LE format and store it in char
472   // array(`valueLE`).
473   // ex. last word of `value` (BE): |------ij|  ==> `valueLE` (LE): |ji------|
474   size_t lastWordPos = numFilledWords;
475   SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
476   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
477       reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
478       valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
479   // Extract actual APInt data from `valueLE`, convert endianness to BE format,
480   // and store it in `result`.
481   // ex. `valueLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|ij|
482   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
483       valueLE.begin(), result + lastWordPos,
484       (numBytes - lastWordPos) * CHAR_BIT, 1);
485 }
486 
487 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
488 /// format.
copyArrayToAPIntForBEmachine(const char * inArray,size_t numBytes,APInt & result)489 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
490                                          APInt &result) {
491   assert(llvm::support::endian::system_endianness() == // NOLINT
492          llvm::support::endianness::big);              // NOLINT
493   assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
494 
495   // Copy the data that fills the word of `result` from `inArray`.
496   // For example, when `result` has 2 words, the first word will be filled with
497   // data. So, the first 8 bytes are copied from `inArray` here.
498   // `inArray` (10 bytes, BE): |abcdefgh|ij|
499   //                     ==> `result` (2 words, BE): |abcdefgh|--------|
500   size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
501   std::copy_n(
502       inArray, numFilledWords,
503       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
504 
505   // Convert array data which will be last word of `result` to LE format, and
506   // store it in char array(`inArrayLE`).
507   // ex. `inArray` (last two bytes, BE): |ij|  ==> `inArrayLE` (LE): |ji------|
508   size_t lastWordPos = numFilledWords;
509   SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
510   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
511       inArray + lastWordPos, inArrayLE.begin(),
512       (numBytes - lastWordPos) * CHAR_BIT, 1);
513 
514   // Convert `inArrayLE` to BE format, and store it in last word of `result`.
515   // ex. `inArrayLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|------ij|
516   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
517       inArrayLE.begin(),
518       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
519           lastWordPos,
520       APInt::APINT_BITS_PER_WORD, 1);
521 }
522 
523 /// Writes value to the bit position `bitPos` in array `rawData`.
writeBits(char * rawData,size_t bitPos,APInt value)524 static void writeBits(char *rawData, size_t bitPos, APInt value) {
525   size_t bitWidth = value.getBitWidth();
526 
527   // If the bitwidth is 1 we just toggle the specific bit.
528   if (bitWidth == 1)
529     return setBit(rawData, bitPos, value.isOneValue());
530 
531   // Otherwise, the bit position is guaranteed to be byte aligned.
532   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
533   if (llvm::support::endian::system_endianness() ==
534       llvm::support::endianness::big) {
535     // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
536     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
537     // work correctly in BE format.
538     // ex. `value` (2 words including 10 bytes)
539     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------|
540     copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
541                                  rawData + (bitPos / CHAR_BIT));
542   } else {
543     std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
544                 llvm::divideCeil(bitWidth, CHAR_BIT),
545                 rawData + (bitPos / CHAR_BIT));
546   }
547 }
548 
549 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
550 /// `rawData`.
readBits(const char * rawData,size_t bitPos,size_t bitWidth)551 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
552   // Handle a boolean bit position.
553   if (bitWidth == 1)
554     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
555 
556   // Otherwise, the bit position must be 8-bit aligned.
557   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
558   APInt result(bitWidth, 0);
559   if (llvm::support::endian::system_endianness() ==
560       llvm::support::endianness::big) {
561     // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
562     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
563     // work correctly in BE format.
564     // ex. `result` (2 words including 10 bytes)
565     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------| This function
566     copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
567                                  llvm::divideCeil(bitWidth, CHAR_BIT), result);
568   } else {
569     std::copy_n(rawData + (bitPos / CHAR_BIT),
570                 llvm::divideCeil(bitWidth, CHAR_BIT),
571                 const_cast<char *>(
572                     reinterpret_cast<const char *>(result.getRawData())));
573   }
574   return result;
575 }
576 
577 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
578 /// the same element count as 'type'.
579 template <typename Values>
hasSameElementsOrSplat(ShapedType type,const Values & values)580 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
581   return (values.size() == 1) ||
582          (type.getNumElements() == static_cast<int64_t>(values.size()));
583 }
584 
585 //===----------------------------------------------------------------------===//
586 // DenseElementsAttr Iterators
587 //===----------------------------------------------------------------------===//
588 
589 //===----------------------------------------------------------------------===//
590 // AttributeElementIterator
591 
AttributeElementIterator(DenseElementsAttr attr,size_t index)592 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
593     DenseElementsAttr attr, size_t index)
594     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
595                                       Attribute, Attribute, Attribute>(
596           attr.getAsOpaquePointer(), index) {}
597 
operator *() const598 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
599   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
600   Type eltTy = owner.getElementType();
601   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
602     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
603   if (eltTy.isa<IndexType>())
604     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
605   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
606     IntElementIterator intIt(owner, index);
607     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
608     return FloatAttr::get(eltTy, *floatIt);
609   }
610   if (auto complexTy = eltTy.dyn_cast<ComplexType>()) {
611     auto complexEltTy = complexTy.getElementType();
612     ComplexIntElementIterator complexIntIt(owner, index);
613     if (complexEltTy.isa<IntegerType>()) {
614       auto value = *complexIntIt;
615       auto real = IntegerAttr::get(complexEltTy, value.real());
616       auto imag = IntegerAttr::get(complexEltTy, value.imag());
617       return ArrayAttr::get(complexTy.getContext(),
618                             ArrayRef<Attribute>{real, imag});
619     }
620 
621     ComplexFloatElementIterator complexFloatIt(
622         complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt);
623     auto value = *complexFloatIt;
624     auto real = FloatAttr::get(complexEltTy, value.real());
625     auto imag = FloatAttr::get(complexEltTy, value.imag());
626     return ArrayAttr::get(complexTy.getContext(),
627                           ArrayRef<Attribute>{real, imag});
628   }
629   if (owner.isa<DenseStringElementsAttr>()) {
630     ArrayRef<StringRef> vals = owner.getRawStringData();
631     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
632   }
633   llvm_unreachable("unexpected element type");
634 }
635 
636 //===----------------------------------------------------------------------===//
637 // BoolElementIterator
638 
BoolElementIterator(DenseElementsAttr attr,size_t dataIndex)639 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
640     DenseElementsAttr attr, size_t dataIndex)
641     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
642           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
643 
operator *() const644 bool DenseElementsAttr::BoolElementIterator::operator*() const {
645   return getBit(getData(), getDataIndex());
646 }
647 
648 //===----------------------------------------------------------------------===//
649 // IntElementIterator
650 
IntElementIterator(DenseElementsAttr attr,size_t dataIndex)651 DenseElementsAttr::IntElementIterator::IntElementIterator(
652     DenseElementsAttr attr, size_t dataIndex)
653     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
654           attr.getRawData().data(), attr.isSplat(), dataIndex),
655       bitWidth(getDenseElementBitWidth(attr.getElementType())) {}
656 
operator *() const657 APInt DenseElementsAttr::IntElementIterator::operator*() const {
658   return readBits(getData(),
659                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
660                   bitWidth);
661 }
662 
663 //===----------------------------------------------------------------------===//
664 // ComplexIntElementIterator
665 
ComplexIntElementIterator(DenseElementsAttr attr,size_t dataIndex)666 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
667     DenseElementsAttr attr, size_t dataIndex)
668     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
669                                       std::complex<APInt>, std::complex<APInt>,
670                                       std::complex<APInt>>(
671           attr.getRawData().data(), attr.isSplat(), dataIndex) {
672   auto complexType = attr.getElementType().cast<ComplexType>();
673   bitWidth = getDenseElementBitWidth(complexType.getElementType());
674 }
675 
676 std::complex<APInt>
operator *() const677 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
678   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
679   size_t offset = getDataIndex() * storageWidth * 2;
680   return {readBits(getData(), offset, bitWidth),
681           readBits(getData(), offset + storageWidth, bitWidth)};
682 }
683 
684 //===----------------------------------------------------------------------===//
685 // DenseArrayAttr
686 //===----------------------------------------------------------------------===//
687 
688 /// Custom storage to ensure proper memory alignment for the allocation of
689 /// DenseArray of any element type.
690 struct mlir::detail::DenseArrayBaseAttrStorage : public AttributeStorage {
691   using KeyTy = std::tuple<ShapedType, DenseArrayBaseAttr::EltType,
692                            ::llvm::ArrayRef<char>>;
DenseArrayBaseAttrStoragemlir::detail::DenseArrayBaseAttrStorage693   DenseArrayBaseAttrStorage(ShapedType type,
694                             DenseArrayBaseAttr::EltType eltType,
695                             ::llvm::ArrayRef<char> elements)
696       : AttributeStorage(type), eltType(eltType), elements(elements) {}
697 
operator ==mlir::detail::DenseArrayBaseAttrStorage698   bool operator==(const KeyTy &tblgenKey) const {
699     return (getType() == std::get<0>(tblgenKey)) &&
700            (eltType == std::get<1>(tblgenKey)) &&
701            (elements == std::get<2>(tblgenKey));
702   }
703 
hashKeymlir::detail::DenseArrayBaseAttrStorage704   static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) {
705     return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey),
706                                 std::get<2>(tblgenKey));
707   }
708 
709   static DenseArrayBaseAttrStorage *
constructmlir::detail::DenseArrayBaseAttrStorage710   construct(AttributeStorageAllocator &allocator, const KeyTy &tblgenKey) {
711     auto type = std::get<0>(tblgenKey);
712     auto eltType = std::get<1>(tblgenKey);
713     auto elements = std::get<2>(tblgenKey);
714     if (!elements.empty()) {
715       char *alloc = static_cast<char *>(
716           allocator.allocate(elements.size(), alignof(uint64_t)));
717       std::uninitialized_copy(elements.begin(), elements.end(), alloc);
718       elements = ArrayRef<char>(alloc, elements.size());
719     }
720     return new (allocator.allocate<DenseArrayBaseAttrStorage>())
721         DenseArrayBaseAttrStorage(type, eltType, elements);
722   }
723 
724   DenseArrayBaseAttr::EltType eltType;
725   ::llvm::ArrayRef<char> elements;
726 };
727 
getElementType() const728 DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const {
729   return getImpl()->eltType;
730 }
731 
732 const int8_t *
value_begin_impl(OverloadToken<int8_t>) const733 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
734   return cast<DenseI8ArrayAttr>().asArrayRef().begin();
735 }
736 const int16_t *
value_begin_impl(OverloadToken<int16_t>) const737 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int16_t>) const {
738   return cast<DenseI16ArrayAttr>().asArrayRef().begin();
739 }
740 const int32_t *
value_begin_impl(OverloadToken<int32_t>) const741 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int32_t>) const {
742   return cast<DenseI32ArrayAttr>().asArrayRef().begin();
743 }
744 const int64_t *
value_begin_impl(OverloadToken<int64_t>) const745 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int64_t>) const {
746   return cast<DenseI64ArrayAttr>().asArrayRef().begin();
747 }
value_begin_impl(OverloadToken<float>) const748 const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken<float>) const {
749   return cast<DenseF32ArrayAttr>().asArrayRef().begin();
750 }
751 const double *
value_begin_impl(OverloadToken<double>) const752 DenseArrayBaseAttr::value_begin_impl(OverloadToken<double>) const {
753   return cast<DenseF64ArrayAttr>().asArrayRef().begin();
754 }
755 
print(AsmPrinter & printer) const756 void DenseArrayBaseAttr::print(AsmPrinter &printer) const {
757   print(printer.getStream());
758 }
759 
printWithoutBraces(raw_ostream & os) const760 void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const {
761   switch (getElementType()) {
762   case DenseArrayBaseAttr::EltType::I8:
763     this->cast<DenseI8ArrayAttr>().printWithoutBraces(os);
764     return;
765   case DenseArrayBaseAttr::EltType::I16:
766     this->cast<DenseI16ArrayAttr>().printWithoutBraces(os);
767     return;
768   case DenseArrayBaseAttr::EltType::I32:
769     this->cast<DenseI32ArrayAttr>().printWithoutBraces(os);
770     return;
771   case DenseArrayBaseAttr::EltType::I64:
772     this->cast<DenseI64ArrayAttr>().printWithoutBraces(os);
773     return;
774   case DenseArrayBaseAttr::EltType::F32:
775     this->cast<DenseF32ArrayAttr>().printWithoutBraces(os);
776     return;
777   case DenseArrayBaseAttr::EltType::F64:
778     this->cast<DenseF64ArrayAttr>().printWithoutBraces(os);
779     return;
780   }
781   llvm_unreachable("<unknown DenseArrayBaseAttr>");
782 }
783 
print(raw_ostream & os) const784 void DenseArrayBaseAttr::print(raw_ostream &os) const {
785   os << "[";
786   printWithoutBraces(os);
787   os << "]";
788 }
789 
790 template <typename T>
print(AsmPrinter & printer) const791 void DenseArrayAttr<T>::print(AsmPrinter &printer) const {
792   print(printer.getStream());
793 }
794 
795 template <typename T>
printWithoutBraces(raw_ostream & os) const796 void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const {
797   ArrayRef<T> values{*this};
798   llvm::interleaveComma(values, os);
799 }
800 
801 /// Specialization for int8_t for forcing printing as number instead of chars.
802 template <>
printWithoutBraces(raw_ostream & os) const803 void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const {
804   ArrayRef<int8_t> values{*this};
805   llvm::interleaveComma(values, os, [&](int64_t v) { os << v; });
806 }
807 
808 template <typename T>
print(raw_ostream & os) const809 void DenseArrayAttr<T>::print(raw_ostream &os) const {
810   os << "[";
811   printWithoutBraces(os);
812   os << "]";
813 }
814 
815 /// Parse a single element: generic template for int types, specialized for
816 /// floating points below.
817 template <typename T>
parseDenseArrayAttrElt(AsmParser & parser,T & value)818 static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) {
819   return parser.parseInteger(value);
820 }
821 
822 template <>
parseDenseArrayAttrElt(AsmParser & parser,float & value)823 ParseResult parseDenseArrayAttrElt<float>(AsmParser &parser, float &value) {
824   double doubleVal;
825   if (parser.parseFloat(doubleVal))
826     return failure();
827   value = doubleVal;
828   return success();
829 }
830 
831 template <>
parseDenseArrayAttrElt(AsmParser & parser,double & value)832 ParseResult parseDenseArrayAttrElt<double>(AsmParser &parser, double &value) {
833   return parser.parseFloat(value);
834 }
835 
836 /// Parse a DenseArrayAttr without the braces: `1, 2, 3`
837 template <typename T>
parseWithoutBraces(AsmParser & parser,Type odsType)838 Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
839                                                 Type odsType) {
840   SmallVector<T> data;
841   if (failed(parser.parseCommaSeparatedList([&]() {
842         T value;
843         if (parseDenseArrayAttrElt(parser, value))
844           return failure();
845         data.push_back(value);
846         return success();
847       })))
848     return {};
849   return get(parser.getContext(), data);
850 }
851 
852 /// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
853 template <typename T>
parse(AsmParser & parser,Type odsType)854 Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
855   if (parser.parseLSquare())
856     return {};
857   // Handle empty list case.
858   if (succeeded(parser.parseOptionalRSquare()))
859     return get(parser.getContext(), {});
860   Attribute result = parseWithoutBraces(parser, odsType);
861   if (parser.parseRSquare())
862     return {};
863   return result;
864 }
865 
866 /// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
867 template <typename T>
operator ArrayRef<T>() const868 DenseArrayAttr<T>::operator ArrayRef<T>() const {
869   ArrayRef<char> raw = getImpl()->elements;
870   assert((raw.size() % sizeof(T)) == 0);
871   return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
872                      raw.size() / sizeof(T));
873 }
874 
875 namespace {
876 /// Mapping from C++ element type to MLIR DenseArrayAttr internals.
877 template <typename T>
878 struct denseArrayAttrEltTypeBuilder;
879 template <>
880 struct denseArrayAttrEltTypeBuilder<int8_t> {
881   constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8;
getShapedType__anonfb591a410511::denseArrayAttrEltTypeBuilder882   static ShapedType getShapedType(MLIRContext *context,
883                                   ArrayRef<int64_t> shape) {
884     return VectorType::get(shape, IntegerType::get(context, 8));
885   }
886 };
887 template <>
888 struct denseArrayAttrEltTypeBuilder<int16_t> {
889   constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16;
getShapedType__anonfb591a410511::denseArrayAttrEltTypeBuilder890   static ShapedType getShapedType(MLIRContext *context,
891                                   ArrayRef<int64_t> shape) {
892     return VectorType::get(shape, IntegerType::get(context, 16));
893   }
894 };
895 template <>
896 struct denseArrayAttrEltTypeBuilder<int32_t> {
897   constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32;
getShapedType__anonfb591a410511::denseArrayAttrEltTypeBuilder898   static ShapedType getShapedType(MLIRContext *context,
899                                   ArrayRef<int64_t> shape) {
900     return VectorType::get(shape, IntegerType::get(context, 32));
901   }
902 };
903 template <>
904 struct denseArrayAttrEltTypeBuilder<int64_t> {
905   constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64;
getShapedType__anonfb591a410511::denseArrayAttrEltTypeBuilder906   static ShapedType getShapedType(MLIRContext *context,
907                                   ArrayRef<int64_t> shape) {
908     return VectorType::get(shape, IntegerType::get(context, 64));
909   }
910 };
911 template <>
912 struct denseArrayAttrEltTypeBuilder<float> {
913   constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32;
getShapedType__anonfb591a410511::denseArrayAttrEltTypeBuilder914   static ShapedType getShapedType(MLIRContext *context,
915                                   ArrayRef<int64_t> shape) {
916     return VectorType::get(shape, Float32Type::get(context));
917   }
918 };
919 template <>
920 struct denseArrayAttrEltTypeBuilder<double> {
921   constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64;
getShapedType__anonfb591a410511::denseArrayAttrEltTypeBuilder922   static ShapedType getShapedType(MLIRContext *context,
923                                   ArrayRef<int64_t> shape) {
924     return VectorType::get(shape, Float64Type::get(context));
925   }
926 };
927 } // namespace
928 
929 /// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
930 template <typename T>
get(MLIRContext * context,ArrayRef<T> content)931 DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context,
932                                          ArrayRef<T> content) {
933   auto size = static_cast<int64_t>(content.size());
934   auto shapedType = denseArrayAttrEltTypeBuilder<T>::getShapedType(
935       context, size ? ArrayRef<int64_t>{size} : ArrayRef<int64_t>{});
936   auto eltType = denseArrayAttrEltTypeBuilder<T>::eltType;
937   auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
938                                  content.size() * sizeof(T));
939   return Base::get(context, shapedType, eltType, rawArray)
940       .template cast<DenseArrayAttr<T>>();
941 }
942 
943 template <typename T>
classof(Attribute attr)944 bool DenseArrayAttr<T>::classof(Attribute attr) {
945   return attr.isa<DenseArrayBaseAttr>() &&
946          attr.cast<DenseArrayBaseAttr>().getElementType() ==
947              denseArrayAttrEltTypeBuilder<T>::eltType;
948 }
949 
950 namespace mlir {
951 namespace detail {
952 // Explicit instantiation for all the supported DenseArrayAttr.
953 template class DenseArrayAttr<int8_t>;
954 template class DenseArrayAttr<int16_t>;
955 template class DenseArrayAttr<int32_t>;
956 template class DenseArrayAttr<int64_t>;
957 template class DenseArrayAttr<float>;
958 template class DenseArrayAttr<double>;
959 } // namespace detail
960 } // namespace mlir
961 
962 //===----------------------------------------------------------------------===//
963 // DenseElementsAttr
964 //===----------------------------------------------------------------------===//
965 
966 /// Method for support type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)967 bool DenseElementsAttr::classof(Attribute attr) {
968   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
969 }
970 
get(ShapedType type,ArrayRef<Attribute> values)971 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
972                                          ArrayRef<Attribute> values) {
973   assert(hasSameElementsOrSplat(type, values));
974 
975   // If the element type is not based on int/float/index, assume it is a string
976   // type.
977   auto eltType = type.getElementType();
978   if (!type.getElementType().isIntOrIndexOrFloat()) {
979     SmallVector<StringRef, 8> stringValues;
980     stringValues.reserve(values.size());
981     for (Attribute attr : values) {
982       assert(attr.isa<StringAttr>() &&
983              "expected string value for non integer/index/float element");
984       stringValues.push_back(attr.cast<StringAttr>().getValue());
985     }
986     return get(type, stringValues);
987   }
988 
989   // Otherwise, get the raw storage width to use for the allocation.
990   size_t bitWidth = getDenseElementBitWidth(eltType);
991   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
992 
993   // Compress the attribute values into a character buffer.
994   SmallVector<char, 8> data(
995       llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT));
996   APInt intVal;
997   for (unsigned i = 0, e = values.size(); i < e; ++i) {
998     assert(eltType == values[i].getType() &&
999            "expected attribute value to have element type");
1000     if (eltType.isa<FloatType>())
1001       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
1002     else if (eltType.isa<IntegerType, IndexType>())
1003       intVal = values[i].cast<IntegerAttr>().getValue();
1004     else
1005       llvm_unreachable("unexpected element type");
1006 
1007     assert(intVal.getBitWidth() == bitWidth &&
1008            "expected value to have same bitwidth as element type");
1009     writeBits(data.data(), i * storageBitWidth, intVal);
1010   }
1011 
1012   // Handle the special encoding of splat of bool.
1013   if (values.size() == 1 && values[0].getType().isInteger(1))
1014     data[0] = data[0] ? -1 : 0;
1015 
1016   return DenseIntOrFPElementsAttr::getRaw(type, data);
1017 }
1018 
get(ShapedType type,ArrayRef<bool> values)1019 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1020                                          ArrayRef<bool> values) {
1021   assert(hasSameElementsOrSplat(type, values));
1022   assert(type.getElementType().isInteger(1));
1023 
1024   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
1025 
1026   if (!values.empty()) {
1027     bool isSplat = true;
1028     bool firstValue = values[0];
1029     for (int i = 0, e = values.size(); i != e; ++i) {
1030       isSplat &= values[i] == firstValue;
1031       setBit(buff.data(), i, values[i]);
1032     }
1033 
1034     // Splat of bool is encoded as a byte with all-ones in it.
1035     if (isSplat) {
1036       buff.resize(1);
1037       buff[0] = values[0] ? -1 : 0;
1038     }
1039   }
1040 
1041   return DenseIntOrFPElementsAttr::getRaw(type, buff);
1042 }
1043 
get(ShapedType type,ArrayRef<StringRef> values)1044 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1045                                          ArrayRef<StringRef> values) {
1046   assert(!type.getElementType().isIntOrFloat());
1047   return DenseStringElementsAttr::get(type, values);
1048 }
1049 
1050 /// Constructs a dense integer elements attribute from an array of APInt
1051 /// values. Each APInt value is expected to have the same bitwidth as the
1052 /// element type of 'type'.
get(ShapedType type,ArrayRef<APInt> values)1053 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1054                                          ArrayRef<APInt> values) {
1055   assert(type.getElementType().isIntOrIndex());
1056   assert(hasSameElementsOrSplat(type, values));
1057   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
1058   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1059 }
get(ShapedType type,ArrayRef<std::complex<APInt>> values)1060 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1061                                          ArrayRef<std::complex<APInt>> values) {
1062   ComplexType complex = type.getElementType().cast<ComplexType>();
1063   assert(complex.getElementType().isa<IntegerType>());
1064   assert(hasSameElementsOrSplat(type, values));
1065   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
1066   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
1067                           values.size() * 2);
1068   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals);
1069 }
1070 
1071 // Constructs a dense float elements attribute from an array of APFloat
1072 // values. Each APFloat value is expected to have the same bitwidth as the
1073 // element type of 'type'.
get(ShapedType type,ArrayRef<APFloat> values)1074 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1075                                          ArrayRef<APFloat> values) {
1076   assert(type.getElementType().isa<FloatType>());
1077   assert(hasSameElementsOrSplat(type, values));
1078   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
1079   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1080 }
1081 DenseElementsAttr
get(ShapedType type,ArrayRef<std::complex<APFloat>> values)1082 DenseElementsAttr::get(ShapedType type,
1083                        ArrayRef<std::complex<APFloat>> values) {
1084   ComplexType complex = type.getElementType().cast<ComplexType>();
1085   assert(complex.getElementType().isa<FloatType>());
1086   assert(hasSameElementsOrSplat(type, values));
1087   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
1088                            values.size() * 2);
1089   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
1090   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals);
1091 }
1092 
1093 /// Construct a dense elements attribute from a raw buffer representing the
1094 /// data for this attribute. Users should generally not use this methods as
1095 /// the expected buffer format may not be a form the user expects.
1096 DenseElementsAttr
getFromRawBuffer(ShapedType type,ArrayRef<char> rawBuffer)1097 DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) {
1098   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer);
1099 }
1100 
1101 /// Returns true if the given buffer is a valid raw buffer for the given type.
isValidRawBuffer(ShapedType type,ArrayRef<char> rawBuffer,bool & detectedSplat)1102 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
1103                                          ArrayRef<char> rawBuffer,
1104                                          bool &detectedSplat) {
1105   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
1106   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
1107   int64_t numElements = type.getNumElements();
1108 
1109   // The initializer is always a splat if the result type has a single element.
1110   detectedSplat = numElements == 1;
1111 
1112   // Storage width of 1 is special as it is packed by the bit.
1113   if (storageWidth == 1) {
1114     // Check for a splat, or a buffer equal to the number of elements which
1115     // consists of either all 0's or all 1's.
1116     if (rawBuffer.size() == 1) {
1117       auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
1118       if (rawByte == 0 || rawByte == 0xff) {
1119         detectedSplat = true;
1120         return true;
1121       }
1122     }
1123 
1124     // This is a valid non-splat buffer if it has the right size.
1125     return rawBufferWidth == llvm::alignTo<8>(numElements);
1126   }
1127 
1128   // All other types are 8-bit aligned, so we can just check the buffer width
1129   // to know if only a single initializer element was passed in.
1130   if (rawBufferWidth == storageWidth) {
1131     detectedSplat = true;
1132     return true;
1133   }
1134 
1135   // The raw buffer is valid if it has the right size.
1136   return rawBufferWidth == storageWidth * numElements;
1137 }
1138 
1139 /// Check the information for a C++ data type, check if this type is valid for
1140 /// the current attribute. This method is used to verify specific type
1141 /// invariants that the templatized 'getValues' method cannot.
isValidIntOrFloat(Type type,int64_t dataEltSize,bool isInt,bool isSigned)1142 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
1143                               bool isSigned) {
1144   // Make sure that the data element size is the same as the type element width.
1145   if (getDenseElementBitWidth(type) !=
1146       static_cast<size_t>(dataEltSize * CHAR_BIT))
1147     return false;
1148 
1149   // Check that the element type is either float or integer or index.
1150   if (!isInt)
1151     return type.isa<FloatType>();
1152   if (type.isIndex())
1153     return true;
1154 
1155   auto intType = type.dyn_cast<IntegerType>();
1156   if (!intType)
1157     return false;
1158 
1159   // Make sure signedness semantics is consistent.
1160   if (intType.isSignless())
1161     return true;
1162   return intType.isSigned() ? isSigned : !isSigned;
1163 }
1164 
1165 /// Defaults down the subclass implementation.
getRawComplex(ShapedType type,ArrayRef<char> data,int64_t dataEltSize,bool isInt,bool isSigned)1166 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
1167                                                    ArrayRef<char> data,
1168                                                    int64_t dataEltSize,
1169                                                    bool isInt, bool isSigned) {
1170   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
1171                                                  isSigned);
1172 }
getRawIntOrFloat(ShapedType type,ArrayRef<char> data,int64_t dataEltSize,bool isInt,bool isSigned)1173 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
1174                                                       ArrayRef<char> data,
1175                                                       int64_t dataEltSize,
1176                                                       bool isInt,
1177                                                       bool isSigned) {
1178   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
1179                                                     isInt, isSigned);
1180 }
1181 
isValidIntOrFloat(int64_t dataEltSize,bool isInt,bool isSigned) const1182 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
1183                                           bool isSigned) const {
1184   return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned);
1185 }
isValidComplex(int64_t dataEltSize,bool isInt,bool isSigned) const1186 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
1187                                        bool isSigned) const {
1188   return ::isValidIntOrFloat(
1189       getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2,
1190       isInt, isSigned);
1191 }
1192 
1193 /// Returns true if this attribute corresponds to a splat, i.e. if all element
1194 /// values are the same.
isSplat() const1195 bool DenseElementsAttr::isSplat() const {
1196   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
1197 }
1198 
1199 /// Return if the given complex type has an integer element type.
isComplexOfIntType(Type type)1200 LLVM_ATTRIBUTE_UNUSED static bool isComplexOfIntType(Type type) {
1201   return type.cast<ComplexType>().getElementType().isa<IntegerType>();
1202 }
1203 
getComplexIntValues() const1204 auto DenseElementsAttr::getComplexIntValues() const
1205     -> iterator_range_impl<ComplexIntElementIterator> {
1206   assert(isComplexOfIntType(getElementType()) &&
1207          "expected complex integral type");
1208   return {getType(), ComplexIntElementIterator(*this, 0),
1209           ComplexIntElementIterator(*this, getNumElements())};
1210 }
complex_value_begin() const1211 auto DenseElementsAttr::complex_value_begin() const
1212     -> ComplexIntElementIterator {
1213   assert(isComplexOfIntType(getElementType()) &&
1214          "expected complex integral type");
1215   return ComplexIntElementIterator(*this, 0);
1216 }
complex_value_end() const1217 auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator {
1218   assert(isComplexOfIntType(getElementType()) &&
1219          "expected complex integral type");
1220   return ComplexIntElementIterator(*this, getNumElements());
1221 }
1222 
1223 /// Return the held element values as a range of APFloat. The element type of
1224 /// this attribute must be of float type.
getFloatValues() const1225 auto DenseElementsAttr::getFloatValues() const
1226     -> iterator_range_impl<FloatElementIterator> {
1227   auto elementType = getElementType().cast<FloatType>();
1228   const auto &elementSemantics = elementType.getFloatSemantics();
1229   return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()),
1230           FloatElementIterator(elementSemantics, raw_int_end())};
1231 }
float_value_begin() const1232 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
1233   auto elementType = getElementType().cast<FloatType>();
1234   return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin());
1235 }
float_value_end() const1236 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
1237   auto elementType = getElementType().cast<FloatType>();
1238   return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end());
1239 }
1240 
getComplexFloatValues() const1241 auto DenseElementsAttr::getComplexFloatValues() const
1242     -> iterator_range_impl<ComplexFloatElementIterator> {
1243   Type eltTy = getElementType().cast<ComplexType>().getElementType();
1244   assert(eltTy.isa<FloatType>() && "expected complex float type");
1245   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
1246   return {getType(),
1247           {semantics, {*this, 0}},
1248           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
1249 }
complex_float_value_begin() const1250 auto DenseElementsAttr::complex_float_value_begin() const
1251     -> ComplexFloatElementIterator {
1252   Type eltTy = getElementType().cast<ComplexType>().getElementType();
1253   assert(eltTy.isa<FloatType>() && "expected complex float type");
1254   return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}};
1255 }
complex_float_value_end() const1256 auto DenseElementsAttr::complex_float_value_end() const
1257     -> ComplexFloatElementIterator {
1258   Type eltTy = getElementType().cast<ComplexType>().getElementType();
1259   assert(eltTy.isa<FloatType>() && "expected complex float type");
1260   return {eltTy.cast<FloatType>().getFloatSemantics(),
1261           {*this, static_cast<size_t>(getNumElements())}};
1262 }
1263 
1264 /// Return the raw storage data held by this attribute.
getRawData() const1265 ArrayRef<char> DenseElementsAttr::getRawData() const {
1266   return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
1267 }
1268 
getRawStringData() const1269 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1270   return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
1271 }
1272 
1273 /// Return a new DenseElementsAttr that has the same data as the current
1274 /// attribute, but has been reshaped to 'newType'. The new type must have the
1275 /// same total number of elements as well as element type.
reshape(ShapedType newType)1276 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1277   ShapedType curType = getType();
1278   if (curType == newType)
1279     return *this;
1280 
1281   assert(newType.getElementType() == curType.getElementType() &&
1282          "expected the same element type");
1283   assert(newType.getNumElements() == curType.getNumElements() &&
1284          "expected the same number of elements");
1285   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1286 }
1287 
resizeSplat(ShapedType newType)1288 DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
1289   assert(isSplat() && "expected a splat type");
1290 
1291   ShapedType curType = getType();
1292   if (curType == newType)
1293     return *this;
1294 
1295   assert(newType.getElementType() == curType.getElementType() &&
1296          "expected the same element type");
1297   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1298 }
1299 
1300 /// Return a new DenseElementsAttr that has the same data as the current
1301 /// attribute, but has bitcast elements such that it is now 'newType'. The new
1302 /// type must have the same shape and element types of the same bitwidth as the
1303 /// current type.
bitcast(Type newElType)1304 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
1305   ShapedType curType = getType();
1306   Type curElType = curType.getElementType();
1307   if (curElType == newElType)
1308     return *this;
1309 
1310   assert(getDenseElementBitWidth(newElType) ==
1311              getDenseElementBitWidth(curElType) &&
1312          "expected element types with the same bitwidth");
1313   return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
1314                                           getRawData());
1315 }
1316 
1317 DenseElementsAttr
mapValues(Type newElementType,function_ref<APInt (const APInt &)> mapping) const1318 DenseElementsAttr::mapValues(Type newElementType,
1319                              function_ref<APInt(const APInt &)> mapping) const {
1320   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1321 }
1322 
mapValues(Type newElementType,function_ref<APInt (const APFloat &)> mapping) const1323 DenseElementsAttr DenseElementsAttr::mapValues(
1324     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1325   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1326 }
1327 
getType() const1328 ShapedType DenseElementsAttr::getType() const {
1329   return Attribute::getType().cast<ShapedType>();
1330 }
1331 
getElementType() const1332 Type DenseElementsAttr::getElementType() const {
1333   return getType().getElementType();
1334 }
1335 
getNumElements() const1336 int64_t DenseElementsAttr::getNumElements() const {
1337   return getType().getNumElements();
1338 }
1339 
1340 //===----------------------------------------------------------------------===//
1341 // DenseIntOrFPElementsAttr
1342 //===----------------------------------------------------------------------===//
1343 
1344 /// Utility method to write a range of APInt values to a buffer.
1345 template <typename APRangeT>
writeAPIntsToBuffer(size_t storageWidth,std::vector<char> & data,APRangeT && values)1346 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1347                                 APRangeT &&values) {
1348   size_t numValues = llvm::size(values);
1349   data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT));
1350   size_t offset = 0;
1351   for (auto it = values.begin(), e = values.end(); it != e;
1352        ++it, offset += storageWidth) {
1353     assert((*it).getBitWidth() <= storageWidth);
1354     writeBits(data.data(), offset, *it);
1355   }
1356 
1357   // Handle the special encoding of splat of a boolean.
1358   if (numValues == 1 && (*values.begin()).getBitWidth() == 1)
1359     data[0] = data[0] ? -1 : 0;
1360 }
1361 
1362 /// Constructs a dense elements attribute from an array of raw APFloat values.
1363 /// Each APFloat value is expected to have the same bitwidth as the element
1364 /// type of 'type'. 'type' must be a vector or tensor with static shape.
getRaw(ShapedType type,size_t storageWidth,ArrayRef<APFloat> values)1365 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1366                                                    size_t storageWidth,
1367                                                    ArrayRef<APFloat> values) {
1368   std::vector<char> data;
1369   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1370   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1371   return DenseIntOrFPElementsAttr::getRaw(type, data);
1372 }
1373 
1374 /// Constructs a dense elements attribute from an array of raw APInt values.
1375 /// Each APInt value is expected to have the same bitwidth as the element type
1376 /// of 'type'.
getRaw(ShapedType type,size_t storageWidth,ArrayRef<APInt> values)1377 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1378                                                    size_t storageWidth,
1379                                                    ArrayRef<APInt> values) {
1380   std::vector<char> data;
1381   writeAPIntsToBuffer(storageWidth, data, values);
1382   return DenseIntOrFPElementsAttr::getRaw(type, data);
1383 }
1384 
getRaw(ShapedType type,ArrayRef<char> data)1385 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1386                                                    ArrayRef<char> data) {
1387   assert((type.isa<RankedTensorType, VectorType>()) &&
1388          "type must be ranked tensor or vector");
1389   assert(type.hasStaticShape() && "type must have static shape");
1390   bool isSplat = false;
1391   bool isValid = isValidRawBuffer(type, data, isSplat);
1392   assert(isValid);
1393   (void)isValid;
1394   return Base::get(type.getContext(), type, data, isSplat);
1395 }
1396 
1397 /// Overload of the raw 'get' method that asserts that the given type is of
1398 /// complex type. This method is used to verify type invariants that the
1399 /// templatized 'get' method cannot.
getRawComplex(ShapedType type,ArrayRef<char> data,int64_t dataEltSize,bool isInt,bool isSigned)1400 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1401                                                           ArrayRef<char> data,
1402                                                           int64_t dataEltSize,
1403                                                           bool isInt,
1404                                                           bool isSigned) {
1405   assert(::isValidIntOrFloat(
1406       type.getElementType().cast<ComplexType>().getElementType(),
1407       dataEltSize / 2, isInt, isSigned));
1408 
1409   int64_t numElements = data.size() / dataEltSize;
1410   (void)numElements;
1411   assert(numElements == 1 || numElements == type.getNumElements());
1412   return getRaw(type, data);
1413 }
1414 
1415 /// Overload of the 'getRaw' method that asserts that the given type is of
1416 /// integer type. This method is used to verify type invariants that the
1417 /// templatized 'get' method cannot.
1418 DenseElementsAttr
getRawIntOrFloat(ShapedType type,ArrayRef<char> data,int64_t dataEltSize,bool isInt,bool isSigned)1419 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1420                                            int64_t dataEltSize, bool isInt,
1421                                            bool isSigned) {
1422   assert(
1423       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1424 
1425   int64_t numElements = data.size() / dataEltSize;
1426   assert(numElements == 1 || numElements == type.getNumElements());
1427   (void)numElements;
1428   return getRaw(type, data);
1429 }
1430 
convertEndianOfCharForBEmachine(const char * inRawData,char * outRawData,size_t elementBitWidth,size_t numElements)1431 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1432     const char *inRawData, char *outRawData, size_t elementBitWidth,
1433     size_t numElements) {
1434   using llvm::support::ulittle16_t;
1435   using llvm::support::ulittle32_t;
1436   using llvm::support::ulittle64_t;
1437 
1438   assert(llvm::support::endian::system_endianness() == // NOLINT
1439          llvm::support::endianness::big);              // NOLINT
1440   // NOLINT to avoid warning message about replacing by static_assert()
1441 
1442   // Following std::copy_n always converts endianness on BE machine.
1443   switch (elementBitWidth) {
1444   case 16: {
1445     const ulittle16_t *inRawDataPos =
1446         reinterpret_cast<const ulittle16_t *>(inRawData);
1447     uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1448     std::copy_n(inRawDataPos, numElements, outDataPos);
1449     break;
1450   }
1451   case 32: {
1452     const ulittle32_t *inRawDataPos =
1453         reinterpret_cast<const ulittle32_t *>(inRawData);
1454     uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1455     std::copy_n(inRawDataPos, numElements, outDataPos);
1456     break;
1457   }
1458   case 64: {
1459     const ulittle64_t *inRawDataPos =
1460         reinterpret_cast<const ulittle64_t *>(inRawData);
1461     uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1462     std::copy_n(inRawDataPos, numElements, outDataPos);
1463     break;
1464   }
1465   default: {
1466     size_t nBytes = elementBitWidth / CHAR_BIT;
1467     for (size_t i = 0; i < nBytes; i++)
1468       std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1469     break;
1470   }
1471   }
1472 }
1473 
convertEndianOfArrayRefForBEmachine(ArrayRef<char> inRawData,MutableArrayRef<char> outRawData,ShapedType type)1474 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1475     ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1476     ShapedType type) {
1477   size_t numElements = type.getNumElements();
1478   Type elementType = type.getElementType();
1479   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1480     elementType = complexTy.getElementType();
1481     numElements = numElements * 2;
1482   }
1483   size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1484   assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1485          inRawData.size() <= outRawData.size());
1486   if (elementBitWidth <= CHAR_BIT)
1487     std::memcpy(outRawData.begin(), inRawData.begin(), inRawData.size());
1488   else
1489     convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1490                                     elementBitWidth, numElements);
1491 }
1492 
1493 //===----------------------------------------------------------------------===//
1494 // DenseFPElementsAttr
1495 //===----------------------------------------------------------------------===//
1496 
1497 template <typename Fn, typename Attr>
mappingHelper(Fn mapping,Attr & attr,ShapedType inType,Type newElementType,llvm::SmallVectorImpl<char> & data)1498 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1499                                 Type newElementType,
1500                                 llvm::SmallVectorImpl<char> &data) {
1501   size_t bitWidth = getDenseElementBitWidth(newElementType);
1502   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1503 
1504   ShapedType newArrayType;
1505   if (inType.isa<RankedTensorType>())
1506     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1507   else if (inType.isa<UnrankedTensorType>())
1508     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1509   else if (auto vType = inType.dyn_cast<VectorType>())
1510     newArrayType = VectorType::get(vType.getShape(), newElementType,
1511                                    vType.getNumScalableDims());
1512   else
1513     assert(newArrayType && "Unhandled tensor type");
1514 
1515   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1516   data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT));
1517 
1518   // Functor used to process a single element value of the attribute.
1519   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1520     auto newInt = mapping(value);
1521     assert(newInt.getBitWidth() == bitWidth);
1522     writeBits(data.data(), index * storageBitWidth, newInt);
1523   };
1524 
1525   // Check for the splat case.
1526   if (attr.isSplat()) {
1527     processElt(*attr.begin(), /*index=*/0);
1528     return newArrayType;
1529   }
1530 
1531   // Otherwise, process all of the element values.
1532   uint64_t elementIdx = 0;
1533   for (auto value : attr)
1534     processElt(value, elementIdx++);
1535   return newArrayType;
1536 }
1537 
mapValues(Type newElementType,function_ref<APInt (const APFloat &)> mapping) const1538 DenseElementsAttr DenseFPElementsAttr::mapValues(
1539     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1540   llvm::SmallVector<char, 8> elementData;
1541   auto newArrayType =
1542       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1543 
1544   return getRaw(newArrayType, elementData);
1545 }
1546 
1547 /// Method for supporting type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)1548 bool DenseFPElementsAttr::classof(Attribute attr) {
1549   return attr.isa<DenseElementsAttr>() &&
1550          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1551 }
1552 
1553 //===----------------------------------------------------------------------===//
1554 // DenseIntElementsAttr
1555 //===----------------------------------------------------------------------===//
1556 
mapValues(Type newElementType,function_ref<APInt (const APInt &)> mapping) const1557 DenseElementsAttr DenseIntElementsAttr::mapValues(
1558     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1559   llvm::SmallVector<char, 8> elementData;
1560   auto newArrayType =
1561       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1562   return getRaw(newArrayType, elementData);
1563 }
1564 
1565 /// Method for supporting type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)1566 bool DenseIntElementsAttr::classof(Attribute attr) {
1567   return attr.isa<DenseElementsAttr>() &&
1568          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1569 }
1570 
1571 //===----------------------------------------------------------------------===//
1572 // OpaqueElementsAttr
1573 //===----------------------------------------------------------------------===//
1574 
decode(ElementsAttr & result)1575 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1576   Dialect *dialect = getContext()->getLoadedDialect(getDialect());
1577   if (!dialect)
1578     return true;
1579   auto *interface = llvm::dyn_cast<DialectDecodeAttributesInterface>(dialect);
1580   if (!interface)
1581     return true;
1582   return failed(interface->decode(*this, result));
1583 }
1584 
1585 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,StringAttr dialect,StringRef value,ShapedType type)1586 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1587                            StringAttr dialect, StringRef value,
1588                            ShapedType type) {
1589   if (!Dialect::isValidNamespace(dialect.strref()))
1590     return emitError() << "invalid dialect namespace '" << dialect << "'";
1591   return success();
1592 }
1593 
1594 //===----------------------------------------------------------------------===//
1595 // SparseElementsAttr
1596 //===----------------------------------------------------------------------===//
1597 
1598 /// Get a zero APFloat for the given sparse attribute.
getZeroAPFloat() const1599 APFloat SparseElementsAttr::getZeroAPFloat() const {
1600   auto eltType = getElementType().cast<FloatType>();
1601   return APFloat(eltType.getFloatSemantics());
1602 }
1603 
1604 /// Get a zero APInt for the given sparse attribute.
getZeroAPInt() const1605 APInt SparseElementsAttr::getZeroAPInt() const {
1606   auto eltType = getElementType().cast<IntegerType>();
1607   return APInt::getZero(eltType.getWidth());
1608 }
1609 
1610 /// Get a zero attribute for the given attribute type.
getZeroAttr() const1611 Attribute SparseElementsAttr::getZeroAttr() const {
1612   auto eltType = getElementType();
1613 
1614   // Handle floating point elements.
1615   if (eltType.isa<FloatType>())
1616     return FloatAttr::get(eltType, 0);
1617 
1618   // Handle complex elements.
1619   if (auto complexTy = eltType.dyn_cast<ComplexType>()) {
1620     auto eltType = complexTy.getElementType();
1621     Attribute zero;
1622     if (eltType.isa<FloatType>())
1623       zero = FloatAttr::get(eltType, 0);
1624     else // must be integer
1625       zero = IntegerAttr::get(eltType, 0);
1626     return ArrayAttr::get(complexTy.getContext(),
1627                           ArrayRef<Attribute>{zero, zero});
1628   }
1629 
1630   // Handle string type.
1631   if (getValues().isa<DenseStringElementsAttr>())
1632     return StringAttr::get("", eltType);
1633 
1634   // Otherwise, this is an integer.
1635   return IntegerAttr::get(eltType, 0);
1636 }
1637 
1638 /// Flatten, and return, all of the sparse indices in this attribute in
1639 /// row-major order.
getFlattenedSparseIndices() const1640 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1641   std::vector<ptrdiff_t> flatSparseIndices;
1642 
1643   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1644   // as a 1-D index array.
1645   auto sparseIndices = getIndices();
1646   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1647   if (sparseIndices.isSplat()) {
1648     SmallVector<uint64_t, 8> indices(getType().getRank(),
1649                                      *sparseIndexValues.begin());
1650     flatSparseIndices.push_back(getFlattenedIndex(indices));
1651     return flatSparseIndices;
1652   }
1653 
1654   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1655   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1656   size_t rank = getType().getRank();
1657   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1658     flatSparseIndices.push_back(getFlattenedIndex(
1659         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1660   return flatSparseIndices;
1661 }
1662 
1663 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,ShapedType type,DenseIntElementsAttr sparseIndices,DenseElementsAttr values)1664 SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1665                            ShapedType type, DenseIntElementsAttr sparseIndices,
1666                            DenseElementsAttr values) {
1667   ShapedType valuesType = values.getType();
1668   if (valuesType.getRank() != 1)
1669     return emitError() << "expected 1-d tensor for sparse element values";
1670 
1671   // Verify the indices and values shape.
1672   ShapedType indicesType = sparseIndices.getType();
1673   auto emitShapeError = [&]() {
1674     return emitError() << "expected shape ([" << type.getShape()
1675                        << "]); inferred shape of indices literal (["
1676                        << indicesType.getShape()
1677                        << "]); inferred shape of values literal (["
1678                        << valuesType.getShape() << "])";
1679   };
1680   // Verify indices shape.
1681   size_t rank = type.getRank(), indicesRank = indicesType.getRank();
1682   if (indicesRank == 2) {
1683     if (indicesType.getDimSize(1) != static_cast<int64_t>(rank))
1684       return emitShapeError();
1685   } else if (indicesRank != 1 || rank != 1) {
1686     return emitShapeError();
1687   }
1688   // Verify the values shape.
1689   int64_t numSparseIndices = indicesType.getDimSize(0);
1690   if (numSparseIndices != valuesType.getDimSize(0))
1691     return emitShapeError();
1692 
1693   // Verify that the sparse indices are within the value shape.
1694   auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) {
1695     return emitError()
1696            << "sparse index #" << indexNum
1697            << " is not contained within the value shape, with index=[" << index
1698            << "], and type=" << type;
1699   };
1700 
1701   // Handle the case where the index values are a splat.
1702   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1703   if (sparseIndices.isSplat()) {
1704     SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin());
1705     if (!ElementsAttr::isValidIndex(type, indices))
1706       return emitIndexError(0, indices);
1707     return success();
1708   }
1709 
1710   // Otherwise, reinterpret each index as an ArrayRef.
1711   for (size_t i = 0, e = numSparseIndices; i != e; ++i) {
1712     ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank),
1713                              rank);
1714     if (!ElementsAttr::isValidIndex(type, index))
1715       return emitIndexError(i, index);
1716   }
1717 
1718   return success();
1719 }
1720 
1721 //===----------------------------------------------------------------------===//
1722 // TypeAttr
1723 //===----------------------------------------------------------------------===//
1724 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const1725 void TypeAttr::walkImmediateSubElements(
1726     function_ref<void(Attribute)> walkAttrsFn,
1727     function_ref<void(Type)> walkTypesFn) const {
1728   walkTypesFn(getValue());
1729 }
1730 
1731 Attribute
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const1732 TypeAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
1733                                       ArrayRef<Type> replTypes) const {
1734   return get(replTypes[0]);
1735 }
1736