1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/BuiltinDialect.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/IntegerSet.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/SymbolTable.h" 17 #include "mlir/IR/Types.h" 18 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" 19 #include "llvm/ADT/APSInt.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/Support/Endian.h" 22 23 using namespace mlir; 24 using namespace mlir::detail; 25 26 //===----------------------------------------------------------------------===// 27 /// Tablegen Attribute Definitions 28 //===----------------------------------------------------------------------===// 29 30 #define GET_ATTRDEF_CLASSES 31 #include "mlir/IR/BuiltinAttributes.cpp.inc" 32 33 //===----------------------------------------------------------------------===// 34 // BuiltinDialect 35 //===----------------------------------------------------------------------===// 36 37 void BuiltinDialect::registerAttributes() { 38 addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr, 39 DenseStringElementsAttr, DictionaryAttr, FloatAttr, 40 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr, 41 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr, 42 UnitAttr>(); 43 } 44 45 //===----------------------------------------------------------------------===// 46 // ArrayAttr 47 //===----------------------------------------------------------------------===// 48 49 void ArrayAttr::walkImmediateSubElements( 50 function_ref<void(Attribute)> walkAttrsFn, 51 function_ref<void(Type)> walkTypesFn) const { 52 for (Attribute attr : getValue()) 53 walkAttrsFn(attr); 54 } 55 56 SubElementAttrInterface ArrayAttr::replaceImmediateSubAttribute( 57 ArrayRef<std::pair<size_t, Attribute>> replacements) const { 58 std::vector<Attribute> vector = getValue().vec(); 59 for (auto &it : replacements) { 60 vector[it.first] = it.second; 61 } 62 return get(getContext(), vector); 63 } 64 65 //===----------------------------------------------------------------------===// 66 // DictionaryAttr 67 //===----------------------------------------------------------------------===// 68 69 /// Helper function that does either an in place sort or sorts from source array 70 /// into destination. If inPlace then storage is both the source and the 71 /// destination, else value is the source and storage destination. Returns 72 /// whether source was sorted. 73 template <bool inPlace> 74 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 75 SmallVectorImpl<NamedAttribute> &storage) { 76 // Specialize for the common case. 77 switch (value.size()) { 78 case 0: 79 // Zero already sorted. 80 if (!inPlace) 81 storage.clear(); 82 break; 83 case 1: 84 // One already sorted but may need to be copied. 85 if (!inPlace) 86 storage.assign({value[0]}); 87 break; 88 case 2: { 89 bool isSorted = value[0] < value[1]; 90 if (inPlace) { 91 if (!isSorted) 92 std::swap(storage[0], storage[1]); 93 } else if (isSorted) { 94 storage.assign({value[0], value[1]}); 95 } else { 96 storage.assign({value[1], value[0]}); 97 } 98 return !isSorted; 99 } 100 default: 101 if (!inPlace) 102 storage.assign(value.begin(), value.end()); 103 // Check to see they are sorted already. 104 bool isSorted = llvm::is_sorted(value); 105 // If not, do a general sort. 106 if (!isSorted) 107 llvm::array_pod_sort(storage.begin(), storage.end()); 108 return !isSorted; 109 } 110 return false; 111 } 112 113 /// Returns an entry with a duplicate name from the given sorted array of named 114 /// attributes. Returns llvm::None if all elements have unique names. 115 static Optional<NamedAttribute> 116 findDuplicateElement(ArrayRef<NamedAttribute> value) { 117 const Optional<NamedAttribute> none{llvm::None}; 118 if (value.size() < 2) 119 return none; 120 121 if (value.size() == 2) 122 return value[0].first == value[1].first ? value[0] : none; 123 124 auto it = std::adjacent_find( 125 value.begin(), value.end(), 126 [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; }); 127 return it != value.end() ? *it : none; 128 } 129 130 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 131 SmallVectorImpl<NamedAttribute> &storage) { 132 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); 133 assert(!findDuplicateElement(storage) && 134 "DictionaryAttr element names must be unique"); 135 return isSorted; 136 } 137 138 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 139 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); 140 assert(!findDuplicateElement(array) && 141 "DictionaryAttr element names must be unique"); 142 return isSorted; 143 } 144 145 Optional<NamedAttribute> 146 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, 147 bool isSorted) { 148 if (!isSorted) 149 dictionaryAttrSort</*inPlace=*/true>(array, array); 150 return findDuplicateElement(array); 151 } 152 153 DictionaryAttr DictionaryAttr::get(MLIRContext *context, 154 ArrayRef<NamedAttribute> value) { 155 if (value.empty()) 156 return DictionaryAttr::getEmpty(context); 157 assert(llvm::all_of(value, 158 [](const NamedAttribute &attr) { return attr.second; }) && 159 "value cannot have null entries"); 160 161 // We need to sort the element list to canonicalize it. 162 SmallVector<NamedAttribute, 8> storage; 163 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 164 value = storage; 165 assert(!findDuplicateElement(value) && 166 "DictionaryAttr element names must be unique"); 167 return Base::get(context, value); 168 } 169 /// Construct a dictionary with an array of values that is known to already be 170 /// sorted by name and uniqued. 171 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context, 172 ArrayRef<NamedAttribute> value) { 173 if (value.empty()) 174 return DictionaryAttr::getEmpty(context); 175 // Ensure that the attribute elements are unique and sorted. 176 assert(llvm::is_sorted(value, 177 [](NamedAttribute l, NamedAttribute r) { 178 return l.first.strref() < r.first.strref(); 179 }) && 180 "expected attribute values to be sorted"); 181 assert(!findDuplicateElement(value) && 182 "DictionaryAttr element names must be unique"); 183 return Base::get(context, value); 184 } 185 186 /// Return the specified attribute if present, null otherwise. 187 Attribute DictionaryAttr::get(StringRef name) const { 188 auto it = impl::findAttrSorted(begin(), end(), name); 189 return it.second ? it.first->second : Attribute(); 190 } 191 Attribute DictionaryAttr::get(Identifier name) const { 192 auto it = impl::findAttrSorted(begin(), end(), name); 193 return it.second ? it.first->second : Attribute(); 194 } 195 196 /// Return the specified named attribute if present, None otherwise. 197 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 198 auto it = impl::findAttrSorted(begin(), end(), name); 199 return it.second ? *it.first : Optional<NamedAttribute>(); 200 } 201 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { 202 auto it = impl::findAttrSorted(begin(), end(), name); 203 return it.second ? *it.first : Optional<NamedAttribute>(); 204 } 205 206 /// Return whether the specified attribute is present. 207 bool DictionaryAttr::contains(StringRef name) const { 208 return impl::findAttrSorted(begin(), end(), name).second; 209 } 210 bool DictionaryAttr::contains(Identifier name) const { 211 return impl::findAttrSorted(begin(), end(), name).second; 212 } 213 214 DictionaryAttr::iterator DictionaryAttr::begin() const { 215 return getValue().begin(); 216 } 217 DictionaryAttr::iterator DictionaryAttr::end() const { 218 return getValue().end(); 219 } 220 size_t DictionaryAttr::size() const { return getValue().size(); } 221 222 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { 223 return Base::get(context, ArrayRef<NamedAttribute>()); 224 } 225 226 void DictionaryAttr::walkImmediateSubElements( 227 function_ref<void(Attribute)> walkAttrsFn, 228 function_ref<void(Type)> walkTypesFn) const { 229 for (Attribute attr : llvm::make_second_range(getValue())) 230 walkAttrsFn(attr); 231 } 232 233 SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute( 234 ArrayRef<std::pair<size_t, Attribute>> replacements) const { 235 std::vector<NamedAttribute> vec = getValue().vec(); 236 for (auto &it : replacements) { 237 vec[it.first].second = it.second; 238 } 239 // The above only modifies the mapped value, but not the key, and therefore 240 // not the order of the elements. It remains sorted 241 return getWithSorted(getContext(), vec); 242 } 243 244 //===----------------------------------------------------------------------===// 245 // StringAttr 246 //===----------------------------------------------------------------------===// 247 248 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) { 249 return Base::get(context, "", NoneType::get(context)); 250 } 251 252 /// Twine support for StringAttr. 253 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) { 254 // Fast-path empty twine. 255 if (twine.isTriviallyEmpty()) 256 return get(context); 257 SmallVector<char, 32> tempStr; 258 return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context)); 259 } 260 261 /// Twine support for StringAttr. 262 StringAttr StringAttr::get(const Twine &twine, Type type) { 263 SmallVector<char, 32> tempStr; 264 return Base::get(type.getContext(), twine.toStringRef(tempStr), type); 265 } 266 267 //===----------------------------------------------------------------------===// 268 // FloatAttr 269 //===----------------------------------------------------------------------===// 270 271 double FloatAttr::getValueAsDouble() const { 272 return getValueAsDouble(getValue()); 273 } 274 double FloatAttr::getValueAsDouble(APFloat value) { 275 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 276 bool losesInfo = false; 277 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 278 &losesInfo); 279 } 280 return value.convertToDouble(); 281 } 282 283 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 284 Type type, APFloat value) { 285 // Verify that the type is correct. 286 if (!type.isa<FloatType>()) 287 return emitError() << "expected floating point type"; 288 289 // Verify that the type semantics match that of the value. 290 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 291 return emitError() 292 << "FloatAttr type doesn't match the type implied by its value"; 293 } 294 return success(); 295 } 296 297 //===----------------------------------------------------------------------===// 298 // SymbolRefAttr 299 //===----------------------------------------------------------------------===// 300 301 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value, 302 ArrayRef<FlatSymbolRefAttr> nestedRefs) { 303 return get(StringAttr::get(ctx, value), nestedRefs); 304 } 305 306 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { 307 return get(ctx, value, {}).cast<FlatSymbolRefAttr>(); 308 } 309 310 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) { 311 return get(value, {}).cast<FlatSymbolRefAttr>(); 312 } 313 314 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) { 315 auto symName = 316 symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); 317 assert(symName && "value does not have a valid symbol name"); 318 return SymbolRefAttr::get(symName); 319 } 320 321 StringAttr SymbolRefAttr::getLeafReference() const { 322 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 323 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr(); 324 } 325 326 //===----------------------------------------------------------------------===// 327 // IntegerAttr 328 //===----------------------------------------------------------------------===// 329 330 int64_t IntegerAttr::getInt() const { 331 assert((getType().isIndex() || getType().isSignlessInteger()) && 332 "must be signless integer"); 333 return getValue().getSExtValue(); 334 } 335 336 int64_t IntegerAttr::getSInt() const { 337 assert(getType().isSignedInteger() && "must be signed integer"); 338 return getValue().getSExtValue(); 339 } 340 341 uint64_t IntegerAttr::getUInt() const { 342 assert(getType().isUnsignedInteger() && "must be unsigned integer"); 343 return getValue().getZExtValue(); 344 } 345 346 /// Return the value as an APSInt which carries the signed from the type of 347 /// the attribute. This traps on signless integers types! 348 APSInt IntegerAttr::getAPSInt() const { 349 assert(!getType().isSignlessInteger() && 350 "Signless integers don't carry a sign for APSInt"); 351 return APSInt(getValue(), getType().isUnsignedInteger()); 352 } 353 354 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, 355 Type type, APInt value) { 356 if (IntegerType integerType = type.dyn_cast<IntegerType>()) { 357 if (integerType.getWidth() != value.getBitWidth()) 358 return emitError() << "integer type bit width (" << integerType.getWidth() 359 << ") doesn't match value bit width (" 360 << value.getBitWidth() << ")"; 361 return success(); 362 } 363 if (type.isa<IndexType>()) 364 return success(); 365 return emitError() << "expected integer or index type"; 366 } 367 368 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { 369 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value)); 370 return attr.cast<BoolAttr>(); 371 } 372 373 //===----------------------------------------------------------------------===// 374 // BoolAttr 375 376 bool BoolAttr::getValue() const { 377 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl); 378 return storage->value.getBoolValue(); 379 } 380 381 bool BoolAttr::classof(Attribute attr) { 382 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 383 return intAttr && intAttr.getType().isSignlessInteger(1); 384 } 385 386 //===----------------------------------------------------------------------===// 387 // OpaqueAttr 388 //===----------------------------------------------------------------------===// 389 390 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError, 391 Identifier dialect, StringRef attrData, 392 Type type) { 393 if (!Dialect::isValidNamespace(dialect.strref())) 394 return emitError() << "invalid dialect namespace '" << dialect << "'"; 395 396 // Check that the dialect is actually registered. 397 MLIRContext *context = dialect.getContext(); 398 if (!context->allowsUnregisteredDialects() && 399 !context->getLoadedDialect(dialect.strref())) { 400 return emitError() 401 << "#" << dialect << "<\"" << attrData << "\"> : " << type 402 << " attribute created with unregistered dialect. If this is " 403 "intended, please call allowUnregisteredDialects() on the " 404 "MLIRContext, or use -allow-unregistered-dialect with " 405 "the MLIR opt tool used"; 406 } 407 408 return success(); 409 } 410 411 //===----------------------------------------------------------------------===// 412 // DenseElementsAttr Utilities 413 //===----------------------------------------------------------------------===// 414 415 /// Get the bitwidth of a dense element type within the buffer. 416 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 417 static size_t getDenseElementStorageWidth(size_t origWidth) { 418 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 419 } 420 static size_t getDenseElementStorageWidth(Type elementType) { 421 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 422 } 423 424 /// Set a bit to a specific value. 425 static void setBit(char *rawData, size_t bitPos, bool value) { 426 if (value) 427 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 428 else 429 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 430 } 431 432 /// Return the value of the specified bit. 433 static bool getBit(const char *rawData, size_t bitPos) { 434 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 435 } 436 437 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for 438 /// BE format. 439 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, 440 char *result) { 441 assert(llvm::support::endian::system_endianness() == // NOLINT 442 llvm::support::endianness::big); // NOLINT 443 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 444 445 // Copy the words filled with data. 446 // For example, when `value` has 2 words, the first word is filled with data. 447 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| 448 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 449 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 450 numFilledWords, result); 451 // Convert last word of APInt to LE format and store it in char 452 // array(`valueLE`). 453 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| 454 size_t lastWordPos = numFilledWords; 455 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); 456 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 457 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, 458 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); 459 // Extract actual APInt data from `valueLE`, convert endianness to BE format, 460 // and store it in `result`. 461 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| 462 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 463 valueLE.begin(), result + lastWordPos, 464 (numBytes - lastWordPos) * CHAR_BIT, 1); 465 } 466 467 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE 468 /// format. 469 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, 470 APInt &result) { 471 assert(llvm::support::endian::system_endianness() == // NOLINT 472 llvm::support::endianness::big); // NOLINT 473 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 474 475 // Copy the data that fills the word of `result` from `inArray`. 476 // For example, when `result` has 2 words, the first word will be filled with 477 // data. So, the first 8 bytes are copied from `inArray` here. 478 // `inArray` (10 bytes, BE): |abcdefgh|ij| 479 // ==> `result` (2 words, BE): |abcdefgh|--------| 480 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 481 std::copy_n( 482 inArray, numFilledWords, 483 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 484 485 // Convert array data which will be last word of `result` to LE format, and 486 // store it in char array(`inArrayLE`). 487 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| 488 size_t lastWordPos = numFilledWords; 489 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); 490 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 491 inArray + lastWordPos, inArrayLE.begin(), 492 (numBytes - lastWordPos) * CHAR_BIT, 1); 493 494 // Convert `inArrayLE` to BE format, and store it in last word of `result`. 495 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| 496 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 497 inArrayLE.begin(), 498 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + 499 lastWordPos, 500 APInt::APINT_BITS_PER_WORD, 1); 501 } 502 503 /// Writes value to the bit position `bitPos` in array `rawData`. 504 static void writeBits(char *rawData, size_t bitPos, APInt value) { 505 size_t bitWidth = value.getBitWidth(); 506 507 // If the bitwidth is 1 we just toggle the specific bit. 508 if (bitWidth == 1) 509 return setBit(rawData, bitPos, value.isOneValue()); 510 511 // Otherwise, the bit position is guaranteed to be byte aligned. 512 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 513 if (llvm::support::endian::system_endianness() == 514 llvm::support::endianness::big) { 515 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. 516 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 517 // work correctly in BE format. 518 // ex. `value` (2 words including 10 bytes) 519 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| 520 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), 521 rawData + (bitPos / CHAR_BIT)); 522 } else { 523 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 524 llvm::divideCeil(bitWidth, CHAR_BIT), 525 rawData + (bitPos / CHAR_BIT)); 526 } 527 } 528 529 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 530 /// `rawData`. 531 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 532 // Handle a boolean bit position. 533 if (bitWidth == 1) 534 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 535 536 // Otherwise, the bit position must be 8-bit aligned. 537 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 538 APInt result(bitWidth, 0); 539 if (llvm::support::endian::system_endianness() == 540 llvm::support::endianness::big) { 541 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. 542 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 543 // work correctly in BE format. 544 // ex. `result` (2 words including 10 bytes) 545 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function 546 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), 547 llvm::divideCeil(bitWidth, CHAR_BIT), result); 548 } else { 549 std::copy_n(rawData + (bitPos / CHAR_BIT), 550 llvm::divideCeil(bitWidth, CHAR_BIT), 551 const_cast<char *>( 552 reinterpret_cast<const char *>(result.getRawData()))); 553 } 554 return result; 555 } 556 557 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 558 /// the same element count as 'type'. 559 template <typename Values> 560 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 561 return (values.size() == 1) || 562 (type.getNumElements() == static_cast<int64_t>(values.size())); 563 } 564 565 //===----------------------------------------------------------------------===// 566 // DenseElementsAttr Iterators 567 //===----------------------------------------------------------------------===// 568 569 //===----------------------------------------------------------------------===// 570 // AttributeElementIterator 571 572 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 573 DenseElementsAttr attr, size_t index) 574 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 575 Attribute, Attribute, Attribute>( 576 attr.getAsOpaquePointer(), index) {} 577 578 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 579 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 580 Type eltTy = owner.getElementType(); 581 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 582 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 583 if (eltTy.isa<IndexType>()) 584 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 585 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 586 IntElementIterator intIt(owner, index); 587 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 588 return FloatAttr::get(eltTy, *floatIt); 589 } 590 if (auto complexTy = eltTy.dyn_cast<ComplexType>()) { 591 auto complexEltTy = complexTy.getElementType(); 592 ComplexIntElementIterator complexIntIt(owner, index); 593 if (complexEltTy.isa<IntegerType>()) { 594 auto value = *complexIntIt; 595 auto real = IntegerAttr::get(complexEltTy, value.real()); 596 auto imag = IntegerAttr::get(complexEltTy, value.imag()); 597 return ArrayAttr::get(complexTy.getContext(), 598 ArrayRef<Attribute>{real, imag}); 599 } 600 601 ComplexFloatElementIterator complexFloatIt( 602 complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt); 603 auto value = *complexFloatIt; 604 auto real = FloatAttr::get(complexEltTy, value.real()); 605 auto imag = FloatAttr::get(complexEltTy, value.imag()); 606 return ArrayAttr::get(complexTy.getContext(), 607 ArrayRef<Attribute>{real, imag}); 608 } 609 if (owner.isa<DenseStringElementsAttr>()) { 610 ArrayRef<StringRef> vals = owner.getRawStringData(); 611 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 612 } 613 llvm_unreachable("unexpected element type"); 614 } 615 616 //===----------------------------------------------------------------------===// 617 // BoolElementIterator 618 619 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 620 DenseElementsAttr attr, size_t dataIndex) 621 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 622 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 623 624 bool DenseElementsAttr::BoolElementIterator::operator*() const { 625 return getBit(getData(), getDataIndex()); 626 } 627 628 //===----------------------------------------------------------------------===// 629 // IntElementIterator 630 631 DenseElementsAttr::IntElementIterator::IntElementIterator( 632 DenseElementsAttr attr, size_t dataIndex) 633 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 634 attr.getRawData().data(), attr.isSplat(), dataIndex), 635 bitWidth(getDenseElementBitWidth(attr.getElementType())) {} 636 637 APInt DenseElementsAttr::IntElementIterator::operator*() const { 638 return readBits(getData(), 639 getDataIndex() * getDenseElementStorageWidth(bitWidth), 640 bitWidth); 641 } 642 643 //===----------------------------------------------------------------------===// 644 // ComplexIntElementIterator 645 646 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 647 DenseElementsAttr attr, size_t dataIndex) 648 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 649 std::complex<APInt>, std::complex<APInt>, 650 std::complex<APInt>>( 651 attr.getRawData().data(), attr.isSplat(), dataIndex) { 652 auto complexType = attr.getElementType().cast<ComplexType>(); 653 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 654 } 655 656 std::complex<APInt> 657 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 658 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 659 size_t offset = getDataIndex() * storageWidth * 2; 660 return {readBits(getData(), offset, bitWidth), 661 readBits(getData(), offset + storageWidth, bitWidth)}; 662 } 663 664 //===----------------------------------------------------------------------===// 665 // FloatElementIterator 666 667 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 668 const llvm::fltSemantics &smt, IntElementIterator it) 669 : llvm::mapped_iterator<IntElementIterator, 670 std::function<APFloat(const APInt &)>>( 671 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 672 673 //===----------------------------------------------------------------------===// 674 // ComplexFloatElementIterator 675 676 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( 677 const llvm::fltSemantics &smt, ComplexIntElementIterator it) 678 : llvm::mapped_iterator< 679 ComplexIntElementIterator, 680 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( 681 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { 682 return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; 683 }) {} 684 685 //===----------------------------------------------------------------------===// 686 // DenseElementsAttr 687 //===----------------------------------------------------------------------===// 688 689 /// Method for support type inquiry through isa, cast and dyn_cast. 690 bool DenseElementsAttr::classof(Attribute attr) { 691 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 692 } 693 694 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 695 ArrayRef<Attribute> values) { 696 assert(hasSameElementsOrSplat(type, values)); 697 698 // If the element type is not based on int/float/index, assume it is a string 699 // type. 700 auto eltType = type.getElementType(); 701 if (!type.getElementType().isIntOrIndexOrFloat()) { 702 SmallVector<StringRef, 8> stringValues; 703 stringValues.reserve(values.size()); 704 for (Attribute attr : values) { 705 assert(attr.isa<StringAttr>() && 706 "expected string value for non integer/index/float element"); 707 stringValues.push_back(attr.cast<StringAttr>().getValue()); 708 } 709 return get(type, stringValues); 710 } 711 712 // Otherwise, get the raw storage width to use for the allocation. 713 size_t bitWidth = getDenseElementBitWidth(eltType); 714 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 715 716 // Compress the attribute values into a character buffer. 717 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 718 values.size()); 719 APInt intVal; 720 for (unsigned i = 0, e = values.size(); i < e; ++i) { 721 assert(eltType == values[i].getType() && 722 "expected attribute value to have element type"); 723 if (eltType.isa<FloatType>()) 724 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 725 else if (eltType.isa<IntegerType, IndexType>()) 726 intVal = values[i].cast<IntegerAttr>().getValue(); 727 else 728 llvm_unreachable("unexpected element type"); 729 730 assert(intVal.getBitWidth() == bitWidth && 731 "expected value to have same bitwidth as element type"); 732 writeBits(data.data(), i * storageBitWidth, intVal); 733 } 734 return DenseIntOrFPElementsAttr::getRaw(type, data, 735 /*isSplat=*/(values.size() == 1)); 736 } 737 738 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 739 ArrayRef<bool> values) { 740 assert(hasSameElementsOrSplat(type, values)); 741 assert(type.getElementType().isInteger(1)); 742 743 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 744 for (int i = 0, e = values.size(); i != e; ++i) 745 setBit(buff.data(), i, values[i]); 746 return DenseIntOrFPElementsAttr::getRaw(type, buff, 747 /*isSplat=*/(values.size() == 1)); 748 } 749 750 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 751 ArrayRef<StringRef> values) { 752 assert(!type.getElementType().isIntOrFloat()); 753 return DenseStringElementsAttr::get(type, values); 754 } 755 756 /// Constructs a dense integer elements attribute from an array of APInt 757 /// values. Each APInt value is expected to have the same bitwidth as the 758 /// element type of 'type'. 759 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 760 ArrayRef<APInt> values) { 761 assert(type.getElementType().isIntOrIndex()); 762 assert(hasSameElementsOrSplat(type, values)); 763 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 764 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 765 /*isSplat=*/(values.size() == 1)); 766 } 767 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 768 ArrayRef<std::complex<APInt>> values) { 769 ComplexType complex = type.getElementType().cast<ComplexType>(); 770 assert(complex.getElementType().isa<IntegerType>()); 771 assert(hasSameElementsOrSplat(type, values)); 772 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 773 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 774 values.size() * 2); 775 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, 776 /*isSplat=*/(values.size() == 1)); 777 } 778 779 // Constructs a dense float elements attribute from an array of APFloat 780 // values. Each APFloat value is expected to have the same bitwidth as the 781 // element type of 'type'. 782 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 783 ArrayRef<APFloat> values) { 784 assert(type.getElementType().isa<FloatType>()); 785 assert(hasSameElementsOrSplat(type, values)); 786 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 787 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 788 /*isSplat=*/(values.size() == 1)); 789 } 790 DenseElementsAttr 791 DenseElementsAttr::get(ShapedType type, 792 ArrayRef<std::complex<APFloat>> values) { 793 ComplexType complex = type.getElementType().cast<ComplexType>(); 794 assert(complex.getElementType().isa<FloatType>()); 795 assert(hasSameElementsOrSplat(type, values)); 796 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 797 values.size() * 2); 798 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 799 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, 800 /*isSplat=*/(values.size() == 1)); 801 } 802 803 /// Construct a dense elements attribute from a raw buffer representing the 804 /// data for this attribute. Users should generally not use this methods as 805 /// the expected buffer format may not be a form the user expects. 806 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 807 ArrayRef<char> rawBuffer, 808 bool isSplatBuffer) { 809 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 810 } 811 812 /// Returns true if the given buffer is a valid raw buffer for the given type. 813 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 814 ArrayRef<char> rawBuffer, 815 bool &detectedSplat) { 816 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 817 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 818 819 // Storage width of 1 is special as it is packed by the bit. 820 if (storageWidth == 1) { 821 // Check for a splat, or a buffer equal to the number of elements which 822 // consists of either all 0's or all 1's. 823 detectedSplat = false; 824 if (rawBuffer.size() == 1) { 825 auto rawByte = static_cast<uint8_t>(rawBuffer[0]); 826 if (rawByte == 0 || rawByte == 0xff) { 827 detectedSplat = true; 828 return true; 829 } 830 } 831 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); 832 } 833 // All other types are 8-bit aligned. 834 if ((detectedSplat = rawBufferWidth == storageWidth)) 835 return true; 836 return rawBufferWidth == (storageWidth * type.getNumElements()); 837 } 838 839 /// Check the information for a C++ data type, check if this type is valid for 840 /// the current attribute. This method is used to verify specific type 841 /// invariants that the templatized 'getValues' method cannot. 842 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 843 bool isSigned) { 844 // Make sure that the data element size is the same as the type element width. 845 if (getDenseElementBitWidth(type) != 846 static_cast<size_t>(dataEltSize * CHAR_BIT)) 847 return false; 848 849 // Check that the element type is either float or integer or index. 850 if (!isInt) 851 return type.isa<FloatType>(); 852 if (type.isIndex()) 853 return true; 854 855 auto intType = type.dyn_cast<IntegerType>(); 856 if (!intType) 857 return false; 858 859 // Make sure signedness semantics is consistent. 860 if (intType.isSignless()) 861 return true; 862 return intType.isSigned() ? isSigned : !isSigned; 863 } 864 865 /// Defaults down the subclass implementation. 866 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 867 ArrayRef<char> data, 868 int64_t dataEltSize, 869 bool isInt, bool isSigned) { 870 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 871 isSigned); 872 } 873 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 874 ArrayRef<char> data, 875 int64_t dataEltSize, 876 bool isInt, 877 bool isSigned) { 878 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 879 isInt, isSigned); 880 } 881 882 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 883 bool isSigned) const { 884 return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned); 885 } 886 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 887 bool isSigned) const { 888 return ::isValidIntOrFloat( 889 getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, 890 isInt, isSigned); 891 } 892 893 /// Returns true if this attribute corresponds to a splat, i.e. if all element 894 /// values are the same. 895 bool DenseElementsAttr::isSplat() const { 896 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 897 } 898 899 /// Return if the given complex type has an integer element type. 900 LLVM_ATTRIBUTE_UNUSED static bool isComplexOfIntType(Type type) { 901 return type.cast<ComplexType>().getElementType().isa<IntegerType>(); 902 } 903 904 auto DenseElementsAttr::getComplexIntValues() const 905 -> iterator_range_impl<ComplexIntElementIterator> { 906 assert(isComplexOfIntType(getElementType()) && 907 "expected complex integral type"); 908 return {getType(), ComplexIntElementIterator(*this, 0), 909 ComplexIntElementIterator(*this, getNumElements())}; 910 } 911 auto DenseElementsAttr::complex_value_begin() const 912 -> ComplexIntElementIterator { 913 assert(isComplexOfIntType(getElementType()) && 914 "expected complex integral type"); 915 return ComplexIntElementIterator(*this, 0); 916 } 917 auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator { 918 assert(isComplexOfIntType(getElementType()) && 919 "expected complex integral type"); 920 return ComplexIntElementIterator(*this, getNumElements()); 921 } 922 923 /// Return the held element values as a range of APFloat. The element type of 924 /// this attribute must be of float type. 925 auto DenseElementsAttr::getFloatValues() const 926 -> iterator_range_impl<FloatElementIterator> { 927 auto elementType = getElementType().cast<FloatType>(); 928 const auto &elementSemantics = elementType.getFloatSemantics(); 929 return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()), 930 FloatElementIterator(elementSemantics, raw_int_end())}; 931 } 932 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 933 auto elementType = getElementType().cast<FloatType>(); 934 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin()); 935 } 936 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 937 auto elementType = getElementType().cast<FloatType>(); 938 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end()); 939 } 940 941 auto DenseElementsAttr::getComplexFloatValues() const 942 -> iterator_range_impl<ComplexFloatElementIterator> { 943 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 944 assert(eltTy.isa<FloatType>() && "expected complex float type"); 945 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 946 return {getType(), 947 {semantics, {*this, 0}}, 948 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 949 } 950 auto DenseElementsAttr::complex_float_value_begin() const 951 -> ComplexFloatElementIterator { 952 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 953 assert(eltTy.isa<FloatType>() && "expected complex float type"); 954 return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}}; 955 } 956 auto DenseElementsAttr::complex_float_value_end() const 957 -> ComplexFloatElementIterator { 958 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 959 assert(eltTy.isa<FloatType>() && "expected complex float type"); 960 return {eltTy.cast<FloatType>().getFloatSemantics(), 961 {*this, static_cast<size_t>(getNumElements())}}; 962 } 963 964 /// Return the raw storage data held by this attribute. 965 ArrayRef<char> DenseElementsAttr::getRawData() const { 966 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data; 967 } 968 969 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 970 return static_cast<DenseStringElementsAttrStorage *>(impl)->data; 971 } 972 973 /// Return a new DenseElementsAttr that has the same data as the current 974 /// attribute, but has been reshaped to 'newType'. The new type must have the 975 /// same total number of elements as well as element type. 976 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 977 ShapedType curType = getType(); 978 if (curType == newType) 979 return *this; 980 981 assert(newType.getElementType() == curType.getElementType() && 982 "expected the same element type"); 983 assert(newType.getNumElements() == curType.getNumElements() && 984 "expected the same number of elements"); 985 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 986 } 987 988 /// Return a new DenseElementsAttr that has the same data as the current 989 /// attribute, but has bitcast elements such that it is now 'newType'. The new 990 /// type must have the same shape and element types of the same bitwidth as the 991 /// current type. 992 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) { 993 ShapedType curType = getType(); 994 Type curElType = curType.getElementType(); 995 if (curElType == newElType) 996 return *this; 997 998 assert(getDenseElementBitWidth(newElType) == 999 getDenseElementBitWidth(curElType) && 1000 "expected element types with the same bitwidth"); 1001 return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType), 1002 getRawData(), isSplat()); 1003 } 1004 1005 DenseElementsAttr 1006 DenseElementsAttr::mapValues(Type newElementType, 1007 function_ref<APInt(const APInt &)> mapping) const { 1008 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 1009 } 1010 1011 DenseElementsAttr DenseElementsAttr::mapValues( 1012 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1013 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 1014 } 1015 1016 ShapedType DenseElementsAttr::getType() const { 1017 return Attribute::getType().cast<ShapedType>(); 1018 } 1019 1020 Type DenseElementsAttr::getElementType() const { 1021 return getType().getElementType(); 1022 } 1023 1024 int64_t DenseElementsAttr::getNumElements() const { 1025 return getType().getNumElements(); 1026 } 1027 1028 //===----------------------------------------------------------------------===// 1029 // DenseIntOrFPElementsAttr 1030 //===----------------------------------------------------------------------===// 1031 1032 /// Utility method to write a range of APInt values to a buffer. 1033 template <typename APRangeT> 1034 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1035 APRangeT &&values) { 1036 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); 1037 size_t offset = 0; 1038 for (auto it = values.begin(), e = values.end(); it != e; 1039 ++it, offset += storageWidth) { 1040 assert((*it).getBitWidth() <= storageWidth); 1041 writeBits(data.data(), offset, *it); 1042 } 1043 } 1044 1045 /// Constructs a dense elements attribute from an array of raw APFloat values. 1046 /// Each APFloat value is expected to have the same bitwidth as the element 1047 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1048 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1049 size_t storageWidth, 1050 ArrayRef<APFloat> values, 1051 bool isSplat) { 1052 std::vector<char> data; 1053 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1054 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1055 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1056 } 1057 1058 /// Constructs a dense elements attribute from an array of raw APInt values. 1059 /// Each APInt value is expected to have the same bitwidth as the element type 1060 /// of 'type'. 1061 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1062 size_t storageWidth, 1063 ArrayRef<APInt> values, 1064 bool isSplat) { 1065 std::vector<char> data; 1066 writeAPIntsToBuffer(storageWidth, data, values); 1067 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1068 } 1069 1070 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1071 ArrayRef<char> data, 1072 bool isSplat) { 1073 assert((type.isa<RankedTensorType, VectorType>()) && 1074 "type must be ranked tensor or vector"); 1075 assert(type.hasStaticShape() && "type must have static shape"); 1076 return Base::get(type.getContext(), type, data, isSplat); 1077 } 1078 1079 /// Overload of the raw 'get' method that asserts that the given type is of 1080 /// complex type. This method is used to verify type invariants that the 1081 /// templatized 'get' method cannot. 1082 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1083 ArrayRef<char> data, 1084 int64_t dataEltSize, 1085 bool isInt, 1086 bool isSigned) { 1087 assert(::isValidIntOrFloat( 1088 type.getElementType().cast<ComplexType>().getElementType(), 1089 dataEltSize / 2, isInt, isSigned)); 1090 1091 int64_t numElements = data.size() / dataEltSize; 1092 assert(numElements == 1 || numElements == type.getNumElements()); 1093 return getRaw(type, data, /*isSplat=*/numElements == 1); 1094 } 1095 1096 /// Overload of the 'getRaw' method that asserts that the given type is of 1097 /// integer type. This method is used to verify type invariants that the 1098 /// templatized 'get' method cannot. 1099 DenseElementsAttr 1100 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1101 int64_t dataEltSize, bool isInt, 1102 bool isSigned) { 1103 assert( 1104 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1105 1106 int64_t numElements = data.size() / dataEltSize; 1107 assert(numElements == 1 || numElements == type.getNumElements()); 1108 return getRaw(type, data, /*isSplat=*/numElements == 1); 1109 } 1110 1111 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 1112 const char *inRawData, char *outRawData, size_t elementBitWidth, 1113 size_t numElements) { 1114 using llvm::support::ulittle16_t; 1115 using llvm::support::ulittle32_t; 1116 using llvm::support::ulittle64_t; 1117 1118 assert(llvm::support::endian::system_endianness() == // NOLINT 1119 llvm::support::endianness::big); // NOLINT 1120 // NOLINT to avoid warning message about replacing by static_assert() 1121 1122 // Following std::copy_n always converts endianness on BE machine. 1123 switch (elementBitWidth) { 1124 case 16: { 1125 const ulittle16_t *inRawDataPos = 1126 reinterpret_cast<const ulittle16_t *>(inRawData); 1127 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); 1128 std::copy_n(inRawDataPos, numElements, outDataPos); 1129 break; 1130 } 1131 case 32: { 1132 const ulittle32_t *inRawDataPos = 1133 reinterpret_cast<const ulittle32_t *>(inRawData); 1134 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); 1135 std::copy_n(inRawDataPos, numElements, outDataPos); 1136 break; 1137 } 1138 case 64: { 1139 const ulittle64_t *inRawDataPos = 1140 reinterpret_cast<const ulittle64_t *>(inRawData); 1141 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); 1142 std::copy_n(inRawDataPos, numElements, outDataPos); 1143 break; 1144 } 1145 default: { 1146 size_t nBytes = elementBitWidth / CHAR_BIT; 1147 for (size_t i = 0; i < nBytes; i++) 1148 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); 1149 break; 1150 } 1151 } 1152 } 1153 1154 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1155 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, 1156 ShapedType type) { 1157 size_t numElements = type.getNumElements(); 1158 Type elementType = type.getElementType(); 1159 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1160 elementType = complexTy.getElementType(); 1161 numElements = numElements * 2; 1162 } 1163 size_t elementBitWidth = getDenseElementStorageWidth(elementType); 1164 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && 1165 inRawData.size() <= outRawData.size()); 1166 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), 1167 elementBitWidth, numElements); 1168 } 1169 1170 //===----------------------------------------------------------------------===// 1171 // DenseFPElementsAttr 1172 //===----------------------------------------------------------------------===// 1173 1174 template <typename Fn, typename Attr> 1175 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1176 Type newElementType, 1177 llvm::SmallVectorImpl<char> &data) { 1178 size_t bitWidth = getDenseElementBitWidth(newElementType); 1179 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1180 1181 ShapedType newArrayType; 1182 if (inType.isa<RankedTensorType>()) 1183 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1184 else if (inType.isa<UnrankedTensorType>()) 1185 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1186 else if (inType.isa<VectorType>()) 1187 newArrayType = VectorType::get(inType.getShape(), newElementType); 1188 else 1189 assert(newArrayType && "Unhandled tensor type"); 1190 1191 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1192 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 1193 1194 // Functor used to process a single element value of the attribute. 1195 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1196 auto newInt = mapping(value); 1197 assert(newInt.getBitWidth() == bitWidth); 1198 writeBits(data.data(), index * storageBitWidth, newInt); 1199 }; 1200 1201 // Check for the splat case. 1202 if (attr.isSplat()) { 1203 processElt(*attr.begin(), /*index=*/0); 1204 return newArrayType; 1205 } 1206 1207 // Otherwise, process all of the element values. 1208 uint64_t elementIdx = 0; 1209 for (auto value : attr) 1210 processElt(value, elementIdx++); 1211 return newArrayType; 1212 } 1213 1214 DenseElementsAttr DenseFPElementsAttr::mapValues( 1215 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1216 llvm::SmallVector<char, 8> elementData; 1217 auto newArrayType = 1218 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1219 1220 return getRaw(newArrayType, elementData, isSplat()); 1221 } 1222 1223 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1224 bool DenseFPElementsAttr::classof(Attribute attr) { 1225 return attr.isa<DenseElementsAttr>() && 1226 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1227 } 1228 1229 //===----------------------------------------------------------------------===// 1230 // DenseIntElementsAttr 1231 //===----------------------------------------------------------------------===// 1232 1233 DenseElementsAttr DenseIntElementsAttr::mapValues( 1234 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1235 llvm::SmallVector<char, 8> elementData; 1236 auto newArrayType = 1237 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1238 1239 return getRaw(newArrayType, elementData, isSplat()); 1240 } 1241 1242 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1243 bool DenseIntElementsAttr::classof(Attribute attr) { 1244 return attr.isa<DenseElementsAttr>() && 1245 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1246 } 1247 1248 //===----------------------------------------------------------------------===// 1249 // OpaqueElementsAttr 1250 //===----------------------------------------------------------------------===// 1251 1252 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1253 Dialect *dialect = getDialect().getDialect(); 1254 if (!dialect) 1255 return true; 1256 auto *interface = 1257 dialect->getRegisteredInterface<DialectDecodeAttributesInterface>(); 1258 if (!interface) 1259 return true; 1260 return failed(interface->decode(*this, result)); 1261 } 1262 1263 LogicalResult 1264 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1265 Identifier dialect, StringRef value, 1266 ShapedType type) { 1267 if (!Dialect::isValidNamespace(dialect.strref())) 1268 return emitError() << "invalid dialect namespace '" << dialect << "'"; 1269 return success(); 1270 } 1271 1272 //===----------------------------------------------------------------------===// 1273 // SparseElementsAttr 1274 //===----------------------------------------------------------------------===// 1275 1276 /// Get a zero APFloat for the given sparse attribute. 1277 APFloat SparseElementsAttr::getZeroAPFloat() const { 1278 auto eltType = getElementType().cast<FloatType>(); 1279 return APFloat(eltType.getFloatSemantics()); 1280 } 1281 1282 /// Get a zero APInt for the given sparse attribute. 1283 APInt SparseElementsAttr::getZeroAPInt() const { 1284 auto eltType = getElementType().cast<IntegerType>(); 1285 return APInt::getZero(eltType.getWidth()); 1286 } 1287 1288 /// Get a zero attribute for the given attribute type. 1289 Attribute SparseElementsAttr::getZeroAttr() const { 1290 auto eltType = getElementType(); 1291 1292 // Handle floating point elements. 1293 if (eltType.isa<FloatType>()) 1294 return FloatAttr::get(eltType, 0); 1295 1296 // Handle string type. 1297 if (getValues().isa<DenseStringElementsAttr>()) 1298 return StringAttr::get("", eltType); 1299 1300 // Otherwise, this is an integer. 1301 return IntegerAttr::get(eltType, 0); 1302 } 1303 1304 /// Flatten, and return, all of the sparse indices in this attribute in 1305 /// row-major order. 1306 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1307 std::vector<ptrdiff_t> flatSparseIndices; 1308 1309 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1310 // as a 1-D index array. 1311 auto sparseIndices = getIndices(); 1312 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1313 if (sparseIndices.isSplat()) { 1314 SmallVector<uint64_t, 8> indices(getType().getRank(), 1315 *sparseIndexValues.begin()); 1316 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1317 return flatSparseIndices; 1318 } 1319 1320 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1321 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1322 size_t rank = getType().getRank(); 1323 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1324 flatSparseIndices.push_back(getFlattenedIndex( 1325 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1326 return flatSparseIndices; 1327 } 1328 1329 LogicalResult 1330 SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1331 ShapedType type, DenseIntElementsAttr sparseIndices, 1332 DenseElementsAttr values) { 1333 ShapedType valuesType = values.getType(); 1334 if (valuesType.getRank() != 1) 1335 return emitError() << "expected 1-d tensor for sparse element values"; 1336 1337 // Verify the indices and values shape. 1338 ShapedType indicesType = sparseIndices.getType(); 1339 auto emitShapeError = [&]() { 1340 return emitError() << "expected shape ([" << type.getShape() 1341 << "]); inferred shape of indices literal ([" 1342 << indicesType.getShape() 1343 << "]); inferred shape of values literal ([" 1344 << valuesType.getShape() << "])"; 1345 }; 1346 // Verify indices shape. 1347 size_t rank = type.getRank(), indicesRank = indicesType.getRank(); 1348 if (indicesRank == 2) { 1349 if (indicesType.getDimSize(1) != static_cast<int64_t>(rank)) 1350 return emitShapeError(); 1351 } else if (indicesRank != 1 || rank != 1) { 1352 return emitShapeError(); 1353 } 1354 // Verify the values shape. 1355 int64_t numSparseIndices = indicesType.getDimSize(0); 1356 if (numSparseIndices != valuesType.getDimSize(0)) 1357 return emitShapeError(); 1358 1359 // Verify that the sparse indices are within the value shape. 1360 auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) { 1361 return emitError() 1362 << "sparse index #" << indexNum 1363 << " is not contained within the value shape, with index=[" << index 1364 << "], and type=" << type; 1365 }; 1366 1367 // Handle the case where the index values are a splat. 1368 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1369 if (sparseIndices.isSplat()) { 1370 SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin()); 1371 if (!ElementsAttr::isValidIndex(type, indices)) 1372 return emitIndexError(0, indices); 1373 return success(); 1374 } 1375 1376 // Otherwise, reinterpret each index as an ArrayRef. 1377 for (size_t i = 0, e = numSparseIndices; i != e; ++i) { 1378 ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank), 1379 rank); 1380 if (!ElementsAttr::isValidIndex(type, index)) 1381 return emitIndexError(i, index); 1382 } 1383 1384 return success(); 1385 } 1386 1387 //===----------------------------------------------------------------------===// 1388 // TypeAttr 1389 //===----------------------------------------------------------------------===// 1390 1391 void TypeAttr::walkImmediateSubElements( 1392 function_ref<void(Attribute)> walkAttrsFn, 1393 function_ref<void(Type)> walkTypesFn) const { 1394 walkTypesFn(getValue()); 1395 } 1396