1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/BuiltinDialect.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/IntegerSet.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/SymbolTable.h" 17 #include "mlir/IR/Types.h" 18 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" 19 #include "llvm/ADT/APSInt.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/Support/Endian.h" 22 23 using namespace mlir; 24 using namespace mlir::detail; 25 26 //===----------------------------------------------------------------------===// 27 /// Tablegen Attribute Definitions 28 //===----------------------------------------------------------------------===// 29 30 #define GET_ATTRDEF_CLASSES 31 #include "mlir/IR/BuiltinAttributes.cpp.inc" 32 33 //===----------------------------------------------------------------------===// 34 // BuiltinDialect 35 //===----------------------------------------------------------------------===// 36 37 void BuiltinDialect::registerAttributes() { 38 addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr, 39 DenseStringElementsAttr, DictionaryAttr, FloatAttr, 40 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr, 41 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr, 42 UnitAttr>(); 43 } 44 45 //===----------------------------------------------------------------------===// 46 // ArrayAttr 47 //===----------------------------------------------------------------------===// 48 49 void ArrayAttr::walkImmediateSubElements( 50 function_ref<void(Attribute)> walkAttrsFn, 51 function_ref<void(Type)> walkTypesFn) const { 52 for (Attribute attr : getValue()) 53 walkAttrsFn(attr); 54 } 55 56 SubElementAttrInterface ArrayAttr::replaceImmediateSubAttribute( 57 ArrayRef<std::pair<size_t, Attribute>> replacements) const { 58 std::vector<Attribute> vector = getValue().vec(); 59 for (auto &it : replacements) { 60 vector[it.first] = it.second; 61 } 62 return get(getContext(), vector); 63 } 64 65 //===----------------------------------------------------------------------===// 66 // DictionaryAttr 67 //===----------------------------------------------------------------------===// 68 69 /// Helper function that does either an in place sort or sorts from source array 70 /// into destination. If inPlace then storage is both the source and the 71 /// destination, else value is the source and storage destination. Returns 72 /// whether source was sorted. 73 template <bool inPlace> 74 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 75 SmallVectorImpl<NamedAttribute> &storage) { 76 // Specialize for the common case. 77 switch (value.size()) { 78 case 0: 79 // Zero already sorted. 80 if (!inPlace) 81 storage.clear(); 82 break; 83 case 1: 84 // One already sorted but may need to be copied. 85 if (!inPlace) 86 storage.assign({value[0]}); 87 break; 88 case 2: { 89 bool isSorted = value[0] < value[1]; 90 if (inPlace) { 91 if (!isSorted) 92 std::swap(storage[0], storage[1]); 93 } else if (isSorted) { 94 storage.assign({value[0], value[1]}); 95 } else { 96 storage.assign({value[1], value[0]}); 97 } 98 return !isSorted; 99 } 100 default: 101 if (!inPlace) 102 storage.assign(value.begin(), value.end()); 103 // Check to see they are sorted already. 104 bool isSorted = llvm::is_sorted(value); 105 // If not, do a general sort. 106 if (!isSorted) 107 llvm::array_pod_sort(storage.begin(), storage.end()); 108 return !isSorted; 109 } 110 return false; 111 } 112 113 /// Returns an entry with a duplicate name from the given sorted array of named 114 /// attributes. Returns llvm::None if all elements have unique names. 115 static Optional<NamedAttribute> 116 findDuplicateElement(ArrayRef<NamedAttribute> value) { 117 const Optional<NamedAttribute> none{llvm::None}; 118 if (value.size() < 2) 119 return none; 120 121 if (value.size() == 2) 122 return value[0].getName() == value[1].getName() ? value[0] : none; 123 124 const auto *it = std::adjacent_find(value.begin(), value.end(), 125 [](NamedAttribute l, NamedAttribute r) { 126 return l.getName() == r.getName(); 127 }); 128 return it != value.end() ? *it : none; 129 } 130 131 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 132 SmallVectorImpl<NamedAttribute> &storage) { 133 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); 134 assert(!findDuplicateElement(storage) && 135 "DictionaryAttr element names must be unique"); 136 return isSorted; 137 } 138 139 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 140 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); 141 assert(!findDuplicateElement(array) && 142 "DictionaryAttr element names must be unique"); 143 return isSorted; 144 } 145 146 Optional<NamedAttribute> 147 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, 148 bool isSorted) { 149 if (!isSorted) 150 dictionaryAttrSort</*inPlace=*/true>(array, array); 151 return findDuplicateElement(array); 152 } 153 154 DictionaryAttr DictionaryAttr::get(MLIRContext *context, 155 ArrayRef<NamedAttribute> value) { 156 if (value.empty()) 157 return DictionaryAttr::getEmpty(context); 158 159 // We need to sort the element list to canonicalize it. 160 SmallVector<NamedAttribute, 8> storage; 161 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 162 value = storage; 163 assert(!findDuplicateElement(value) && 164 "DictionaryAttr element names must be unique"); 165 return Base::get(context, value); 166 } 167 /// Construct a dictionary with an array of values that is known to already be 168 /// sorted by name and uniqued. 169 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context, 170 ArrayRef<NamedAttribute> value) { 171 if (value.empty()) 172 return DictionaryAttr::getEmpty(context); 173 // Ensure that the attribute elements are unique and sorted. 174 assert(llvm::is_sorted( 175 value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && 176 "expected attribute values to be sorted"); 177 assert(!findDuplicateElement(value) && 178 "DictionaryAttr element names must be unique"); 179 return Base::get(context, value); 180 } 181 182 /// Return the specified attribute if present, null otherwise. 183 Attribute DictionaryAttr::get(StringRef name) const { 184 auto it = impl::findAttrSorted(begin(), end(), name); 185 return it.second ? it.first->getValue() : Attribute(); 186 } 187 Attribute DictionaryAttr::get(StringAttr name) const { 188 auto it = impl::findAttrSorted(begin(), end(), name); 189 return it.second ? it.first->getValue() : Attribute(); 190 } 191 192 /// Return the specified named attribute if present, None otherwise. 193 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 194 auto it = impl::findAttrSorted(begin(), end(), name); 195 return it.second ? *it.first : Optional<NamedAttribute>(); 196 } 197 Optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const { 198 auto it = impl::findAttrSorted(begin(), end(), name); 199 return it.second ? *it.first : Optional<NamedAttribute>(); 200 } 201 202 /// Return whether the specified attribute is present. 203 bool DictionaryAttr::contains(StringRef name) const { 204 return impl::findAttrSorted(begin(), end(), name).second; 205 } 206 bool DictionaryAttr::contains(StringAttr name) const { 207 return impl::findAttrSorted(begin(), end(), name).second; 208 } 209 210 DictionaryAttr::iterator DictionaryAttr::begin() const { 211 return getValue().begin(); 212 } 213 DictionaryAttr::iterator DictionaryAttr::end() const { 214 return getValue().end(); 215 } 216 size_t DictionaryAttr::size() const { return getValue().size(); } 217 218 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { 219 return Base::get(context, ArrayRef<NamedAttribute>()); 220 } 221 222 void DictionaryAttr::walkImmediateSubElements( 223 function_ref<void(Attribute)> walkAttrsFn, 224 function_ref<void(Type)> walkTypesFn) const { 225 for (const NamedAttribute &attr : getValue()) 226 walkAttrsFn(attr.getValue()); 227 } 228 229 SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute( 230 ArrayRef<std::pair<size_t, Attribute>> replacements) const { 231 std::vector<NamedAttribute> vec = getValue().vec(); 232 for (auto &it : replacements) 233 vec[it.first].setValue(it.second); 234 235 // The above only modifies the mapped value, but not the key, and therefore 236 // not the order of the elements. It remains sorted 237 return getWithSorted(getContext(), vec); 238 } 239 240 //===----------------------------------------------------------------------===// 241 // StringAttr 242 //===----------------------------------------------------------------------===// 243 244 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) { 245 return Base::get(context, "", NoneType::get(context)); 246 } 247 248 /// Twine support for StringAttr. 249 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) { 250 // Fast-path empty twine. 251 if (twine.isTriviallyEmpty()) 252 return get(context); 253 SmallVector<char, 32> tempStr; 254 return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context)); 255 } 256 257 /// Twine support for StringAttr. 258 StringAttr StringAttr::get(const Twine &twine, Type type) { 259 SmallVector<char, 32> tempStr; 260 return Base::get(type.getContext(), twine.toStringRef(tempStr), type); 261 } 262 263 StringRef StringAttr::getValue() const { return getImpl()->value; } 264 265 Dialect *StringAttr::getReferencedDialect() const { 266 return getImpl()->referencedDialect; 267 } 268 269 //===----------------------------------------------------------------------===// 270 // FloatAttr 271 //===----------------------------------------------------------------------===// 272 273 double FloatAttr::getValueAsDouble() const { 274 return getValueAsDouble(getValue()); 275 } 276 double FloatAttr::getValueAsDouble(APFloat value) { 277 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 278 bool losesInfo = false; 279 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 280 &losesInfo); 281 } 282 return value.convertToDouble(); 283 } 284 285 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 286 Type type, APFloat value) { 287 // Verify that the type is correct. 288 if (!type.isa<FloatType>()) 289 return emitError() << "expected floating point type"; 290 291 // Verify that the type semantics match that of the value. 292 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 293 return emitError() 294 << "FloatAttr type doesn't match the type implied by its value"; 295 } 296 return success(); 297 } 298 299 //===----------------------------------------------------------------------===// 300 // SymbolRefAttr 301 //===----------------------------------------------------------------------===// 302 303 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value, 304 ArrayRef<FlatSymbolRefAttr> nestedRefs) { 305 return get(StringAttr::get(ctx, value), nestedRefs); 306 } 307 308 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { 309 return get(ctx, value, {}).cast<FlatSymbolRefAttr>(); 310 } 311 312 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) { 313 return get(value, {}).cast<FlatSymbolRefAttr>(); 314 } 315 316 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) { 317 auto symName = 318 symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); 319 assert(symName && "value does not have a valid symbol name"); 320 return SymbolRefAttr::get(symName); 321 } 322 323 StringAttr SymbolRefAttr::getLeafReference() const { 324 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 325 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr(); 326 } 327 328 //===----------------------------------------------------------------------===// 329 // IntegerAttr 330 //===----------------------------------------------------------------------===// 331 332 int64_t IntegerAttr::getInt() const { 333 assert((getType().isIndex() || getType().isSignlessInteger()) && 334 "must be signless integer"); 335 return getValue().getSExtValue(); 336 } 337 338 int64_t IntegerAttr::getSInt() const { 339 assert(getType().isSignedInteger() && "must be signed integer"); 340 return getValue().getSExtValue(); 341 } 342 343 uint64_t IntegerAttr::getUInt() const { 344 assert(getType().isUnsignedInteger() && "must be unsigned integer"); 345 return getValue().getZExtValue(); 346 } 347 348 /// Return the value as an APSInt which carries the signed from the type of 349 /// the attribute. This traps on signless integers types! 350 APSInt IntegerAttr::getAPSInt() const { 351 assert(!getType().isSignlessInteger() && 352 "Signless integers don't carry a sign for APSInt"); 353 return APSInt(getValue(), getType().isUnsignedInteger()); 354 } 355 356 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, 357 Type type, APInt value) { 358 if (IntegerType integerType = type.dyn_cast<IntegerType>()) { 359 if (integerType.getWidth() != value.getBitWidth()) 360 return emitError() << "integer type bit width (" << integerType.getWidth() 361 << ") doesn't match value bit width (" 362 << value.getBitWidth() << ")"; 363 return success(); 364 } 365 if (type.isa<IndexType>()) 366 return success(); 367 return emitError() << "expected integer or index type"; 368 } 369 370 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { 371 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value)); 372 return attr.cast<BoolAttr>(); 373 } 374 375 //===----------------------------------------------------------------------===// 376 // BoolAttr 377 //===----------------------------------------------------------------------===// 378 379 bool BoolAttr::getValue() const { 380 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl); 381 return storage->value.getBoolValue(); 382 } 383 384 bool BoolAttr::classof(Attribute attr) { 385 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 386 return intAttr && intAttr.getType().isSignlessInteger(1); 387 } 388 389 //===----------------------------------------------------------------------===// 390 // OpaqueAttr 391 //===----------------------------------------------------------------------===// 392 393 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError, 394 StringAttr dialect, StringRef attrData, 395 Type type) { 396 if (!Dialect::isValidNamespace(dialect.strref())) 397 return emitError() << "invalid dialect namespace '" << dialect << "'"; 398 399 // Check that the dialect is actually registered. 400 MLIRContext *context = dialect.getContext(); 401 if (!context->allowsUnregisteredDialects() && 402 !context->getLoadedDialect(dialect.strref())) { 403 return emitError() 404 << "#" << dialect << "<\"" << attrData << "\"> : " << type 405 << " attribute created with unregistered dialect. If this is " 406 "intended, please call allowUnregisteredDialects() on the " 407 "MLIRContext, or use -allow-unregistered-dialect with " 408 "the MLIR opt tool used"; 409 } 410 411 return success(); 412 } 413 414 //===----------------------------------------------------------------------===// 415 // DenseElementsAttr Utilities 416 //===----------------------------------------------------------------------===// 417 418 /// Get the bitwidth of a dense element type within the buffer. 419 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 420 static size_t getDenseElementStorageWidth(size_t origWidth) { 421 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 422 } 423 static size_t getDenseElementStorageWidth(Type elementType) { 424 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 425 } 426 427 /// Set a bit to a specific value. 428 static void setBit(char *rawData, size_t bitPos, bool value) { 429 if (value) 430 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 431 else 432 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 433 } 434 435 /// Return the value of the specified bit. 436 static bool getBit(const char *rawData, size_t bitPos) { 437 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 438 } 439 440 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for 441 /// BE format. 442 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, 443 char *result) { 444 assert(llvm::support::endian::system_endianness() == // NOLINT 445 llvm::support::endianness::big); // NOLINT 446 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 447 448 // Copy the words filled with data. 449 // For example, when `value` has 2 words, the first word is filled with data. 450 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| 451 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 452 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 453 numFilledWords, result); 454 // Convert last word of APInt to LE format and store it in char 455 // array(`valueLE`). 456 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| 457 size_t lastWordPos = numFilledWords; 458 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); 459 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 460 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, 461 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); 462 // Extract actual APInt data from `valueLE`, convert endianness to BE format, 463 // and store it in `result`. 464 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| 465 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 466 valueLE.begin(), result + lastWordPos, 467 (numBytes - lastWordPos) * CHAR_BIT, 1); 468 } 469 470 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE 471 /// format. 472 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, 473 APInt &result) { 474 assert(llvm::support::endian::system_endianness() == // NOLINT 475 llvm::support::endianness::big); // NOLINT 476 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 477 478 // Copy the data that fills the word of `result` from `inArray`. 479 // For example, when `result` has 2 words, the first word will be filled with 480 // data. So, the first 8 bytes are copied from `inArray` here. 481 // `inArray` (10 bytes, BE): |abcdefgh|ij| 482 // ==> `result` (2 words, BE): |abcdefgh|--------| 483 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 484 std::copy_n( 485 inArray, numFilledWords, 486 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 487 488 // Convert array data which will be last word of `result` to LE format, and 489 // store it in char array(`inArrayLE`). 490 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| 491 size_t lastWordPos = numFilledWords; 492 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); 493 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 494 inArray + lastWordPos, inArrayLE.begin(), 495 (numBytes - lastWordPos) * CHAR_BIT, 1); 496 497 // Convert `inArrayLE` to BE format, and store it in last word of `result`. 498 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| 499 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 500 inArrayLE.begin(), 501 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + 502 lastWordPos, 503 APInt::APINT_BITS_PER_WORD, 1); 504 } 505 506 /// Writes value to the bit position `bitPos` in array `rawData`. 507 static void writeBits(char *rawData, size_t bitPos, APInt value) { 508 size_t bitWidth = value.getBitWidth(); 509 510 // If the bitwidth is 1 we just toggle the specific bit. 511 if (bitWidth == 1) 512 return setBit(rawData, bitPos, value.isOneValue()); 513 514 // Otherwise, the bit position is guaranteed to be byte aligned. 515 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 516 if (llvm::support::endian::system_endianness() == 517 llvm::support::endianness::big) { 518 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. 519 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 520 // work correctly in BE format. 521 // ex. `value` (2 words including 10 bytes) 522 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| 523 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), 524 rawData + (bitPos / CHAR_BIT)); 525 } else { 526 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 527 llvm::divideCeil(bitWidth, CHAR_BIT), 528 rawData + (bitPos / CHAR_BIT)); 529 } 530 } 531 532 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 533 /// `rawData`. 534 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 535 // Handle a boolean bit position. 536 if (bitWidth == 1) 537 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 538 539 // Otherwise, the bit position must be 8-bit aligned. 540 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 541 APInt result(bitWidth, 0); 542 if (llvm::support::endian::system_endianness() == 543 llvm::support::endianness::big) { 544 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. 545 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 546 // work correctly in BE format. 547 // ex. `result` (2 words including 10 bytes) 548 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function 549 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), 550 llvm::divideCeil(bitWidth, CHAR_BIT), result); 551 } else { 552 std::copy_n(rawData + (bitPos / CHAR_BIT), 553 llvm::divideCeil(bitWidth, CHAR_BIT), 554 const_cast<char *>( 555 reinterpret_cast<const char *>(result.getRawData()))); 556 } 557 return result; 558 } 559 560 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 561 /// the same element count as 'type'. 562 template <typename Values> 563 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 564 return (values.size() == 1) || 565 (type.getNumElements() == static_cast<int64_t>(values.size())); 566 } 567 568 //===----------------------------------------------------------------------===// 569 // DenseElementsAttr Iterators 570 //===----------------------------------------------------------------------===// 571 572 //===----------------------------------------------------------------------===// 573 // AttributeElementIterator 574 575 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 576 DenseElementsAttr attr, size_t index) 577 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 578 Attribute, Attribute, Attribute>( 579 attr.getAsOpaquePointer(), index) {} 580 581 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 582 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 583 Type eltTy = owner.getElementType(); 584 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 585 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 586 if (eltTy.isa<IndexType>()) 587 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 588 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 589 IntElementIterator intIt(owner, index); 590 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 591 return FloatAttr::get(eltTy, *floatIt); 592 } 593 if (auto complexTy = eltTy.dyn_cast<ComplexType>()) { 594 auto complexEltTy = complexTy.getElementType(); 595 ComplexIntElementIterator complexIntIt(owner, index); 596 if (complexEltTy.isa<IntegerType>()) { 597 auto value = *complexIntIt; 598 auto real = IntegerAttr::get(complexEltTy, value.real()); 599 auto imag = IntegerAttr::get(complexEltTy, value.imag()); 600 return ArrayAttr::get(complexTy.getContext(), 601 ArrayRef<Attribute>{real, imag}); 602 } 603 604 ComplexFloatElementIterator complexFloatIt( 605 complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt); 606 auto value = *complexFloatIt; 607 auto real = FloatAttr::get(complexEltTy, value.real()); 608 auto imag = FloatAttr::get(complexEltTy, value.imag()); 609 return ArrayAttr::get(complexTy.getContext(), 610 ArrayRef<Attribute>{real, imag}); 611 } 612 if (owner.isa<DenseStringElementsAttr>()) { 613 ArrayRef<StringRef> vals = owner.getRawStringData(); 614 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 615 } 616 llvm_unreachable("unexpected element type"); 617 } 618 619 //===----------------------------------------------------------------------===// 620 // BoolElementIterator 621 622 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 623 DenseElementsAttr attr, size_t dataIndex) 624 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 625 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 626 627 bool DenseElementsAttr::BoolElementIterator::operator*() const { 628 return getBit(getData(), getDataIndex()); 629 } 630 631 //===----------------------------------------------------------------------===// 632 // IntElementIterator 633 634 DenseElementsAttr::IntElementIterator::IntElementIterator( 635 DenseElementsAttr attr, size_t dataIndex) 636 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 637 attr.getRawData().data(), attr.isSplat(), dataIndex), 638 bitWidth(getDenseElementBitWidth(attr.getElementType())) {} 639 640 APInt DenseElementsAttr::IntElementIterator::operator*() const { 641 return readBits(getData(), 642 getDataIndex() * getDenseElementStorageWidth(bitWidth), 643 bitWidth); 644 } 645 646 //===----------------------------------------------------------------------===// 647 // ComplexIntElementIterator 648 649 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 650 DenseElementsAttr attr, size_t dataIndex) 651 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 652 std::complex<APInt>, std::complex<APInt>, 653 std::complex<APInt>>( 654 attr.getRawData().data(), attr.isSplat(), dataIndex) { 655 auto complexType = attr.getElementType().cast<ComplexType>(); 656 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 657 } 658 659 std::complex<APInt> 660 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 661 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 662 size_t offset = getDataIndex() * storageWidth * 2; 663 return {readBits(getData(), offset, bitWidth), 664 readBits(getData(), offset + storageWidth, bitWidth)}; 665 } 666 667 //===----------------------------------------------------------------------===// 668 // DenseElementsAttr 669 //===----------------------------------------------------------------------===// 670 671 /// Method for support type inquiry through isa, cast and dyn_cast. 672 bool DenseElementsAttr::classof(Attribute attr) { 673 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 674 } 675 676 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 677 ArrayRef<Attribute> values) { 678 assert(hasSameElementsOrSplat(type, values)); 679 680 // If the element type is not based on int/float/index, assume it is a string 681 // type. 682 auto eltType = type.getElementType(); 683 if (!type.getElementType().isIntOrIndexOrFloat()) { 684 SmallVector<StringRef, 8> stringValues; 685 stringValues.reserve(values.size()); 686 for (Attribute attr : values) { 687 assert(attr.isa<StringAttr>() && 688 "expected string value for non integer/index/float element"); 689 stringValues.push_back(attr.cast<StringAttr>().getValue()); 690 } 691 return get(type, stringValues); 692 } 693 694 // Otherwise, get the raw storage width to use for the allocation. 695 size_t bitWidth = getDenseElementBitWidth(eltType); 696 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 697 698 // Compress the attribute values into a character buffer. 699 SmallVector<char, 8> data( 700 llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT)); 701 APInt intVal; 702 for (unsigned i = 0, e = values.size(); i < e; ++i) { 703 assert(eltType == values[i].getType() && 704 "expected attribute value to have element type"); 705 if (eltType.isa<FloatType>()) 706 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 707 else if (eltType.isa<IntegerType, IndexType>()) 708 intVal = values[i].cast<IntegerAttr>().getValue(); 709 else 710 llvm_unreachable("unexpected element type"); 711 712 assert(intVal.getBitWidth() == bitWidth && 713 "expected value to have same bitwidth as element type"); 714 writeBits(data.data(), i * storageBitWidth, intVal); 715 } 716 717 // Handle the special encoding of splat of bool. 718 if (values.size() == 1 && values[0].getType().isInteger(1)) 719 data[0] = data[0] ? -1 : 0; 720 721 return DenseIntOrFPElementsAttr::getRaw(type, data); 722 } 723 724 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 725 ArrayRef<bool> values) { 726 assert(hasSameElementsOrSplat(type, values)); 727 assert(type.getElementType().isInteger(1)); 728 729 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 730 731 if (!values.empty()) { 732 bool isSplat = true; 733 bool firstValue = values[0]; 734 for (int i = 0, e = values.size(); i != e; ++i) { 735 isSplat &= values[i] == firstValue; 736 setBit(buff.data(), i, values[i]); 737 } 738 739 // Splat of bool is encoded as a byte with all-ones in it. 740 if (isSplat) { 741 buff.resize(1); 742 buff[0] = values[0] ? -1 : 0; 743 } 744 } 745 746 return DenseIntOrFPElementsAttr::getRaw(type, buff); 747 } 748 749 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 750 ArrayRef<StringRef> values) { 751 assert(!type.getElementType().isIntOrFloat()); 752 return DenseStringElementsAttr::get(type, values); 753 } 754 755 /// Constructs a dense integer elements attribute from an array of APInt 756 /// values. Each APInt value is expected to have the same bitwidth as the 757 /// element type of 'type'. 758 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 759 ArrayRef<APInt> values) { 760 assert(type.getElementType().isIntOrIndex()); 761 assert(hasSameElementsOrSplat(type, values)); 762 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 763 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); 764 } 765 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 766 ArrayRef<std::complex<APInt>> values) { 767 ComplexType complex = type.getElementType().cast<ComplexType>(); 768 assert(complex.getElementType().isa<IntegerType>()); 769 assert(hasSameElementsOrSplat(type, values)); 770 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 771 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 772 values.size() * 2); 773 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals); 774 } 775 776 // Constructs a dense float elements attribute from an array of APFloat 777 // values. Each APFloat value is expected to have the same bitwidth as the 778 // element type of 'type'. 779 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 780 ArrayRef<APFloat> values) { 781 assert(type.getElementType().isa<FloatType>()); 782 assert(hasSameElementsOrSplat(type, values)); 783 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 784 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); 785 } 786 DenseElementsAttr 787 DenseElementsAttr::get(ShapedType type, 788 ArrayRef<std::complex<APFloat>> values) { 789 ComplexType complex = type.getElementType().cast<ComplexType>(); 790 assert(complex.getElementType().isa<FloatType>()); 791 assert(hasSameElementsOrSplat(type, values)); 792 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 793 values.size() * 2); 794 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 795 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals); 796 } 797 798 /// Construct a dense elements attribute from a raw buffer representing the 799 /// data for this attribute. Users should generally not use this methods as 800 /// the expected buffer format may not be a form the user expects. 801 DenseElementsAttr 802 DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) { 803 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer); 804 } 805 806 /// Returns true if the given buffer is a valid raw buffer for the given type. 807 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 808 ArrayRef<char> rawBuffer, 809 bool &detectedSplat) { 810 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 811 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 812 int64_t numElements = type.getNumElements(); 813 814 // The initializer is always a splat if the result type has a single element. 815 detectedSplat = numElements == 1; 816 817 // Storage width of 1 is special as it is packed by the bit. 818 if (storageWidth == 1) { 819 // Check for a splat, or a buffer equal to the number of elements which 820 // consists of either all 0's or all 1's. 821 if (rawBuffer.size() == 1) { 822 auto rawByte = static_cast<uint8_t>(rawBuffer[0]); 823 if (rawByte == 0 || rawByte == 0xff) { 824 detectedSplat = true; 825 return true; 826 } 827 } 828 829 // This is a valid non-splat buffer if it has the right size. 830 return rawBufferWidth == llvm::alignTo<8>(numElements); 831 } 832 833 // All other types are 8-bit aligned, so we can just check the buffer width 834 // to know if only a single initializer element was passed in. 835 if (rawBufferWidth == storageWidth) { 836 detectedSplat = true; 837 return true; 838 } 839 840 // The raw buffer is valid if it has the right size. 841 return rawBufferWidth == storageWidth * numElements; 842 } 843 844 /// Check the information for a C++ data type, check if this type is valid for 845 /// the current attribute. This method is used to verify specific type 846 /// invariants that the templatized 'getValues' method cannot. 847 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 848 bool isSigned) { 849 // Make sure that the data element size is the same as the type element width. 850 if (getDenseElementBitWidth(type) != 851 static_cast<size_t>(dataEltSize * CHAR_BIT)) 852 return false; 853 854 // Check that the element type is either float or integer or index. 855 if (!isInt) 856 return type.isa<FloatType>(); 857 if (type.isIndex()) 858 return true; 859 860 auto intType = type.dyn_cast<IntegerType>(); 861 if (!intType) 862 return false; 863 864 // Make sure signedness semantics is consistent. 865 if (intType.isSignless()) 866 return true; 867 return intType.isSigned() ? isSigned : !isSigned; 868 } 869 870 /// Defaults down the subclass implementation. 871 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 872 ArrayRef<char> data, 873 int64_t dataEltSize, 874 bool isInt, bool isSigned) { 875 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 876 isSigned); 877 } 878 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 879 ArrayRef<char> data, 880 int64_t dataEltSize, 881 bool isInt, 882 bool isSigned) { 883 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 884 isInt, isSigned); 885 } 886 887 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 888 bool isSigned) const { 889 return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned); 890 } 891 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 892 bool isSigned) const { 893 return ::isValidIntOrFloat( 894 getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, 895 isInt, isSigned); 896 } 897 898 /// Returns true if this attribute corresponds to a splat, i.e. if all element 899 /// values are the same. 900 bool DenseElementsAttr::isSplat() const { 901 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 902 } 903 904 /// Return if the given complex type has an integer element type. 905 LLVM_ATTRIBUTE_UNUSED static bool isComplexOfIntType(Type type) { 906 return type.cast<ComplexType>().getElementType().isa<IntegerType>(); 907 } 908 909 auto DenseElementsAttr::getComplexIntValues() const 910 -> iterator_range_impl<ComplexIntElementIterator> { 911 assert(isComplexOfIntType(getElementType()) && 912 "expected complex integral type"); 913 return {getType(), ComplexIntElementIterator(*this, 0), 914 ComplexIntElementIterator(*this, getNumElements())}; 915 } 916 auto DenseElementsAttr::complex_value_begin() const 917 -> ComplexIntElementIterator { 918 assert(isComplexOfIntType(getElementType()) && 919 "expected complex integral type"); 920 return ComplexIntElementIterator(*this, 0); 921 } 922 auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator { 923 assert(isComplexOfIntType(getElementType()) && 924 "expected complex integral type"); 925 return ComplexIntElementIterator(*this, getNumElements()); 926 } 927 928 /// Return the held element values as a range of APFloat. The element type of 929 /// this attribute must be of float type. 930 auto DenseElementsAttr::getFloatValues() const 931 -> iterator_range_impl<FloatElementIterator> { 932 auto elementType = getElementType().cast<FloatType>(); 933 const auto &elementSemantics = elementType.getFloatSemantics(); 934 return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()), 935 FloatElementIterator(elementSemantics, raw_int_end())}; 936 } 937 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 938 auto elementType = getElementType().cast<FloatType>(); 939 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin()); 940 } 941 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 942 auto elementType = getElementType().cast<FloatType>(); 943 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end()); 944 } 945 946 auto DenseElementsAttr::getComplexFloatValues() const 947 -> iterator_range_impl<ComplexFloatElementIterator> { 948 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 949 assert(eltTy.isa<FloatType>() && "expected complex float type"); 950 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 951 return {getType(), 952 {semantics, {*this, 0}}, 953 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 954 } 955 auto DenseElementsAttr::complex_float_value_begin() const 956 -> ComplexFloatElementIterator { 957 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 958 assert(eltTy.isa<FloatType>() && "expected complex float type"); 959 return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}}; 960 } 961 auto DenseElementsAttr::complex_float_value_end() const 962 -> ComplexFloatElementIterator { 963 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 964 assert(eltTy.isa<FloatType>() && "expected complex float type"); 965 return {eltTy.cast<FloatType>().getFloatSemantics(), 966 {*this, static_cast<size_t>(getNumElements())}}; 967 } 968 969 /// Return the raw storage data held by this attribute. 970 ArrayRef<char> DenseElementsAttr::getRawData() const { 971 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data; 972 } 973 974 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 975 return static_cast<DenseStringElementsAttrStorage *>(impl)->data; 976 } 977 978 /// Return a new DenseElementsAttr that has the same data as the current 979 /// attribute, but has been reshaped to 'newType'. The new type must have the 980 /// same total number of elements as well as element type. 981 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 982 ShapedType curType = getType(); 983 if (curType == newType) 984 return *this; 985 986 assert(newType.getElementType() == curType.getElementType() && 987 "expected the same element type"); 988 assert(newType.getNumElements() == curType.getNumElements() && 989 "expected the same number of elements"); 990 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData()); 991 } 992 993 DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) { 994 assert(isSplat() && "expected a splat type"); 995 996 ShapedType curType = getType(); 997 if (curType == newType) 998 return *this; 999 1000 assert(newType.getElementType() == curType.getElementType() && 1001 "expected the same element type"); 1002 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData()); 1003 } 1004 1005 /// Return a new DenseElementsAttr that has the same data as the current 1006 /// attribute, but has bitcast elements such that it is now 'newType'. The new 1007 /// type must have the same shape and element types of the same bitwidth as the 1008 /// current type. 1009 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) { 1010 ShapedType curType = getType(); 1011 Type curElType = curType.getElementType(); 1012 if (curElType == newElType) 1013 return *this; 1014 1015 assert(getDenseElementBitWidth(newElType) == 1016 getDenseElementBitWidth(curElType) && 1017 "expected element types with the same bitwidth"); 1018 return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType), 1019 getRawData()); 1020 } 1021 1022 DenseElementsAttr 1023 DenseElementsAttr::mapValues(Type newElementType, 1024 function_ref<APInt(const APInt &)> mapping) const { 1025 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 1026 } 1027 1028 DenseElementsAttr DenseElementsAttr::mapValues( 1029 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1030 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 1031 } 1032 1033 ShapedType DenseElementsAttr::getType() const { 1034 return Attribute::getType().cast<ShapedType>(); 1035 } 1036 1037 Type DenseElementsAttr::getElementType() const { 1038 return getType().getElementType(); 1039 } 1040 1041 int64_t DenseElementsAttr::getNumElements() const { 1042 return getType().getNumElements(); 1043 } 1044 1045 //===----------------------------------------------------------------------===// 1046 // DenseIntOrFPElementsAttr 1047 //===----------------------------------------------------------------------===// 1048 1049 /// Utility method to write a range of APInt values to a buffer. 1050 template <typename APRangeT> 1051 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1052 APRangeT &&values) { 1053 size_t numValues = llvm::size(values); 1054 data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT)); 1055 size_t offset = 0; 1056 for (auto it = values.begin(), e = values.end(); it != e; 1057 ++it, offset += storageWidth) { 1058 assert((*it).getBitWidth() <= storageWidth); 1059 writeBits(data.data(), offset, *it); 1060 } 1061 1062 // Handle the special encoding of splat of a boolean. 1063 if (numValues == 1 && (*values.begin()).getBitWidth() == 1) 1064 data[0] = data[0] ? -1 : 0; 1065 } 1066 1067 /// Constructs a dense elements attribute from an array of raw APFloat values. 1068 /// Each APFloat value is expected to have the same bitwidth as the element 1069 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1070 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1071 size_t storageWidth, 1072 ArrayRef<APFloat> values) { 1073 std::vector<char> data; 1074 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1075 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1076 return DenseIntOrFPElementsAttr::getRaw(type, data); 1077 } 1078 1079 /// Constructs a dense elements attribute from an array of raw APInt values. 1080 /// Each APInt value is expected to have the same bitwidth as the element type 1081 /// of 'type'. 1082 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1083 size_t storageWidth, 1084 ArrayRef<APInt> values) { 1085 std::vector<char> data; 1086 writeAPIntsToBuffer(storageWidth, data, values); 1087 return DenseIntOrFPElementsAttr::getRaw(type, data); 1088 } 1089 1090 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1091 ArrayRef<char> data) { 1092 assert((type.isa<RankedTensorType, VectorType>()) && 1093 "type must be ranked tensor or vector"); 1094 assert(type.hasStaticShape() && "type must have static shape"); 1095 bool isSplat = false; 1096 bool isValid = isValidRawBuffer(type, data, isSplat); 1097 assert(isValid); 1098 (void)isValid; 1099 return Base::get(type.getContext(), type, data, isSplat); 1100 } 1101 1102 /// Overload of the raw 'get' method that asserts that the given type is of 1103 /// complex type. This method is used to verify type invariants that the 1104 /// templatized 'get' method cannot. 1105 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1106 ArrayRef<char> data, 1107 int64_t dataEltSize, 1108 bool isInt, 1109 bool isSigned) { 1110 assert(::isValidIntOrFloat( 1111 type.getElementType().cast<ComplexType>().getElementType(), 1112 dataEltSize / 2, isInt, isSigned)); 1113 1114 int64_t numElements = data.size() / dataEltSize; 1115 (void)numElements; 1116 assert(numElements == 1 || numElements == type.getNumElements()); 1117 return getRaw(type, data); 1118 } 1119 1120 /// Overload of the 'getRaw' method that asserts that the given type is of 1121 /// integer type. This method is used to verify type invariants that the 1122 /// templatized 'get' method cannot. 1123 DenseElementsAttr 1124 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1125 int64_t dataEltSize, bool isInt, 1126 bool isSigned) { 1127 assert( 1128 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1129 1130 int64_t numElements = data.size() / dataEltSize; 1131 assert(numElements == 1 || numElements == type.getNumElements()); 1132 (void)numElements; 1133 return getRaw(type, data); 1134 } 1135 1136 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 1137 const char *inRawData, char *outRawData, size_t elementBitWidth, 1138 size_t numElements) { 1139 using llvm::support::ulittle16_t; 1140 using llvm::support::ulittle32_t; 1141 using llvm::support::ulittle64_t; 1142 1143 assert(llvm::support::endian::system_endianness() == // NOLINT 1144 llvm::support::endianness::big); // NOLINT 1145 // NOLINT to avoid warning message about replacing by static_assert() 1146 1147 // Following std::copy_n always converts endianness on BE machine. 1148 switch (elementBitWidth) { 1149 case 16: { 1150 const ulittle16_t *inRawDataPos = 1151 reinterpret_cast<const ulittle16_t *>(inRawData); 1152 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); 1153 std::copy_n(inRawDataPos, numElements, outDataPos); 1154 break; 1155 } 1156 case 32: { 1157 const ulittle32_t *inRawDataPos = 1158 reinterpret_cast<const ulittle32_t *>(inRawData); 1159 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); 1160 std::copy_n(inRawDataPos, numElements, outDataPos); 1161 break; 1162 } 1163 case 64: { 1164 const ulittle64_t *inRawDataPos = 1165 reinterpret_cast<const ulittle64_t *>(inRawData); 1166 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); 1167 std::copy_n(inRawDataPos, numElements, outDataPos); 1168 break; 1169 } 1170 default: { 1171 size_t nBytes = elementBitWidth / CHAR_BIT; 1172 for (size_t i = 0; i < nBytes; i++) 1173 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); 1174 break; 1175 } 1176 } 1177 } 1178 1179 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1180 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, 1181 ShapedType type) { 1182 size_t numElements = type.getNumElements(); 1183 Type elementType = type.getElementType(); 1184 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1185 elementType = complexTy.getElementType(); 1186 numElements = numElements * 2; 1187 } 1188 size_t elementBitWidth = getDenseElementStorageWidth(elementType); 1189 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && 1190 inRawData.size() <= outRawData.size()); 1191 if (elementBitWidth <= CHAR_BIT) 1192 std::memcpy(outRawData.begin(), inRawData.begin(), inRawData.size()); 1193 else 1194 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), 1195 elementBitWidth, numElements); 1196 } 1197 1198 //===----------------------------------------------------------------------===// 1199 // DenseFPElementsAttr 1200 //===----------------------------------------------------------------------===// 1201 1202 template <typename Fn, typename Attr> 1203 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1204 Type newElementType, 1205 llvm::SmallVectorImpl<char> &data) { 1206 size_t bitWidth = getDenseElementBitWidth(newElementType); 1207 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1208 1209 ShapedType newArrayType; 1210 if (inType.isa<RankedTensorType>()) 1211 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1212 else if (inType.isa<UnrankedTensorType>()) 1213 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1214 else if (auto vType = inType.dyn_cast<VectorType>()) 1215 newArrayType = VectorType::get(vType.getShape(), newElementType, 1216 vType.getNumScalableDims()); 1217 else 1218 assert(newArrayType && "Unhandled tensor type"); 1219 1220 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1221 data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT)); 1222 1223 // Functor used to process a single element value of the attribute. 1224 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1225 auto newInt = mapping(value); 1226 assert(newInt.getBitWidth() == bitWidth); 1227 writeBits(data.data(), index * storageBitWidth, newInt); 1228 }; 1229 1230 // Check for the splat case. 1231 if (attr.isSplat()) { 1232 processElt(*attr.begin(), /*index=*/0); 1233 return newArrayType; 1234 } 1235 1236 // Otherwise, process all of the element values. 1237 uint64_t elementIdx = 0; 1238 for (auto value : attr) 1239 processElt(value, elementIdx++); 1240 return newArrayType; 1241 } 1242 1243 DenseElementsAttr DenseFPElementsAttr::mapValues( 1244 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1245 llvm::SmallVector<char, 8> elementData; 1246 auto newArrayType = 1247 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1248 1249 return getRaw(newArrayType, elementData); 1250 } 1251 1252 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1253 bool DenseFPElementsAttr::classof(Attribute attr) { 1254 return attr.isa<DenseElementsAttr>() && 1255 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1256 } 1257 1258 //===----------------------------------------------------------------------===// 1259 // DenseIntElementsAttr 1260 //===----------------------------------------------------------------------===// 1261 1262 DenseElementsAttr DenseIntElementsAttr::mapValues( 1263 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1264 llvm::SmallVector<char, 8> elementData; 1265 auto newArrayType = 1266 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1267 return getRaw(newArrayType, elementData); 1268 } 1269 1270 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1271 bool DenseIntElementsAttr::classof(Attribute attr) { 1272 return attr.isa<DenseElementsAttr>() && 1273 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1274 } 1275 1276 //===----------------------------------------------------------------------===// 1277 // OpaqueElementsAttr 1278 //===----------------------------------------------------------------------===// 1279 1280 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1281 Dialect *dialect = getContext()->getLoadedDialect(getDialect()); 1282 if (!dialect) 1283 return true; 1284 auto *interface = llvm::dyn_cast<DialectDecodeAttributesInterface>(dialect); 1285 if (!interface) 1286 return true; 1287 return failed(interface->decode(*this, result)); 1288 } 1289 1290 LogicalResult 1291 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1292 StringAttr dialect, StringRef value, 1293 ShapedType type) { 1294 if (!Dialect::isValidNamespace(dialect.strref())) 1295 return emitError() << "invalid dialect namespace '" << dialect << "'"; 1296 return success(); 1297 } 1298 1299 //===----------------------------------------------------------------------===// 1300 // SparseElementsAttr 1301 //===----------------------------------------------------------------------===// 1302 1303 /// Get a zero APFloat for the given sparse attribute. 1304 APFloat SparseElementsAttr::getZeroAPFloat() const { 1305 auto eltType = getElementType().cast<FloatType>(); 1306 return APFloat(eltType.getFloatSemantics()); 1307 } 1308 1309 /// Get a zero APInt for the given sparse attribute. 1310 APInt SparseElementsAttr::getZeroAPInt() const { 1311 auto eltType = getElementType().cast<IntegerType>(); 1312 return APInt::getZero(eltType.getWidth()); 1313 } 1314 1315 /// Get a zero attribute for the given attribute type. 1316 Attribute SparseElementsAttr::getZeroAttr() const { 1317 auto eltType = getElementType(); 1318 1319 // Handle floating point elements. 1320 if (eltType.isa<FloatType>()) 1321 return FloatAttr::get(eltType, 0); 1322 1323 // Handle string type. 1324 if (getValues().isa<DenseStringElementsAttr>()) 1325 return StringAttr::get("", eltType); 1326 1327 // Otherwise, this is an integer. 1328 return IntegerAttr::get(eltType, 0); 1329 } 1330 1331 /// Flatten, and return, all of the sparse indices in this attribute in 1332 /// row-major order. 1333 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1334 std::vector<ptrdiff_t> flatSparseIndices; 1335 1336 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1337 // as a 1-D index array. 1338 auto sparseIndices = getIndices(); 1339 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1340 if (sparseIndices.isSplat()) { 1341 SmallVector<uint64_t, 8> indices(getType().getRank(), 1342 *sparseIndexValues.begin()); 1343 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1344 return flatSparseIndices; 1345 } 1346 1347 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1348 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1349 size_t rank = getType().getRank(); 1350 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1351 flatSparseIndices.push_back(getFlattenedIndex( 1352 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1353 return flatSparseIndices; 1354 } 1355 1356 LogicalResult 1357 SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1358 ShapedType type, DenseIntElementsAttr sparseIndices, 1359 DenseElementsAttr values) { 1360 ShapedType valuesType = values.getType(); 1361 if (valuesType.getRank() != 1) 1362 return emitError() << "expected 1-d tensor for sparse element values"; 1363 1364 // Verify the indices and values shape. 1365 ShapedType indicesType = sparseIndices.getType(); 1366 auto emitShapeError = [&]() { 1367 return emitError() << "expected shape ([" << type.getShape() 1368 << "]); inferred shape of indices literal ([" 1369 << indicesType.getShape() 1370 << "]); inferred shape of values literal ([" 1371 << valuesType.getShape() << "])"; 1372 }; 1373 // Verify indices shape. 1374 size_t rank = type.getRank(), indicesRank = indicesType.getRank(); 1375 if (indicesRank == 2) { 1376 if (indicesType.getDimSize(1) != static_cast<int64_t>(rank)) 1377 return emitShapeError(); 1378 } else if (indicesRank != 1 || rank != 1) { 1379 return emitShapeError(); 1380 } 1381 // Verify the values shape. 1382 int64_t numSparseIndices = indicesType.getDimSize(0); 1383 if (numSparseIndices != valuesType.getDimSize(0)) 1384 return emitShapeError(); 1385 1386 // Verify that the sparse indices are within the value shape. 1387 auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) { 1388 return emitError() 1389 << "sparse index #" << indexNum 1390 << " is not contained within the value shape, with index=[" << index 1391 << "], and type=" << type; 1392 }; 1393 1394 // Handle the case where the index values are a splat. 1395 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1396 if (sparseIndices.isSplat()) { 1397 SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin()); 1398 if (!ElementsAttr::isValidIndex(type, indices)) 1399 return emitIndexError(0, indices); 1400 return success(); 1401 } 1402 1403 // Otherwise, reinterpret each index as an ArrayRef. 1404 for (size_t i = 0, e = numSparseIndices; i != e; ++i) { 1405 ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank), 1406 rank); 1407 if (!ElementsAttr::isValidIndex(type, index)) 1408 return emitIndexError(i, index); 1409 } 1410 1411 return success(); 1412 } 1413 1414 //===----------------------------------------------------------------------===// 1415 // TypeAttr 1416 //===----------------------------------------------------------------------===// 1417 1418 void TypeAttr::walkImmediateSubElements( 1419 function_ref<void(Attribute)> walkAttrsFn, 1420 function_ref<void(Type)> walkTypesFn) const { 1421 walkTypesFn(getValue()); 1422 } 1423