1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BuiltinDialect.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/IntegerSet.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/IR/SymbolTable.h"
18 #include "mlir/IR/Types.h"
19 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
20 #include "llvm/ADT/APSInt.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/Support/Endian.h"
23
24 using namespace mlir;
25 using namespace mlir::detail;
26
27 //===----------------------------------------------------------------------===//
28 /// Tablegen Attribute Definitions
29 //===----------------------------------------------------------------------===//
30
31 #define GET_ATTRDEF_CLASSES
32 #include "mlir/IR/BuiltinAttributes.cpp.inc"
33
34 //===----------------------------------------------------------------------===//
35 // BuiltinDialect
36 //===----------------------------------------------------------------------===//
37
registerAttributes()38 void BuiltinDialect::registerAttributes() {
39 addAttributes<AffineMapAttr, ArrayAttr, DenseArrayBaseAttr,
40 DenseIntOrFPElementsAttr, DenseStringElementsAttr,
41 DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr,
42 IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
43 SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>();
44 }
45
46 //===----------------------------------------------------------------------===//
47 // ArrayAttr
48 //===----------------------------------------------------------------------===//
49
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const50 void ArrayAttr::walkImmediateSubElements(
51 function_ref<void(Attribute)> walkAttrsFn,
52 function_ref<void(Type)> walkTypesFn) const {
53 for (Attribute attr : getValue())
54 walkAttrsFn(attr);
55 }
56
57 Attribute
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const58 ArrayAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
59 ArrayRef<Type> replTypes) const {
60 return get(getContext(), replAttrs);
61 }
62
63 //===----------------------------------------------------------------------===//
64 // DictionaryAttr
65 //===----------------------------------------------------------------------===//
66
67 /// Helper function that does either an in place sort or sorts from source array
68 /// into destination. If inPlace then storage is both the source and the
69 /// destination, else value is the source and storage destination. Returns
70 /// whether source was sorted.
71 template <bool inPlace>
dictionaryAttrSort(ArrayRef<NamedAttribute> value,SmallVectorImpl<NamedAttribute> & storage)72 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
73 SmallVectorImpl<NamedAttribute> &storage) {
74 // Specialize for the common case.
75 switch (value.size()) {
76 case 0:
77 // Zero already sorted.
78 if (!inPlace)
79 storage.clear();
80 break;
81 case 1:
82 // One already sorted but may need to be copied.
83 if (!inPlace)
84 storage.assign({value[0]});
85 break;
86 case 2: {
87 bool isSorted = value[0] < value[1];
88 if (inPlace) {
89 if (!isSorted)
90 std::swap(storage[0], storage[1]);
91 } else if (isSorted) {
92 storage.assign({value[0], value[1]});
93 } else {
94 storage.assign({value[1], value[0]});
95 }
96 return !isSorted;
97 }
98 default:
99 if (!inPlace)
100 storage.assign(value.begin(), value.end());
101 // Check to see they are sorted already.
102 bool isSorted = llvm::is_sorted(value);
103 // If not, do a general sort.
104 if (!isSorted)
105 llvm::array_pod_sort(storage.begin(), storage.end());
106 return !isSorted;
107 }
108 return false;
109 }
110
111 /// Returns an entry with a duplicate name from the given sorted array of named
112 /// attributes. Returns llvm::None if all elements have unique names.
113 static Optional<NamedAttribute>
findDuplicateElement(ArrayRef<NamedAttribute> value)114 findDuplicateElement(ArrayRef<NamedAttribute> value) {
115 const Optional<NamedAttribute> none{llvm::None};
116 if (value.size() < 2)
117 return none;
118
119 if (value.size() == 2)
120 return value[0].getName() == value[1].getName() ? value[0] : none;
121
122 const auto *it = std::adjacent_find(value.begin(), value.end(),
123 [](NamedAttribute l, NamedAttribute r) {
124 return l.getName() == r.getName();
125 });
126 return it != value.end() ? *it : none;
127 }
128
sort(ArrayRef<NamedAttribute> value,SmallVectorImpl<NamedAttribute> & storage)129 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
130 SmallVectorImpl<NamedAttribute> &storage) {
131 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
132 assert(!findDuplicateElement(storage) &&
133 "DictionaryAttr element names must be unique");
134 return isSorted;
135 }
136
sortInPlace(SmallVectorImpl<NamedAttribute> & array)137 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
138 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
139 assert(!findDuplicateElement(array) &&
140 "DictionaryAttr element names must be unique");
141 return isSorted;
142 }
143
144 Optional<NamedAttribute>
findDuplicate(SmallVectorImpl<NamedAttribute> & array,bool isSorted)145 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
146 bool isSorted) {
147 if (!isSorted)
148 dictionaryAttrSort</*inPlace=*/true>(array, array);
149 return findDuplicateElement(array);
150 }
151
get(MLIRContext * context,ArrayRef<NamedAttribute> value)152 DictionaryAttr DictionaryAttr::get(MLIRContext *context,
153 ArrayRef<NamedAttribute> value) {
154 if (value.empty())
155 return DictionaryAttr::getEmpty(context);
156
157 // We need to sort the element list to canonicalize it.
158 SmallVector<NamedAttribute, 8> storage;
159 if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
160 value = storage;
161 assert(!findDuplicateElement(value) &&
162 "DictionaryAttr element names must be unique");
163 return Base::get(context, value);
164 }
165 /// Construct a dictionary with an array of values that is known to already be
166 /// sorted by name and uniqued.
getWithSorted(MLIRContext * context,ArrayRef<NamedAttribute> value)167 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
168 ArrayRef<NamedAttribute> value) {
169 if (value.empty())
170 return DictionaryAttr::getEmpty(context);
171 // Ensure that the attribute elements are unique and sorted.
172 assert(llvm::is_sorted(
173 value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) &&
174 "expected attribute values to be sorted");
175 assert(!findDuplicateElement(value) &&
176 "DictionaryAttr element names must be unique");
177 return Base::get(context, value);
178 }
179
180 /// Return the specified attribute if present, null otherwise.
get(StringRef name) const181 Attribute DictionaryAttr::get(StringRef name) const {
182 auto it = impl::findAttrSorted(begin(), end(), name);
183 return it.second ? it.first->getValue() : Attribute();
184 }
get(StringAttr name) const185 Attribute DictionaryAttr::get(StringAttr name) const {
186 auto it = impl::findAttrSorted(begin(), end(), name);
187 return it.second ? it.first->getValue() : Attribute();
188 }
189
190 /// Return the specified named attribute if present, None otherwise.
getNamed(StringRef name) const191 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
192 auto it = impl::findAttrSorted(begin(), end(), name);
193 return it.second ? *it.first : Optional<NamedAttribute>();
194 }
getNamed(StringAttr name) const195 Optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const {
196 auto it = impl::findAttrSorted(begin(), end(), name);
197 return it.second ? *it.first : Optional<NamedAttribute>();
198 }
199
200 /// Return whether the specified attribute is present.
contains(StringRef name) const201 bool DictionaryAttr::contains(StringRef name) const {
202 return impl::findAttrSorted(begin(), end(), name).second;
203 }
contains(StringAttr name) const204 bool DictionaryAttr::contains(StringAttr name) const {
205 return impl::findAttrSorted(begin(), end(), name).second;
206 }
207
begin() const208 DictionaryAttr::iterator DictionaryAttr::begin() const {
209 return getValue().begin();
210 }
end() const211 DictionaryAttr::iterator DictionaryAttr::end() const {
212 return getValue().end();
213 }
size() const214 size_t DictionaryAttr::size() const { return getValue().size(); }
215
getEmptyUnchecked(MLIRContext * context)216 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
217 return Base::get(context, ArrayRef<NamedAttribute>());
218 }
219
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const220 void DictionaryAttr::walkImmediateSubElements(
221 function_ref<void(Attribute)> walkAttrsFn,
222 function_ref<void(Type)> walkTypesFn) const {
223 for (const NamedAttribute &attr : getValue())
224 walkAttrsFn(attr.getValue());
225 }
226
227 Attribute
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const228 DictionaryAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
229 ArrayRef<Type> replTypes) const {
230 std::vector<NamedAttribute> vec = getValue().vec();
231 for (auto &it : llvm::enumerate(replAttrs))
232 vec[it.index()].setValue(it.value());
233
234 // The above only modifies the mapped value, but not the key, and therefore
235 // not the order of the elements. It remains sorted
236 return getWithSorted(getContext(), vec);
237 }
238
239 //===----------------------------------------------------------------------===//
240 // StringAttr
241 //===----------------------------------------------------------------------===//
242
getEmptyStringAttrUnchecked(MLIRContext * context)243 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) {
244 return Base::get(context, "", NoneType::get(context));
245 }
246
247 /// Twine support for StringAttr.
get(MLIRContext * context,const Twine & twine)248 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) {
249 // Fast-path empty twine.
250 if (twine.isTriviallyEmpty())
251 return get(context);
252 SmallVector<char, 32> tempStr;
253 return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context));
254 }
255
256 /// Twine support for StringAttr.
get(const Twine & twine,Type type)257 StringAttr StringAttr::get(const Twine &twine, Type type) {
258 SmallVector<char, 32> tempStr;
259 return Base::get(type.getContext(), twine.toStringRef(tempStr), type);
260 }
261
getValue() const262 StringRef StringAttr::getValue() const { return getImpl()->value; }
263
getReferencedDialect() const264 Dialect *StringAttr::getReferencedDialect() const {
265 return getImpl()->referencedDialect;
266 }
267
268 //===----------------------------------------------------------------------===//
269 // FloatAttr
270 //===----------------------------------------------------------------------===//
271
getValueAsDouble() const272 double FloatAttr::getValueAsDouble() const {
273 return getValueAsDouble(getValue());
274 }
getValueAsDouble(APFloat value)275 double FloatAttr::getValueAsDouble(APFloat value) {
276 if (&value.getSemantics() != &APFloat::IEEEdouble()) {
277 bool losesInfo = false;
278 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
279 &losesInfo);
280 }
281 return value.convertToDouble();
282 }
283
verify(function_ref<InFlightDiagnostic ()> emitError,Type type,APFloat value)284 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
285 Type type, APFloat value) {
286 // Verify that the type is correct.
287 if (!type.isa<FloatType>())
288 return emitError() << "expected floating point type";
289
290 // Verify that the type semantics match that of the value.
291 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
292 return emitError()
293 << "FloatAttr type doesn't match the type implied by its value";
294 }
295 return success();
296 }
297
298 //===----------------------------------------------------------------------===//
299 // SymbolRefAttr
300 //===----------------------------------------------------------------------===//
301
get(MLIRContext * ctx,StringRef value,ArrayRef<FlatSymbolRefAttr> nestedRefs)302 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
303 ArrayRef<FlatSymbolRefAttr> nestedRefs) {
304 return get(StringAttr::get(ctx, value), nestedRefs);
305 }
306
get(MLIRContext * ctx,StringRef value)307 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
308 return get(ctx, value, {}).cast<FlatSymbolRefAttr>();
309 }
310
get(StringAttr value)311 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
312 return get(value, {}).cast<FlatSymbolRefAttr>();
313 }
314
get(Operation * symbol)315 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) {
316 auto symName =
317 symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
318 assert(symName && "value does not have a valid symbol name");
319 return SymbolRefAttr::get(symName);
320 }
321
getLeafReference() const322 StringAttr SymbolRefAttr::getLeafReference() const {
323 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
324 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
325 }
326
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const327 void SymbolRefAttr::walkImmediateSubElements(
328 function_ref<void(Attribute)> walkAttrsFn,
329 function_ref<void(Type)> walkTypesFn) const {
330 walkAttrsFn(getRootReference());
331 for (FlatSymbolRefAttr ref : getNestedReferences())
332 walkAttrsFn(ref);
333 }
334
335 Attribute
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const336 SymbolRefAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
337 ArrayRef<Type> replTypes) const {
338 ArrayRef<Attribute> rawNestedRefs = replAttrs.drop_front();
339 ArrayRef<FlatSymbolRefAttr> nestedRefs(
340 static_cast<const FlatSymbolRefAttr *>(rawNestedRefs.data()),
341 rawNestedRefs.size());
342 return get(replAttrs[0].cast<StringAttr>(), nestedRefs);
343 }
344
345 //===----------------------------------------------------------------------===//
346 // IntegerAttr
347 //===----------------------------------------------------------------------===//
348
getInt() const349 int64_t IntegerAttr::getInt() const {
350 assert((getType().isIndex() || getType().isSignlessInteger()) &&
351 "must be signless integer");
352 return getValue().getSExtValue();
353 }
354
getSInt() const355 int64_t IntegerAttr::getSInt() const {
356 assert(getType().isSignedInteger() && "must be signed integer");
357 return getValue().getSExtValue();
358 }
359
getUInt() const360 uint64_t IntegerAttr::getUInt() const {
361 assert(getType().isUnsignedInteger() && "must be unsigned integer");
362 return getValue().getZExtValue();
363 }
364
365 /// Return the value as an APSInt which carries the signed from the type of
366 /// the attribute. This traps on signless integers types!
getAPSInt() const367 APSInt IntegerAttr::getAPSInt() const {
368 assert(!getType().isSignlessInteger() &&
369 "Signless integers don't carry a sign for APSInt");
370 return APSInt(getValue(), getType().isUnsignedInteger());
371 }
372
verify(function_ref<InFlightDiagnostic ()> emitError,Type type,APInt value)373 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
374 Type type, APInt value) {
375 if (IntegerType integerType = type.dyn_cast<IntegerType>()) {
376 if (integerType.getWidth() != value.getBitWidth())
377 return emitError() << "integer type bit width (" << integerType.getWidth()
378 << ") doesn't match value bit width ("
379 << value.getBitWidth() << ")";
380 return success();
381 }
382 if (type.isa<IndexType>())
383 return success();
384 return emitError() << "expected integer or index type";
385 }
386
getBoolAttrUnchecked(IntegerType type,bool value)387 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
388 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
389 return attr.cast<BoolAttr>();
390 }
391
392 //===----------------------------------------------------------------------===//
393 // BoolAttr
394 //===----------------------------------------------------------------------===//
395
getValue() const396 bool BoolAttr::getValue() const {
397 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
398 return storage->value.getBoolValue();
399 }
400
classof(Attribute attr)401 bool BoolAttr::classof(Attribute attr) {
402 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
403 return intAttr && intAttr.getType().isSignlessInteger(1);
404 }
405
406 //===----------------------------------------------------------------------===//
407 // OpaqueAttr
408 //===----------------------------------------------------------------------===//
409
verify(function_ref<InFlightDiagnostic ()> emitError,StringAttr dialect,StringRef attrData,Type type)410 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
411 StringAttr dialect, StringRef attrData,
412 Type type) {
413 if (!Dialect::isValidNamespace(dialect.strref()))
414 return emitError() << "invalid dialect namespace '" << dialect << "'";
415
416 // Check that the dialect is actually registered.
417 MLIRContext *context = dialect.getContext();
418 if (!context->allowsUnregisteredDialects() &&
419 !context->getLoadedDialect(dialect.strref())) {
420 return emitError()
421 << "#" << dialect << "<\"" << attrData << "\"> : " << type
422 << " attribute created with unregistered dialect. If this is "
423 "intended, please call allowUnregisteredDialects() on the "
424 "MLIRContext, or use -allow-unregistered-dialect with "
425 "the MLIR opt tool used";
426 }
427
428 return success();
429 }
430
431 //===----------------------------------------------------------------------===//
432 // DenseElementsAttr Utilities
433 //===----------------------------------------------------------------------===//
434
435 /// Get the bitwidth of a dense element type within the buffer.
436 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
getDenseElementStorageWidth(size_t origWidth)437 static size_t getDenseElementStorageWidth(size_t origWidth) {
438 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
439 }
getDenseElementStorageWidth(Type elementType)440 static size_t getDenseElementStorageWidth(Type elementType) {
441 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
442 }
443
444 /// Set a bit to a specific value.
setBit(char * rawData,size_t bitPos,bool value)445 static void setBit(char *rawData, size_t bitPos, bool value) {
446 if (value)
447 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
448 else
449 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
450 }
451
452 /// Return the value of the specified bit.
getBit(const char * rawData,size_t bitPos)453 static bool getBit(const char *rawData, size_t bitPos) {
454 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
455 }
456
457 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
458 /// BE format.
copyAPIntToArrayForBEmachine(APInt value,size_t numBytes,char * result)459 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
460 char *result) {
461 assert(llvm::support::endian::system_endianness() == // NOLINT
462 llvm::support::endianness::big); // NOLINT
463 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
464
465 // Copy the words filled with data.
466 // For example, when `value` has 2 words, the first word is filled with data.
467 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
468 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
469 std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
470 numFilledWords, result);
471 // Convert last word of APInt to LE format and store it in char
472 // array(`valueLE`).
473 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------|
474 size_t lastWordPos = numFilledWords;
475 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
476 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
477 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
478 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
479 // Extract actual APInt data from `valueLE`, convert endianness to BE format,
480 // and store it in `result`.
481 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij|
482 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
483 valueLE.begin(), result + lastWordPos,
484 (numBytes - lastWordPos) * CHAR_BIT, 1);
485 }
486
487 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
488 /// format.
copyArrayToAPIntForBEmachine(const char * inArray,size_t numBytes,APInt & result)489 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
490 APInt &result) {
491 assert(llvm::support::endian::system_endianness() == // NOLINT
492 llvm::support::endianness::big); // NOLINT
493 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
494
495 // Copy the data that fills the word of `result` from `inArray`.
496 // For example, when `result` has 2 words, the first word will be filled with
497 // data. So, the first 8 bytes are copied from `inArray` here.
498 // `inArray` (10 bytes, BE): |abcdefgh|ij|
499 // ==> `result` (2 words, BE): |abcdefgh|--------|
500 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
501 std::copy_n(
502 inArray, numFilledWords,
503 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
504
505 // Convert array data which will be last word of `result` to LE format, and
506 // store it in char array(`inArrayLE`).
507 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------|
508 size_t lastWordPos = numFilledWords;
509 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
510 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
511 inArray + lastWordPos, inArrayLE.begin(),
512 (numBytes - lastWordPos) * CHAR_BIT, 1);
513
514 // Convert `inArrayLE` to BE format, and store it in last word of `result`.
515 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij|
516 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
517 inArrayLE.begin(),
518 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
519 lastWordPos,
520 APInt::APINT_BITS_PER_WORD, 1);
521 }
522
523 /// Writes value to the bit position `bitPos` in array `rawData`.
writeBits(char * rawData,size_t bitPos,APInt value)524 static void writeBits(char *rawData, size_t bitPos, APInt value) {
525 size_t bitWidth = value.getBitWidth();
526
527 // If the bitwidth is 1 we just toggle the specific bit.
528 if (bitWidth == 1)
529 return setBit(rawData, bitPos, value.isOneValue());
530
531 // Otherwise, the bit position is guaranteed to be byte aligned.
532 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
533 if (llvm::support::endian::system_endianness() ==
534 llvm::support::endianness::big) {
535 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
536 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
537 // work correctly in BE format.
538 // ex. `value` (2 words including 10 bytes)
539 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------|
540 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
541 rawData + (bitPos / CHAR_BIT));
542 } else {
543 std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
544 llvm::divideCeil(bitWidth, CHAR_BIT),
545 rawData + (bitPos / CHAR_BIT));
546 }
547 }
548
549 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
550 /// `rawData`.
readBits(const char * rawData,size_t bitPos,size_t bitWidth)551 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
552 // Handle a boolean bit position.
553 if (bitWidth == 1)
554 return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
555
556 // Otherwise, the bit position must be 8-bit aligned.
557 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
558 APInt result(bitWidth, 0);
559 if (llvm::support::endian::system_endianness() ==
560 llvm::support::endianness::big) {
561 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
562 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
563 // work correctly in BE format.
564 // ex. `result` (2 words including 10 bytes)
565 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function
566 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
567 llvm::divideCeil(bitWidth, CHAR_BIT), result);
568 } else {
569 std::copy_n(rawData + (bitPos / CHAR_BIT),
570 llvm::divideCeil(bitWidth, CHAR_BIT),
571 const_cast<char *>(
572 reinterpret_cast<const char *>(result.getRawData())));
573 }
574 return result;
575 }
576
577 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
578 /// the same element count as 'type'.
579 template <typename Values>
hasSameElementsOrSplat(ShapedType type,const Values & values)580 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
581 return (values.size() == 1) ||
582 (type.getNumElements() == static_cast<int64_t>(values.size()));
583 }
584
585 //===----------------------------------------------------------------------===//
586 // DenseElementsAttr Iterators
587 //===----------------------------------------------------------------------===//
588
589 //===----------------------------------------------------------------------===//
590 // AttributeElementIterator
591
AttributeElementIterator(DenseElementsAttr attr,size_t index)592 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
593 DenseElementsAttr attr, size_t index)
594 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
595 Attribute, Attribute, Attribute>(
596 attr.getAsOpaquePointer(), index) {}
597
operator *() const598 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
599 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
600 Type eltTy = owner.getElementType();
601 if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
602 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
603 if (eltTy.isa<IndexType>())
604 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
605 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
606 IntElementIterator intIt(owner, index);
607 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
608 return FloatAttr::get(eltTy, *floatIt);
609 }
610 if (auto complexTy = eltTy.dyn_cast<ComplexType>()) {
611 auto complexEltTy = complexTy.getElementType();
612 ComplexIntElementIterator complexIntIt(owner, index);
613 if (complexEltTy.isa<IntegerType>()) {
614 auto value = *complexIntIt;
615 auto real = IntegerAttr::get(complexEltTy, value.real());
616 auto imag = IntegerAttr::get(complexEltTy, value.imag());
617 return ArrayAttr::get(complexTy.getContext(),
618 ArrayRef<Attribute>{real, imag});
619 }
620
621 ComplexFloatElementIterator complexFloatIt(
622 complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt);
623 auto value = *complexFloatIt;
624 auto real = FloatAttr::get(complexEltTy, value.real());
625 auto imag = FloatAttr::get(complexEltTy, value.imag());
626 return ArrayAttr::get(complexTy.getContext(),
627 ArrayRef<Attribute>{real, imag});
628 }
629 if (owner.isa<DenseStringElementsAttr>()) {
630 ArrayRef<StringRef> vals = owner.getRawStringData();
631 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
632 }
633 llvm_unreachable("unexpected element type");
634 }
635
636 //===----------------------------------------------------------------------===//
637 // BoolElementIterator
638
BoolElementIterator(DenseElementsAttr attr,size_t dataIndex)639 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
640 DenseElementsAttr attr, size_t dataIndex)
641 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
642 attr.getRawData().data(), attr.isSplat(), dataIndex) {}
643
operator *() const644 bool DenseElementsAttr::BoolElementIterator::operator*() const {
645 return getBit(getData(), getDataIndex());
646 }
647
648 //===----------------------------------------------------------------------===//
649 // IntElementIterator
650
IntElementIterator(DenseElementsAttr attr,size_t dataIndex)651 DenseElementsAttr::IntElementIterator::IntElementIterator(
652 DenseElementsAttr attr, size_t dataIndex)
653 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
654 attr.getRawData().data(), attr.isSplat(), dataIndex),
655 bitWidth(getDenseElementBitWidth(attr.getElementType())) {}
656
operator *() const657 APInt DenseElementsAttr::IntElementIterator::operator*() const {
658 return readBits(getData(),
659 getDataIndex() * getDenseElementStorageWidth(bitWidth),
660 bitWidth);
661 }
662
663 //===----------------------------------------------------------------------===//
664 // ComplexIntElementIterator
665
ComplexIntElementIterator(DenseElementsAttr attr,size_t dataIndex)666 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
667 DenseElementsAttr attr, size_t dataIndex)
668 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
669 std::complex<APInt>, std::complex<APInt>,
670 std::complex<APInt>>(
671 attr.getRawData().data(), attr.isSplat(), dataIndex) {
672 auto complexType = attr.getElementType().cast<ComplexType>();
673 bitWidth = getDenseElementBitWidth(complexType.getElementType());
674 }
675
676 std::complex<APInt>
operator *() const677 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
678 size_t storageWidth = getDenseElementStorageWidth(bitWidth);
679 size_t offset = getDataIndex() * storageWidth * 2;
680 return {readBits(getData(), offset, bitWidth),
681 readBits(getData(), offset + storageWidth, bitWidth)};
682 }
683
684 //===----------------------------------------------------------------------===//
685 // DenseArrayAttr
686 //===----------------------------------------------------------------------===//
687
688 /// Custom storage to ensure proper memory alignment for the allocation of
689 /// DenseArray of any element type.
690 struct mlir::detail::DenseArrayBaseAttrStorage : public AttributeStorage {
691 using KeyTy = std::tuple<ShapedType, DenseArrayBaseAttr::EltType,
692 ::llvm::ArrayRef<char>>;
DenseArrayBaseAttrStoragemlir::detail::DenseArrayBaseAttrStorage693 DenseArrayBaseAttrStorage(ShapedType type,
694 DenseArrayBaseAttr::EltType eltType,
695 ::llvm::ArrayRef<char> elements)
696 : AttributeStorage(type), eltType(eltType), elements(elements) {}
697
operator ==mlir::detail::DenseArrayBaseAttrStorage698 bool operator==(const KeyTy &tblgenKey) const {
699 return (getType() == std::get<0>(tblgenKey)) &&
700 (eltType == std::get<1>(tblgenKey)) &&
701 (elements == std::get<2>(tblgenKey));
702 }
703
hashKeymlir::detail::DenseArrayBaseAttrStorage704 static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) {
705 return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey),
706 std::get<2>(tblgenKey));
707 }
708
709 static DenseArrayBaseAttrStorage *
constructmlir::detail::DenseArrayBaseAttrStorage710 construct(AttributeStorageAllocator &allocator, const KeyTy &tblgenKey) {
711 auto type = std::get<0>(tblgenKey);
712 auto eltType = std::get<1>(tblgenKey);
713 auto elements = std::get<2>(tblgenKey);
714 if (!elements.empty()) {
715 char *alloc = static_cast<char *>(
716 allocator.allocate(elements.size(), alignof(uint64_t)));
717 std::uninitialized_copy(elements.begin(), elements.end(), alloc);
718 elements = ArrayRef<char>(alloc, elements.size());
719 }
720 return new (allocator.allocate<DenseArrayBaseAttrStorage>())
721 DenseArrayBaseAttrStorage(type, eltType, elements);
722 }
723
724 DenseArrayBaseAttr::EltType eltType;
725 ::llvm::ArrayRef<char> elements;
726 };
727
getElementType() const728 DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const {
729 return getImpl()->eltType;
730 }
731
732 const int8_t *
value_begin_impl(OverloadToken<int8_t>) const733 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
734 return cast<DenseI8ArrayAttr>().asArrayRef().begin();
735 }
736 const int16_t *
value_begin_impl(OverloadToken<int16_t>) const737 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int16_t>) const {
738 return cast<DenseI16ArrayAttr>().asArrayRef().begin();
739 }
740 const int32_t *
value_begin_impl(OverloadToken<int32_t>) const741 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int32_t>) const {
742 return cast<DenseI32ArrayAttr>().asArrayRef().begin();
743 }
744 const int64_t *
value_begin_impl(OverloadToken<int64_t>) const745 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int64_t>) const {
746 return cast<DenseI64ArrayAttr>().asArrayRef().begin();
747 }
value_begin_impl(OverloadToken<float>) const748 const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken<float>) const {
749 return cast<DenseF32ArrayAttr>().asArrayRef().begin();
750 }
751 const double *
value_begin_impl(OverloadToken<double>) const752 DenseArrayBaseAttr::value_begin_impl(OverloadToken<double>) const {
753 return cast<DenseF64ArrayAttr>().asArrayRef().begin();
754 }
755
print(AsmPrinter & printer) const756 void DenseArrayBaseAttr::print(AsmPrinter &printer) const {
757 print(printer.getStream());
758 }
759
printWithoutBraces(raw_ostream & os) const760 void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const {
761 switch (getElementType()) {
762 case DenseArrayBaseAttr::EltType::I8:
763 this->cast<DenseI8ArrayAttr>().printWithoutBraces(os);
764 return;
765 case DenseArrayBaseAttr::EltType::I16:
766 this->cast<DenseI16ArrayAttr>().printWithoutBraces(os);
767 return;
768 case DenseArrayBaseAttr::EltType::I32:
769 this->cast<DenseI32ArrayAttr>().printWithoutBraces(os);
770 return;
771 case DenseArrayBaseAttr::EltType::I64:
772 this->cast<DenseI64ArrayAttr>().printWithoutBraces(os);
773 return;
774 case DenseArrayBaseAttr::EltType::F32:
775 this->cast<DenseF32ArrayAttr>().printWithoutBraces(os);
776 return;
777 case DenseArrayBaseAttr::EltType::F64:
778 this->cast<DenseF64ArrayAttr>().printWithoutBraces(os);
779 return;
780 }
781 llvm_unreachable("<unknown DenseArrayBaseAttr>");
782 }
783
print(raw_ostream & os) const784 void DenseArrayBaseAttr::print(raw_ostream &os) const {
785 os << "[";
786 printWithoutBraces(os);
787 os << "]";
788 }
789
790 template <typename T>
print(AsmPrinter & printer) const791 void DenseArrayAttr<T>::print(AsmPrinter &printer) const {
792 print(printer.getStream());
793 }
794
795 template <typename T>
printWithoutBraces(raw_ostream & os) const796 void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const {
797 ArrayRef<T> values{*this};
798 llvm::interleaveComma(values, os);
799 }
800
801 /// Specialization for int8_t for forcing printing as number instead of chars.
802 template <>
printWithoutBraces(raw_ostream & os) const803 void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const {
804 ArrayRef<int8_t> values{*this};
805 llvm::interleaveComma(values, os, [&](int64_t v) { os << v; });
806 }
807
808 template <typename T>
print(raw_ostream & os) const809 void DenseArrayAttr<T>::print(raw_ostream &os) const {
810 os << "[";
811 printWithoutBraces(os);
812 os << "]";
813 }
814
815 /// Parse a single element: generic template for int types, specialized for
816 /// floating points below.
817 template <typename T>
parseDenseArrayAttrElt(AsmParser & parser,T & value)818 static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) {
819 return parser.parseInteger(value);
820 }
821
822 template <>
parseDenseArrayAttrElt(AsmParser & parser,float & value)823 ParseResult parseDenseArrayAttrElt<float>(AsmParser &parser, float &value) {
824 double doubleVal;
825 if (parser.parseFloat(doubleVal))
826 return failure();
827 value = doubleVal;
828 return success();
829 }
830
831 template <>
parseDenseArrayAttrElt(AsmParser & parser,double & value)832 ParseResult parseDenseArrayAttrElt<double>(AsmParser &parser, double &value) {
833 return parser.parseFloat(value);
834 }
835
836 /// Parse a DenseArrayAttr without the braces: `1, 2, 3`
837 template <typename T>
parseWithoutBraces(AsmParser & parser,Type odsType)838 Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
839 Type odsType) {
840 SmallVector<T> data;
841 if (failed(parser.parseCommaSeparatedList([&]() {
842 T value;
843 if (parseDenseArrayAttrElt(parser, value))
844 return failure();
845 data.push_back(value);
846 return success();
847 })))
848 return {};
849 return get(parser.getContext(), data);
850 }
851
852 /// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
853 template <typename T>
parse(AsmParser & parser,Type odsType)854 Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
855 if (parser.parseLSquare())
856 return {};
857 // Handle empty list case.
858 if (succeeded(parser.parseOptionalRSquare()))
859 return get(parser.getContext(), {});
860 Attribute result = parseWithoutBraces(parser, odsType);
861 if (parser.parseRSquare())
862 return {};
863 return result;
864 }
865
866 /// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
867 template <typename T>
operator ArrayRef<T>() const868 DenseArrayAttr<T>::operator ArrayRef<T>() const {
869 ArrayRef<char> raw = getImpl()->elements;
870 assert((raw.size() % sizeof(T)) == 0);
871 return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
872 raw.size() / sizeof(T));
873 }
874
875 namespace {
876 /// Mapping from C++ element type to MLIR DenseArrayAttr internals.
877 template <typename T>
878 struct denseArrayAttrEltTypeBuilder;
879 template <>
880 struct denseArrayAttrEltTypeBuilder<int8_t> {
881 constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8;
getShapedType__anonfb591a410511::denseArrayAttrEltTypeBuilder882 static ShapedType getShapedType(MLIRContext *context,
883 ArrayRef<int64_t> shape) {
884 return VectorType::get(shape, IntegerType::get(context, 8));
885 }
886 };
887 template <>
888 struct denseArrayAttrEltTypeBuilder<int16_t> {
889 constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16;
getShapedType__anonfb591a410511::denseArrayAttrEltTypeBuilder890 static ShapedType getShapedType(MLIRContext *context,
891 ArrayRef<int64_t> shape) {
892 return VectorType::get(shape, IntegerType::get(context, 16));
893 }
894 };
895 template <>
896 struct denseArrayAttrEltTypeBuilder<int32_t> {
897 constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32;
getShapedType__anonfb591a410511::denseArrayAttrEltTypeBuilder898 static ShapedType getShapedType(MLIRContext *context,
899 ArrayRef<int64_t> shape) {
900 return VectorType::get(shape, IntegerType::get(context, 32));
901 }
902 };
903 template <>
904 struct denseArrayAttrEltTypeBuilder<int64_t> {
905 constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64;
getShapedType__anonfb591a410511::denseArrayAttrEltTypeBuilder906 static ShapedType getShapedType(MLIRContext *context,
907 ArrayRef<int64_t> shape) {
908 return VectorType::get(shape, IntegerType::get(context, 64));
909 }
910 };
911 template <>
912 struct denseArrayAttrEltTypeBuilder<float> {
913 constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32;
getShapedType__anonfb591a410511::denseArrayAttrEltTypeBuilder914 static ShapedType getShapedType(MLIRContext *context,
915 ArrayRef<int64_t> shape) {
916 return VectorType::get(shape, Float32Type::get(context));
917 }
918 };
919 template <>
920 struct denseArrayAttrEltTypeBuilder<double> {
921 constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64;
getShapedType__anonfb591a410511::denseArrayAttrEltTypeBuilder922 static ShapedType getShapedType(MLIRContext *context,
923 ArrayRef<int64_t> shape) {
924 return VectorType::get(shape, Float64Type::get(context));
925 }
926 };
927 } // namespace
928
929 /// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
930 template <typename T>
get(MLIRContext * context,ArrayRef<T> content)931 DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context,
932 ArrayRef<T> content) {
933 auto size = static_cast<int64_t>(content.size());
934 auto shapedType = denseArrayAttrEltTypeBuilder<T>::getShapedType(
935 context, size ? ArrayRef<int64_t>{size} : ArrayRef<int64_t>{});
936 auto eltType = denseArrayAttrEltTypeBuilder<T>::eltType;
937 auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
938 content.size() * sizeof(T));
939 return Base::get(context, shapedType, eltType, rawArray)
940 .template cast<DenseArrayAttr<T>>();
941 }
942
943 template <typename T>
classof(Attribute attr)944 bool DenseArrayAttr<T>::classof(Attribute attr) {
945 return attr.isa<DenseArrayBaseAttr>() &&
946 attr.cast<DenseArrayBaseAttr>().getElementType() ==
947 denseArrayAttrEltTypeBuilder<T>::eltType;
948 }
949
950 namespace mlir {
951 namespace detail {
952 // Explicit instantiation for all the supported DenseArrayAttr.
953 template class DenseArrayAttr<int8_t>;
954 template class DenseArrayAttr<int16_t>;
955 template class DenseArrayAttr<int32_t>;
956 template class DenseArrayAttr<int64_t>;
957 template class DenseArrayAttr<float>;
958 template class DenseArrayAttr<double>;
959 } // namespace detail
960 } // namespace mlir
961
962 //===----------------------------------------------------------------------===//
963 // DenseElementsAttr
964 //===----------------------------------------------------------------------===//
965
966 /// Method for support type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)967 bool DenseElementsAttr::classof(Attribute attr) {
968 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
969 }
970
get(ShapedType type,ArrayRef<Attribute> values)971 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
972 ArrayRef<Attribute> values) {
973 assert(hasSameElementsOrSplat(type, values));
974
975 // If the element type is not based on int/float/index, assume it is a string
976 // type.
977 auto eltType = type.getElementType();
978 if (!type.getElementType().isIntOrIndexOrFloat()) {
979 SmallVector<StringRef, 8> stringValues;
980 stringValues.reserve(values.size());
981 for (Attribute attr : values) {
982 assert(attr.isa<StringAttr>() &&
983 "expected string value for non integer/index/float element");
984 stringValues.push_back(attr.cast<StringAttr>().getValue());
985 }
986 return get(type, stringValues);
987 }
988
989 // Otherwise, get the raw storage width to use for the allocation.
990 size_t bitWidth = getDenseElementBitWidth(eltType);
991 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
992
993 // Compress the attribute values into a character buffer.
994 SmallVector<char, 8> data(
995 llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT));
996 APInt intVal;
997 for (unsigned i = 0, e = values.size(); i < e; ++i) {
998 assert(eltType == values[i].getType() &&
999 "expected attribute value to have element type");
1000 if (eltType.isa<FloatType>())
1001 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
1002 else if (eltType.isa<IntegerType, IndexType>())
1003 intVal = values[i].cast<IntegerAttr>().getValue();
1004 else
1005 llvm_unreachable("unexpected element type");
1006
1007 assert(intVal.getBitWidth() == bitWidth &&
1008 "expected value to have same bitwidth as element type");
1009 writeBits(data.data(), i * storageBitWidth, intVal);
1010 }
1011
1012 // Handle the special encoding of splat of bool.
1013 if (values.size() == 1 && values[0].getType().isInteger(1))
1014 data[0] = data[0] ? -1 : 0;
1015
1016 return DenseIntOrFPElementsAttr::getRaw(type, data);
1017 }
1018
get(ShapedType type,ArrayRef<bool> values)1019 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1020 ArrayRef<bool> values) {
1021 assert(hasSameElementsOrSplat(type, values));
1022 assert(type.getElementType().isInteger(1));
1023
1024 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
1025
1026 if (!values.empty()) {
1027 bool isSplat = true;
1028 bool firstValue = values[0];
1029 for (int i = 0, e = values.size(); i != e; ++i) {
1030 isSplat &= values[i] == firstValue;
1031 setBit(buff.data(), i, values[i]);
1032 }
1033
1034 // Splat of bool is encoded as a byte with all-ones in it.
1035 if (isSplat) {
1036 buff.resize(1);
1037 buff[0] = values[0] ? -1 : 0;
1038 }
1039 }
1040
1041 return DenseIntOrFPElementsAttr::getRaw(type, buff);
1042 }
1043
get(ShapedType type,ArrayRef<StringRef> values)1044 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1045 ArrayRef<StringRef> values) {
1046 assert(!type.getElementType().isIntOrFloat());
1047 return DenseStringElementsAttr::get(type, values);
1048 }
1049
1050 /// Constructs a dense integer elements attribute from an array of APInt
1051 /// values. Each APInt value is expected to have the same bitwidth as the
1052 /// element type of 'type'.
get(ShapedType type,ArrayRef<APInt> values)1053 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1054 ArrayRef<APInt> values) {
1055 assert(type.getElementType().isIntOrIndex());
1056 assert(hasSameElementsOrSplat(type, values));
1057 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
1058 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1059 }
get(ShapedType type,ArrayRef<std::complex<APInt>> values)1060 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1061 ArrayRef<std::complex<APInt>> values) {
1062 ComplexType complex = type.getElementType().cast<ComplexType>();
1063 assert(complex.getElementType().isa<IntegerType>());
1064 assert(hasSameElementsOrSplat(type, values));
1065 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
1066 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
1067 values.size() * 2);
1068 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals);
1069 }
1070
1071 // Constructs a dense float elements attribute from an array of APFloat
1072 // values. Each APFloat value is expected to have the same bitwidth as the
1073 // element type of 'type'.
get(ShapedType type,ArrayRef<APFloat> values)1074 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1075 ArrayRef<APFloat> values) {
1076 assert(type.getElementType().isa<FloatType>());
1077 assert(hasSameElementsOrSplat(type, values));
1078 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
1079 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1080 }
1081 DenseElementsAttr
get(ShapedType type,ArrayRef<std::complex<APFloat>> values)1082 DenseElementsAttr::get(ShapedType type,
1083 ArrayRef<std::complex<APFloat>> values) {
1084 ComplexType complex = type.getElementType().cast<ComplexType>();
1085 assert(complex.getElementType().isa<FloatType>());
1086 assert(hasSameElementsOrSplat(type, values));
1087 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
1088 values.size() * 2);
1089 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
1090 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals);
1091 }
1092
1093 /// Construct a dense elements attribute from a raw buffer representing the
1094 /// data for this attribute. Users should generally not use this methods as
1095 /// the expected buffer format may not be a form the user expects.
1096 DenseElementsAttr
getFromRawBuffer(ShapedType type,ArrayRef<char> rawBuffer)1097 DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) {
1098 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer);
1099 }
1100
1101 /// Returns true if the given buffer is a valid raw buffer for the given type.
isValidRawBuffer(ShapedType type,ArrayRef<char> rawBuffer,bool & detectedSplat)1102 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
1103 ArrayRef<char> rawBuffer,
1104 bool &detectedSplat) {
1105 size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
1106 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
1107 int64_t numElements = type.getNumElements();
1108
1109 // The initializer is always a splat if the result type has a single element.
1110 detectedSplat = numElements == 1;
1111
1112 // Storage width of 1 is special as it is packed by the bit.
1113 if (storageWidth == 1) {
1114 // Check for a splat, or a buffer equal to the number of elements which
1115 // consists of either all 0's or all 1's.
1116 if (rawBuffer.size() == 1) {
1117 auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
1118 if (rawByte == 0 || rawByte == 0xff) {
1119 detectedSplat = true;
1120 return true;
1121 }
1122 }
1123
1124 // This is a valid non-splat buffer if it has the right size.
1125 return rawBufferWidth == llvm::alignTo<8>(numElements);
1126 }
1127
1128 // All other types are 8-bit aligned, so we can just check the buffer width
1129 // to know if only a single initializer element was passed in.
1130 if (rawBufferWidth == storageWidth) {
1131 detectedSplat = true;
1132 return true;
1133 }
1134
1135 // The raw buffer is valid if it has the right size.
1136 return rawBufferWidth == storageWidth * numElements;
1137 }
1138
1139 /// Check the information for a C++ data type, check if this type is valid for
1140 /// the current attribute. This method is used to verify specific type
1141 /// invariants that the templatized 'getValues' method cannot.
isValidIntOrFloat(Type type,int64_t dataEltSize,bool isInt,bool isSigned)1142 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
1143 bool isSigned) {
1144 // Make sure that the data element size is the same as the type element width.
1145 if (getDenseElementBitWidth(type) !=
1146 static_cast<size_t>(dataEltSize * CHAR_BIT))
1147 return false;
1148
1149 // Check that the element type is either float or integer or index.
1150 if (!isInt)
1151 return type.isa<FloatType>();
1152 if (type.isIndex())
1153 return true;
1154
1155 auto intType = type.dyn_cast<IntegerType>();
1156 if (!intType)
1157 return false;
1158
1159 // Make sure signedness semantics is consistent.
1160 if (intType.isSignless())
1161 return true;
1162 return intType.isSigned() ? isSigned : !isSigned;
1163 }
1164
1165 /// Defaults down the subclass implementation.
getRawComplex(ShapedType type,ArrayRef<char> data,int64_t dataEltSize,bool isInt,bool isSigned)1166 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
1167 ArrayRef<char> data,
1168 int64_t dataEltSize,
1169 bool isInt, bool isSigned) {
1170 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
1171 isSigned);
1172 }
getRawIntOrFloat(ShapedType type,ArrayRef<char> data,int64_t dataEltSize,bool isInt,bool isSigned)1173 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
1174 ArrayRef<char> data,
1175 int64_t dataEltSize,
1176 bool isInt,
1177 bool isSigned) {
1178 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
1179 isInt, isSigned);
1180 }
1181
isValidIntOrFloat(int64_t dataEltSize,bool isInt,bool isSigned) const1182 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
1183 bool isSigned) const {
1184 return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned);
1185 }
isValidComplex(int64_t dataEltSize,bool isInt,bool isSigned) const1186 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
1187 bool isSigned) const {
1188 return ::isValidIntOrFloat(
1189 getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2,
1190 isInt, isSigned);
1191 }
1192
1193 /// Returns true if this attribute corresponds to a splat, i.e. if all element
1194 /// values are the same.
isSplat() const1195 bool DenseElementsAttr::isSplat() const {
1196 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
1197 }
1198
1199 /// Return if the given complex type has an integer element type.
isComplexOfIntType(Type type)1200 LLVM_ATTRIBUTE_UNUSED static bool isComplexOfIntType(Type type) {
1201 return type.cast<ComplexType>().getElementType().isa<IntegerType>();
1202 }
1203
getComplexIntValues() const1204 auto DenseElementsAttr::getComplexIntValues() const
1205 -> iterator_range_impl<ComplexIntElementIterator> {
1206 assert(isComplexOfIntType(getElementType()) &&
1207 "expected complex integral type");
1208 return {getType(), ComplexIntElementIterator(*this, 0),
1209 ComplexIntElementIterator(*this, getNumElements())};
1210 }
complex_value_begin() const1211 auto DenseElementsAttr::complex_value_begin() const
1212 -> ComplexIntElementIterator {
1213 assert(isComplexOfIntType(getElementType()) &&
1214 "expected complex integral type");
1215 return ComplexIntElementIterator(*this, 0);
1216 }
complex_value_end() const1217 auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator {
1218 assert(isComplexOfIntType(getElementType()) &&
1219 "expected complex integral type");
1220 return ComplexIntElementIterator(*this, getNumElements());
1221 }
1222
1223 /// Return the held element values as a range of APFloat. The element type of
1224 /// this attribute must be of float type.
getFloatValues() const1225 auto DenseElementsAttr::getFloatValues() const
1226 -> iterator_range_impl<FloatElementIterator> {
1227 auto elementType = getElementType().cast<FloatType>();
1228 const auto &elementSemantics = elementType.getFloatSemantics();
1229 return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()),
1230 FloatElementIterator(elementSemantics, raw_int_end())};
1231 }
float_value_begin() const1232 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
1233 auto elementType = getElementType().cast<FloatType>();
1234 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin());
1235 }
float_value_end() const1236 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
1237 auto elementType = getElementType().cast<FloatType>();
1238 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end());
1239 }
1240
getComplexFloatValues() const1241 auto DenseElementsAttr::getComplexFloatValues() const
1242 -> iterator_range_impl<ComplexFloatElementIterator> {
1243 Type eltTy = getElementType().cast<ComplexType>().getElementType();
1244 assert(eltTy.isa<FloatType>() && "expected complex float type");
1245 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
1246 return {getType(),
1247 {semantics, {*this, 0}},
1248 {semantics, {*this, static_cast<size_t>(getNumElements())}}};
1249 }
complex_float_value_begin() const1250 auto DenseElementsAttr::complex_float_value_begin() const
1251 -> ComplexFloatElementIterator {
1252 Type eltTy = getElementType().cast<ComplexType>().getElementType();
1253 assert(eltTy.isa<FloatType>() && "expected complex float type");
1254 return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}};
1255 }
complex_float_value_end() const1256 auto DenseElementsAttr::complex_float_value_end() const
1257 -> ComplexFloatElementIterator {
1258 Type eltTy = getElementType().cast<ComplexType>().getElementType();
1259 assert(eltTy.isa<FloatType>() && "expected complex float type");
1260 return {eltTy.cast<FloatType>().getFloatSemantics(),
1261 {*this, static_cast<size_t>(getNumElements())}};
1262 }
1263
1264 /// Return the raw storage data held by this attribute.
getRawData() const1265 ArrayRef<char> DenseElementsAttr::getRawData() const {
1266 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
1267 }
1268
getRawStringData() const1269 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1270 return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
1271 }
1272
1273 /// Return a new DenseElementsAttr that has the same data as the current
1274 /// attribute, but has been reshaped to 'newType'. The new type must have the
1275 /// same total number of elements as well as element type.
reshape(ShapedType newType)1276 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1277 ShapedType curType = getType();
1278 if (curType == newType)
1279 return *this;
1280
1281 assert(newType.getElementType() == curType.getElementType() &&
1282 "expected the same element type");
1283 assert(newType.getNumElements() == curType.getNumElements() &&
1284 "expected the same number of elements");
1285 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1286 }
1287
resizeSplat(ShapedType newType)1288 DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
1289 assert(isSplat() && "expected a splat type");
1290
1291 ShapedType curType = getType();
1292 if (curType == newType)
1293 return *this;
1294
1295 assert(newType.getElementType() == curType.getElementType() &&
1296 "expected the same element type");
1297 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1298 }
1299
1300 /// Return a new DenseElementsAttr that has the same data as the current
1301 /// attribute, but has bitcast elements such that it is now 'newType'. The new
1302 /// type must have the same shape and element types of the same bitwidth as the
1303 /// current type.
bitcast(Type newElType)1304 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
1305 ShapedType curType = getType();
1306 Type curElType = curType.getElementType();
1307 if (curElType == newElType)
1308 return *this;
1309
1310 assert(getDenseElementBitWidth(newElType) ==
1311 getDenseElementBitWidth(curElType) &&
1312 "expected element types with the same bitwidth");
1313 return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
1314 getRawData());
1315 }
1316
1317 DenseElementsAttr
mapValues(Type newElementType,function_ref<APInt (const APInt &)> mapping) const1318 DenseElementsAttr::mapValues(Type newElementType,
1319 function_ref<APInt(const APInt &)> mapping) const {
1320 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1321 }
1322
mapValues(Type newElementType,function_ref<APInt (const APFloat &)> mapping) const1323 DenseElementsAttr DenseElementsAttr::mapValues(
1324 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1325 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1326 }
1327
getType() const1328 ShapedType DenseElementsAttr::getType() const {
1329 return Attribute::getType().cast<ShapedType>();
1330 }
1331
getElementType() const1332 Type DenseElementsAttr::getElementType() const {
1333 return getType().getElementType();
1334 }
1335
getNumElements() const1336 int64_t DenseElementsAttr::getNumElements() const {
1337 return getType().getNumElements();
1338 }
1339
1340 //===----------------------------------------------------------------------===//
1341 // DenseIntOrFPElementsAttr
1342 //===----------------------------------------------------------------------===//
1343
1344 /// Utility method to write a range of APInt values to a buffer.
1345 template <typename APRangeT>
writeAPIntsToBuffer(size_t storageWidth,std::vector<char> & data,APRangeT && values)1346 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1347 APRangeT &&values) {
1348 size_t numValues = llvm::size(values);
1349 data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT));
1350 size_t offset = 0;
1351 for (auto it = values.begin(), e = values.end(); it != e;
1352 ++it, offset += storageWidth) {
1353 assert((*it).getBitWidth() <= storageWidth);
1354 writeBits(data.data(), offset, *it);
1355 }
1356
1357 // Handle the special encoding of splat of a boolean.
1358 if (numValues == 1 && (*values.begin()).getBitWidth() == 1)
1359 data[0] = data[0] ? -1 : 0;
1360 }
1361
1362 /// Constructs a dense elements attribute from an array of raw APFloat values.
1363 /// Each APFloat value is expected to have the same bitwidth as the element
1364 /// type of 'type'. 'type' must be a vector or tensor with static shape.
getRaw(ShapedType type,size_t storageWidth,ArrayRef<APFloat> values)1365 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1366 size_t storageWidth,
1367 ArrayRef<APFloat> values) {
1368 std::vector<char> data;
1369 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1370 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1371 return DenseIntOrFPElementsAttr::getRaw(type, data);
1372 }
1373
1374 /// Constructs a dense elements attribute from an array of raw APInt values.
1375 /// Each APInt value is expected to have the same bitwidth as the element type
1376 /// of 'type'.
getRaw(ShapedType type,size_t storageWidth,ArrayRef<APInt> values)1377 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1378 size_t storageWidth,
1379 ArrayRef<APInt> values) {
1380 std::vector<char> data;
1381 writeAPIntsToBuffer(storageWidth, data, values);
1382 return DenseIntOrFPElementsAttr::getRaw(type, data);
1383 }
1384
getRaw(ShapedType type,ArrayRef<char> data)1385 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1386 ArrayRef<char> data) {
1387 assert((type.isa<RankedTensorType, VectorType>()) &&
1388 "type must be ranked tensor or vector");
1389 assert(type.hasStaticShape() && "type must have static shape");
1390 bool isSplat = false;
1391 bool isValid = isValidRawBuffer(type, data, isSplat);
1392 assert(isValid);
1393 (void)isValid;
1394 return Base::get(type.getContext(), type, data, isSplat);
1395 }
1396
1397 /// Overload of the raw 'get' method that asserts that the given type is of
1398 /// complex type. This method is used to verify type invariants that the
1399 /// templatized 'get' method cannot.
getRawComplex(ShapedType type,ArrayRef<char> data,int64_t dataEltSize,bool isInt,bool isSigned)1400 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1401 ArrayRef<char> data,
1402 int64_t dataEltSize,
1403 bool isInt,
1404 bool isSigned) {
1405 assert(::isValidIntOrFloat(
1406 type.getElementType().cast<ComplexType>().getElementType(),
1407 dataEltSize / 2, isInt, isSigned));
1408
1409 int64_t numElements = data.size() / dataEltSize;
1410 (void)numElements;
1411 assert(numElements == 1 || numElements == type.getNumElements());
1412 return getRaw(type, data);
1413 }
1414
1415 /// Overload of the 'getRaw' method that asserts that the given type is of
1416 /// integer type. This method is used to verify type invariants that the
1417 /// templatized 'get' method cannot.
1418 DenseElementsAttr
getRawIntOrFloat(ShapedType type,ArrayRef<char> data,int64_t dataEltSize,bool isInt,bool isSigned)1419 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1420 int64_t dataEltSize, bool isInt,
1421 bool isSigned) {
1422 assert(
1423 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1424
1425 int64_t numElements = data.size() / dataEltSize;
1426 assert(numElements == 1 || numElements == type.getNumElements());
1427 (void)numElements;
1428 return getRaw(type, data);
1429 }
1430
convertEndianOfCharForBEmachine(const char * inRawData,char * outRawData,size_t elementBitWidth,size_t numElements)1431 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1432 const char *inRawData, char *outRawData, size_t elementBitWidth,
1433 size_t numElements) {
1434 using llvm::support::ulittle16_t;
1435 using llvm::support::ulittle32_t;
1436 using llvm::support::ulittle64_t;
1437
1438 assert(llvm::support::endian::system_endianness() == // NOLINT
1439 llvm::support::endianness::big); // NOLINT
1440 // NOLINT to avoid warning message about replacing by static_assert()
1441
1442 // Following std::copy_n always converts endianness on BE machine.
1443 switch (elementBitWidth) {
1444 case 16: {
1445 const ulittle16_t *inRawDataPos =
1446 reinterpret_cast<const ulittle16_t *>(inRawData);
1447 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1448 std::copy_n(inRawDataPos, numElements, outDataPos);
1449 break;
1450 }
1451 case 32: {
1452 const ulittle32_t *inRawDataPos =
1453 reinterpret_cast<const ulittle32_t *>(inRawData);
1454 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1455 std::copy_n(inRawDataPos, numElements, outDataPos);
1456 break;
1457 }
1458 case 64: {
1459 const ulittle64_t *inRawDataPos =
1460 reinterpret_cast<const ulittle64_t *>(inRawData);
1461 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1462 std::copy_n(inRawDataPos, numElements, outDataPos);
1463 break;
1464 }
1465 default: {
1466 size_t nBytes = elementBitWidth / CHAR_BIT;
1467 for (size_t i = 0; i < nBytes; i++)
1468 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1469 break;
1470 }
1471 }
1472 }
1473
convertEndianOfArrayRefForBEmachine(ArrayRef<char> inRawData,MutableArrayRef<char> outRawData,ShapedType type)1474 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1475 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1476 ShapedType type) {
1477 size_t numElements = type.getNumElements();
1478 Type elementType = type.getElementType();
1479 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1480 elementType = complexTy.getElementType();
1481 numElements = numElements * 2;
1482 }
1483 size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1484 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1485 inRawData.size() <= outRawData.size());
1486 if (elementBitWidth <= CHAR_BIT)
1487 std::memcpy(outRawData.begin(), inRawData.begin(), inRawData.size());
1488 else
1489 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1490 elementBitWidth, numElements);
1491 }
1492
1493 //===----------------------------------------------------------------------===//
1494 // DenseFPElementsAttr
1495 //===----------------------------------------------------------------------===//
1496
1497 template <typename Fn, typename Attr>
mappingHelper(Fn mapping,Attr & attr,ShapedType inType,Type newElementType,llvm::SmallVectorImpl<char> & data)1498 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1499 Type newElementType,
1500 llvm::SmallVectorImpl<char> &data) {
1501 size_t bitWidth = getDenseElementBitWidth(newElementType);
1502 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1503
1504 ShapedType newArrayType;
1505 if (inType.isa<RankedTensorType>())
1506 newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1507 else if (inType.isa<UnrankedTensorType>())
1508 newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1509 else if (auto vType = inType.dyn_cast<VectorType>())
1510 newArrayType = VectorType::get(vType.getShape(), newElementType,
1511 vType.getNumScalableDims());
1512 else
1513 assert(newArrayType && "Unhandled tensor type");
1514
1515 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1516 data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT));
1517
1518 // Functor used to process a single element value of the attribute.
1519 auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1520 auto newInt = mapping(value);
1521 assert(newInt.getBitWidth() == bitWidth);
1522 writeBits(data.data(), index * storageBitWidth, newInt);
1523 };
1524
1525 // Check for the splat case.
1526 if (attr.isSplat()) {
1527 processElt(*attr.begin(), /*index=*/0);
1528 return newArrayType;
1529 }
1530
1531 // Otherwise, process all of the element values.
1532 uint64_t elementIdx = 0;
1533 for (auto value : attr)
1534 processElt(value, elementIdx++);
1535 return newArrayType;
1536 }
1537
mapValues(Type newElementType,function_ref<APInt (const APFloat &)> mapping) const1538 DenseElementsAttr DenseFPElementsAttr::mapValues(
1539 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1540 llvm::SmallVector<char, 8> elementData;
1541 auto newArrayType =
1542 mappingHelper(mapping, *this, getType(), newElementType, elementData);
1543
1544 return getRaw(newArrayType, elementData);
1545 }
1546
1547 /// Method for supporting type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)1548 bool DenseFPElementsAttr::classof(Attribute attr) {
1549 return attr.isa<DenseElementsAttr>() &&
1550 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1551 }
1552
1553 //===----------------------------------------------------------------------===//
1554 // DenseIntElementsAttr
1555 //===----------------------------------------------------------------------===//
1556
mapValues(Type newElementType,function_ref<APInt (const APInt &)> mapping) const1557 DenseElementsAttr DenseIntElementsAttr::mapValues(
1558 Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1559 llvm::SmallVector<char, 8> elementData;
1560 auto newArrayType =
1561 mappingHelper(mapping, *this, getType(), newElementType, elementData);
1562 return getRaw(newArrayType, elementData);
1563 }
1564
1565 /// Method for supporting type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)1566 bool DenseIntElementsAttr::classof(Attribute attr) {
1567 return attr.isa<DenseElementsAttr>() &&
1568 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1569 }
1570
1571 //===----------------------------------------------------------------------===//
1572 // OpaqueElementsAttr
1573 //===----------------------------------------------------------------------===//
1574
decode(ElementsAttr & result)1575 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1576 Dialect *dialect = getContext()->getLoadedDialect(getDialect());
1577 if (!dialect)
1578 return true;
1579 auto *interface = llvm::dyn_cast<DialectDecodeAttributesInterface>(dialect);
1580 if (!interface)
1581 return true;
1582 return failed(interface->decode(*this, result));
1583 }
1584
1585 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,StringAttr dialect,StringRef value,ShapedType type)1586 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1587 StringAttr dialect, StringRef value,
1588 ShapedType type) {
1589 if (!Dialect::isValidNamespace(dialect.strref()))
1590 return emitError() << "invalid dialect namespace '" << dialect << "'";
1591 return success();
1592 }
1593
1594 //===----------------------------------------------------------------------===//
1595 // SparseElementsAttr
1596 //===----------------------------------------------------------------------===//
1597
1598 /// Get a zero APFloat for the given sparse attribute.
getZeroAPFloat() const1599 APFloat SparseElementsAttr::getZeroAPFloat() const {
1600 auto eltType = getElementType().cast<FloatType>();
1601 return APFloat(eltType.getFloatSemantics());
1602 }
1603
1604 /// Get a zero APInt for the given sparse attribute.
getZeroAPInt() const1605 APInt SparseElementsAttr::getZeroAPInt() const {
1606 auto eltType = getElementType().cast<IntegerType>();
1607 return APInt::getZero(eltType.getWidth());
1608 }
1609
1610 /// Get a zero attribute for the given attribute type.
getZeroAttr() const1611 Attribute SparseElementsAttr::getZeroAttr() const {
1612 auto eltType = getElementType();
1613
1614 // Handle floating point elements.
1615 if (eltType.isa<FloatType>())
1616 return FloatAttr::get(eltType, 0);
1617
1618 // Handle complex elements.
1619 if (auto complexTy = eltType.dyn_cast<ComplexType>()) {
1620 auto eltType = complexTy.getElementType();
1621 Attribute zero;
1622 if (eltType.isa<FloatType>())
1623 zero = FloatAttr::get(eltType, 0);
1624 else // must be integer
1625 zero = IntegerAttr::get(eltType, 0);
1626 return ArrayAttr::get(complexTy.getContext(),
1627 ArrayRef<Attribute>{zero, zero});
1628 }
1629
1630 // Handle string type.
1631 if (getValues().isa<DenseStringElementsAttr>())
1632 return StringAttr::get("", eltType);
1633
1634 // Otherwise, this is an integer.
1635 return IntegerAttr::get(eltType, 0);
1636 }
1637
1638 /// Flatten, and return, all of the sparse indices in this attribute in
1639 /// row-major order.
getFlattenedSparseIndices() const1640 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1641 std::vector<ptrdiff_t> flatSparseIndices;
1642
1643 // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1644 // as a 1-D index array.
1645 auto sparseIndices = getIndices();
1646 auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1647 if (sparseIndices.isSplat()) {
1648 SmallVector<uint64_t, 8> indices(getType().getRank(),
1649 *sparseIndexValues.begin());
1650 flatSparseIndices.push_back(getFlattenedIndex(indices));
1651 return flatSparseIndices;
1652 }
1653
1654 // Otherwise, reinterpret each index as an ArrayRef when flattening.
1655 auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1656 size_t rank = getType().getRank();
1657 for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1658 flatSparseIndices.push_back(getFlattenedIndex(
1659 {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1660 return flatSparseIndices;
1661 }
1662
1663 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,ShapedType type,DenseIntElementsAttr sparseIndices,DenseElementsAttr values)1664 SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1665 ShapedType type, DenseIntElementsAttr sparseIndices,
1666 DenseElementsAttr values) {
1667 ShapedType valuesType = values.getType();
1668 if (valuesType.getRank() != 1)
1669 return emitError() << "expected 1-d tensor for sparse element values";
1670
1671 // Verify the indices and values shape.
1672 ShapedType indicesType = sparseIndices.getType();
1673 auto emitShapeError = [&]() {
1674 return emitError() << "expected shape ([" << type.getShape()
1675 << "]); inferred shape of indices literal (["
1676 << indicesType.getShape()
1677 << "]); inferred shape of values literal (["
1678 << valuesType.getShape() << "])";
1679 };
1680 // Verify indices shape.
1681 size_t rank = type.getRank(), indicesRank = indicesType.getRank();
1682 if (indicesRank == 2) {
1683 if (indicesType.getDimSize(1) != static_cast<int64_t>(rank))
1684 return emitShapeError();
1685 } else if (indicesRank != 1 || rank != 1) {
1686 return emitShapeError();
1687 }
1688 // Verify the values shape.
1689 int64_t numSparseIndices = indicesType.getDimSize(0);
1690 if (numSparseIndices != valuesType.getDimSize(0))
1691 return emitShapeError();
1692
1693 // Verify that the sparse indices are within the value shape.
1694 auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) {
1695 return emitError()
1696 << "sparse index #" << indexNum
1697 << " is not contained within the value shape, with index=[" << index
1698 << "], and type=" << type;
1699 };
1700
1701 // Handle the case where the index values are a splat.
1702 auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1703 if (sparseIndices.isSplat()) {
1704 SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin());
1705 if (!ElementsAttr::isValidIndex(type, indices))
1706 return emitIndexError(0, indices);
1707 return success();
1708 }
1709
1710 // Otherwise, reinterpret each index as an ArrayRef.
1711 for (size_t i = 0, e = numSparseIndices; i != e; ++i) {
1712 ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank),
1713 rank);
1714 if (!ElementsAttr::isValidIndex(type, index))
1715 return emitIndexError(i, index);
1716 }
1717
1718 return success();
1719 }
1720
1721 //===----------------------------------------------------------------------===//
1722 // TypeAttr
1723 //===----------------------------------------------------------------------===//
1724
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const1725 void TypeAttr::walkImmediateSubElements(
1726 function_ref<void(Attribute)> walkAttrsFn,
1727 function_ref<void(Type)> walkTypesFn) const {
1728 walkTypesFn(getValue());
1729 }
1730
1731 Attribute
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const1732 TypeAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
1733 ArrayRef<Type> replTypes) const {
1734 return get(replTypes[0]);
1735 }
1736