1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/BuiltinDialect.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/IntegerSet.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/SymbolTable.h" 17 #include "mlir/IR/Types.h" 18 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" 19 #include "llvm/ADT/APSInt.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/Support/Endian.h" 22 23 using namespace mlir; 24 using namespace mlir::detail; 25 26 //===----------------------------------------------------------------------===// 27 /// Tablegen Attribute Definitions 28 //===----------------------------------------------------------------------===// 29 30 #define GET_ATTRDEF_CLASSES 31 #include "mlir/IR/BuiltinAttributes.cpp.inc" 32 33 //===----------------------------------------------------------------------===// 34 // BuiltinDialect 35 //===----------------------------------------------------------------------===// 36 37 void BuiltinDialect::registerAttributes() { 38 addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr, 39 DenseStringElementsAttr, DictionaryAttr, FloatAttr, 40 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr, 41 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr, 42 UnitAttr>(); 43 } 44 45 //===----------------------------------------------------------------------===// 46 // ArrayAttr 47 //===----------------------------------------------------------------------===// 48 49 void ArrayAttr::walkImmediateSubElements( 50 function_ref<void(Attribute)> walkAttrsFn, 51 function_ref<void(Type)> walkTypesFn) const { 52 for (Attribute attr : getValue()) 53 walkAttrsFn(attr); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // DictionaryAttr 58 //===----------------------------------------------------------------------===// 59 60 /// Helper function that does either an in place sort or sorts from source array 61 /// into destination. If inPlace then storage is both the source and the 62 /// destination, else value is the source and storage destination. Returns 63 /// whether source was sorted. 64 template <bool inPlace> 65 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 66 SmallVectorImpl<NamedAttribute> &storage) { 67 // Specialize for the common case. 68 switch (value.size()) { 69 case 0: 70 // Zero already sorted. 71 break; 72 case 1: 73 // One already sorted but may need to be copied. 74 if (!inPlace) 75 storage.assign({value[0]}); 76 break; 77 case 2: { 78 bool isSorted = value[0] < value[1]; 79 if (inPlace) { 80 if (!isSorted) 81 std::swap(storage[0], storage[1]); 82 } else if (isSorted) { 83 storage.assign({value[0], value[1]}); 84 } else { 85 storage.assign({value[1], value[0]}); 86 } 87 return !isSorted; 88 } 89 default: 90 if (!inPlace) 91 storage.assign(value.begin(), value.end()); 92 // Check to see they are sorted already. 93 bool isSorted = llvm::is_sorted(value); 94 // If not, do a general sort. 95 if (!isSorted) 96 llvm::array_pod_sort(storage.begin(), storage.end()); 97 return !isSorted; 98 } 99 return false; 100 } 101 102 /// Returns an entry with a duplicate name from the given sorted array of named 103 /// attributes. Returns llvm::None if all elements have unique names. 104 static Optional<NamedAttribute> 105 findDuplicateElement(ArrayRef<NamedAttribute> value) { 106 const Optional<NamedAttribute> none{llvm::None}; 107 if (value.size() < 2) 108 return none; 109 110 if (value.size() == 2) 111 return value[0].first == value[1].first ? value[0] : none; 112 113 auto it = std::adjacent_find( 114 value.begin(), value.end(), 115 [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; }); 116 return it != value.end() ? *it : none; 117 } 118 119 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 120 SmallVectorImpl<NamedAttribute> &storage) { 121 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); 122 assert(!findDuplicateElement(storage) && 123 "DictionaryAttr element names must be unique"); 124 return isSorted; 125 } 126 127 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 128 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); 129 assert(!findDuplicateElement(array) && 130 "DictionaryAttr element names must be unique"); 131 return isSorted; 132 } 133 134 Optional<NamedAttribute> 135 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, 136 bool isSorted) { 137 if (!isSorted) 138 dictionaryAttrSort</*inPlace=*/true>(array, array); 139 return findDuplicateElement(array); 140 } 141 142 DictionaryAttr DictionaryAttr::get(MLIRContext *context, 143 ArrayRef<NamedAttribute> value) { 144 if (value.empty()) 145 return DictionaryAttr::getEmpty(context); 146 assert(llvm::all_of(value, 147 [](const NamedAttribute &attr) { return attr.second; }) && 148 "value cannot have null entries"); 149 150 // We need to sort the element list to canonicalize it. 151 SmallVector<NamedAttribute, 8> storage; 152 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 153 value = storage; 154 assert(!findDuplicateElement(value) && 155 "DictionaryAttr element names must be unique"); 156 return Base::get(context, value); 157 } 158 /// Construct a dictionary with an array of values that is known to already be 159 /// sorted by name and uniqued. 160 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context, 161 ArrayRef<NamedAttribute> value) { 162 if (value.empty()) 163 return DictionaryAttr::getEmpty(context); 164 // Ensure that the attribute elements are unique and sorted. 165 assert(llvm::is_sorted(value, 166 [](NamedAttribute l, NamedAttribute r) { 167 return l.first.strref() < r.first.strref(); 168 }) && 169 "expected attribute values to be sorted"); 170 assert(!findDuplicateElement(value) && 171 "DictionaryAttr element names must be unique"); 172 return Base::get(context, value); 173 } 174 175 /// Return the specified attribute if present, null otherwise. 176 Attribute DictionaryAttr::get(StringRef name) const { 177 Optional<NamedAttribute> attr = getNamed(name); 178 return attr ? attr->second : nullptr; 179 } 180 Attribute DictionaryAttr::get(Identifier name) const { 181 Optional<NamedAttribute> attr = getNamed(name); 182 return attr ? attr->second : nullptr; 183 } 184 185 /// Return the specified named attribute if present, None otherwise. 186 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 187 ArrayRef<NamedAttribute> values = getValue(); 188 const auto *it = llvm::lower_bound(values, name); 189 return it != values.end() && it->first == name ? *it 190 : Optional<NamedAttribute>(); 191 } 192 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { 193 for (auto elt : getValue()) 194 if (elt.first == name) 195 return elt; 196 return llvm::None; 197 } 198 199 DictionaryAttr::iterator DictionaryAttr::begin() const { 200 return getValue().begin(); 201 } 202 DictionaryAttr::iterator DictionaryAttr::end() const { 203 return getValue().end(); 204 } 205 size_t DictionaryAttr::size() const { return getValue().size(); } 206 207 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { 208 return Base::get(context, ArrayRef<NamedAttribute>()); 209 } 210 211 void DictionaryAttr::walkImmediateSubElements( 212 function_ref<void(Attribute)> walkAttrsFn, 213 function_ref<void(Type)> walkTypesFn) const { 214 for (Attribute attr : llvm::make_second_range(getValue())) 215 walkAttrsFn(attr); 216 } 217 218 //===----------------------------------------------------------------------===// 219 // StringAttr 220 //===----------------------------------------------------------------------===// 221 222 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) { 223 return Base::get(context, "", NoneType::get(context)); 224 } 225 226 /// Twine support for StringAttr. 227 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) { 228 // Fast-path empty twine. 229 if (twine.isTriviallyEmpty()) 230 return get(context); 231 SmallVector<char, 32> tempStr; 232 return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context)); 233 } 234 235 /// Twine support for StringAttr. 236 StringAttr StringAttr::get(const Twine &twine, Type type) { 237 SmallVector<char, 32> tempStr; 238 return Base::get(type.getContext(), twine.toStringRef(tempStr), type); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // FloatAttr 243 //===----------------------------------------------------------------------===// 244 245 double FloatAttr::getValueAsDouble() const { 246 return getValueAsDouble(getValue()); 247 } 248 double FloatAttr::getValueAsDouble(APFloat value) { 249 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 250 bool losesInfo = false; 251 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 252 &losesInfo); 253 } 254 return value.convertToDouble(); 255 } 256 257 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 258 Type type, APFloat value) { 259 // Verify that the type is correct. 260 if (!type.isa<FloatType>()) 261 return emitError() << "expected floating point type"; 262 263 // Verify that the type semantics match that of the value. 264 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 265 return emitError() 266 << "FloatAttr type doesn't match the type implied by its value"; 267 } 268 return success(); 269 } 270 271 //===----------------------------------------------------------------------===// 272 // SymbolRefAttr 273 //===----------------------------------------------------------------------===// 274 275 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value, 276 ArrayRef<FlatSymbolRefAttr> nestedRefs) { 277 return get(StringAttr::get(ctx, value), nestedRefs); 278 } 279 280 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { 281 return get(ctx, value, {}).cast<FlatSymbolRefAttr>(); 282 } 283 284 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) { 285 return get(value, {}).cast<FlatSymbolRefAttr>(); 286 } 287 288 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) { 289 auto symName = 290 symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); 291 assert(symName && "value does not have a valid symbol name"); 292 return SymbolRefAttr::get(symName); 293 } 294 295 StringAttr SymbolRefAttr::getLeafReference() const { 296 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 297 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr(); 298 } 299 300 //===----------------------------------------------------------------------===// 301 // IntegerAttr 302 //===----------------------------------------------------------------------===// 303 304 int64_t IntegerAttr::getInt() const { 305 assert((getType().isIndex() || getType().isSignlessInteger()) && 306 "must be signless integer"); 307 return getValue().getSExtValue(); 308 } 309 310 int64_t IntegerAttr::getSInt() const { 311 assert(getType().isSignedInteger() && "must be signed integer"); 312 return getValue().getSExtValue(); 313 } 314 315 uint64_t IntegerAttr::getUInt() const { 316 assert(getType().isUnsignedInteger() && "must be unsigned integer"); 317 return getValue().getZExtValue(); 318 } 319 320 /// Return the value as an APSInt which carries the signed from the type of 321 /// the attribute. This traps on signless integers types! 322 APSInt IntegerAttr::getAPSInt() const { 323 assert(!getType().isSignlessInteger() && 324 "Signless integers don't carry a sign for APSInt"); 325 return APSInt(getValue(), getType().isUnsignedInteger()); 326 } 327 328 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, 329 Type type, APInt value) { 330 if (IntegerType integerType = type.dyn_cast<IntegerType>()) { 331 if (integerType.getWidth() != value.getBitWidth()) 332 return emitError() << "integer type bit width (" << integerType.getWidth() 333 << ") doesn't match value bit width (" 334 << value.getBitWidth() << ")"; 335 return success(); 336 } 337 if (type.isa<IndexType>()) 338 return success(); 339 return emitError() << "expected integer or index type"; 340 } 341 342 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { 343 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value)); 344 return attr.cast<BoolAttr>(); 345 } 346 347 //===----------------------------------------------------------------------===// 348 // BoolAttr 349 350 bool BoolAttr::getValue() const { 351 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl); 352 return storage->value.getBoolValue(); 353 } 354 355 bool BoolAttr::classof(Attribute attr) { 356 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 357 return intAttr && intAttr.getType().isSignlessInteger(1); 358 } 359 360 //===----------------------------------------------------------------------===// 361 // OpaqueAttr 362 //===----------------------------------------------------------------------===// 363 364 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError, 365 Identifier dialect, StringRef attrData, 366 Type type) { 367 if (!Dialect::isValidNamespace(dialect.strref())) 368 return emitError() << "invalid dialect namespace '" << dialect << "'"; 369 370 // Check that the dialect is actually registered. 371 MLIRContext *context = dialect.getContext(); 372 if (!context->allowsUnregisteredDialects() && 373 !context->getLoadedDialect(dialect.strref())) { 374 return emitError() 375 << "#" << dialect << "<\"" << attrData << "\"> : " << type 376 << " attribute created with unregistered dialect. If this is " 377 "intended, please call allowUnregisteredDialects() on the " 378 "MLIRContext, or use -allow-unregistered-dialect with " 379 "the MLIR opt tool used"; 380 } 381 382 return success(); 383 } 384 385 //===----------------------------------------------------------------------===// 386 // DenseElementsAttr Utilities 387 //===----------------------------------------------------------------------===// 388 389 /// Get the bitwidth of a dense element type within the buffer. 390 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 391 static size_t getDenseElementStorageWidth(size_t origWidth) { 392 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 393 } 394 static size_t getDenseElementStorageWidth(Type elementType) { 395 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 396 } 397 398 /// Set a bit to a specific value. 399 static void setBit(char *rawData, size_t bitPos, bool value) { 400 if (value) 401 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 402 else 403 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 404 } 405 406 /// Return the value of the specified bit. 407 static bool getBit(const char *rawData, size_t bitPos) { 408 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 409 } 410 411 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for 412 /// BE format. 413 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, 414 char *result) { 415 assert(llvm::support::endian::system_endianness() == // NOLINT 416 llvm::support::endianness::big); // NOLINT 417 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 418 419 // Copy the words filled with data. 420 // For example, when `value` has 2 words, the first word is filled with data. 421 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| 422 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 423 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 424 numFilledWords, result); 425 // Convert last word of APInt to LE format and store it in char 426 // array(`valueLE`). 427 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| 428 size_t lastWordPos = numFilledWords; 429 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); 430 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 431 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, 432 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); 433 // Extract actual APInt data from `valueLE`, convert endianness to BE format, 434 // and store it in `result`. 435 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| 436 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 437 valueLE.begin(), result + lastWordPos, 438 (numBytes - lastWordPos) * CHAR_BIT, 1); 439 } 440 441 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE 442 /// format. 443 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, 444 APInt &result) { 445 assert(llvm::support::endian::system_endianness() == // NOLINT 446 llvm::support::endianness::big); // NOLINT 447 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 448 449 // Copy the data that fills the word of `result` from `inArray`. 450 // For example, when `result` has 2 words, the first word will be filled with 451 // data. So, the first 8 bytes are copied from `inArray` here. 452 // `inArray` (10 bytes, BE): |abcdefgh|ij| 453 // ==> `result` (2 words, BE): |abcdefgh|--------| 454 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 455 std::copy_n( 456 inArray, numFilledWords, 457 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 458 459 // Convert array data which will be last word of `result` to LE format, and 460 // store it in char array(`inArrayLE`). 461 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| 462 size_t lastWordPos = numFilledWords; 463 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); 464 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 465 inArray + lastWordPos, inArrayLE.begin(), 466 (numBytes - lastWordPos) * CHAR_BIT, 1); 467 468 // Convert `inArrayLE` to BE format, and store it in last word of `result`. 469 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| 470 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 471 inArrayLE.begin(), 472 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + 473 lastWordPos, 474 APInt::APINT_BITS_PER_WORD, 1); 475 } 476 477 /// Writes value to the bit position `bitPos` in array `rawData`. 478 static void writeBits(char *rawData, size_t bitPos, APInt value) { 479 size_t bitWidth = value.getBitWidth(); 480 481 // If the bitwidth is 1 we just toggle the specific bit. 482 if (bitWidth == 1) 483 return setBit(rawData, bitPos, value.isOneValue()); 484 485 // Otherwise, the bit position is guaranteed to be byte aligned. 486 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 487 if (llvm::support::endian::system_endianness() == 488 llvm::support::endianness::big) { 489 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. 490 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 491 // work correctly in BE format. 492 // ex. `value` (2 words including 10 bytes) 493 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| 494 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), 495 rawData + (bitPos / CHAR_BIT)); 496 } else { 497 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 498 llvm::divideCeil(bitWidth, CHAR_BIT), 499 rawData + (bitPos / CHAR_BIT)); 500 } 501 } 502 503 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 504 /// `rawData`. 505 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 506 // Handle a boolean bit position. 507 if (bitWidth == 1) 508 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 509 510 // Otherwise, the bit position must be 8-bit aligned. 511 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 512 APInt result(bitWidth, 0); 513 if (llvm::support::endian::system_endianness() == 514 llvm::support::endianness::big) { 515 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. 516 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 517 // work correctly in BE format. 518 // ex. `result` (2 words including 10 bytes) 519 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function 520 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), 521 llvm::divideCeil(bitWidth, CHAR_BIT), result); 522 } else { 523 std::copy_n(rawData + (bitPos / CHAR_BIT), 524 llvm::divideCeil(bitWidth, CHAR_BIT), 525 const_cast<char *>( 526 reinterpret_cast<const char *>(result.getRawData()))); 527 } 528 return result; 529 } 530 531 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 532 /// the same element count as 'type'. 533 template <typename Values> 534 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 535 return (values.size() == 1) || 536 (type.getNumElements() == static_cast<int64_t>(values.size())); 537 } 538 539 //===----------------------------------------------------------------------===// 540 // DenseElementsAttr Iterators 541 //===----------------------------------------------------------------------===// 542 543 //===----------------------------------------------------------------------===// 544 // AttributeElementIterator 545 546 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 547 DenseElementsAttr attr, size_t index) 548 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 549 Attribute, Attribute, Attribute>( 550 attr.getAsOpaquePointer(), index) {} 551 552 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 553 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 554 Type eltTy = owner.getElementType(); 555 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 556 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 557 if (eltTy.isa<IndexType>()) 558 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 559 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 560 IntElementIterator intIt(owner, index); 561 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 562 return FloatAttr::get(eltTy, *floatIt); 563 } 564 if (auto complexTy = eltTy.dyn_cast<ComplexType>()) { 565 auto complexEltTy = complexTy.getElementType(); 566 ComplexIntElementIterator complexIntIt(owner, index); 567 if (complexEltTy.isa<IntegerType>()) { 568 auto value = *complexIntIt; 569 auto real = IntegerAttr::get(complexEltTy, value.real()); 570 auto imag = IntegerAttr::get(complexEltTy, value.imag()); 571 return ArrayAttr::get(complexTy.getContext(), 572 ArrayRef<Attribute>{real, imag}); 573 } 574 575 ComplexFloatElementIterator complexFloatIt( 576 complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt); 577 auto value = *complexFloatIt; 578 auto real = FloatAttr::get(complexEltTy, value.real()); 579 auto imag = FloatAttr::get(complexEltTy, value.imag()); 580 return ArrayAttr::get(complexTy.getContext(), 581 ArrayRef<Attribute>{real, imag}); 582 } 583 if (owner.isa<DenseStringElementsAttr>()) { 584 ArrayRef<StringRef> vals = owner.getRawStringData(); 585 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 586 } 587 llvm_unreachable("unexpected element type"); 588 } 589 590 //===----------------------------------------------------------------------===// 591 // BoolElementIterator 592 593 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 594 DenseElementsAttr attr, size_t dataIndex) 595 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 596 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 597 598 bool DenseElementsAttr::BoolElementIterator::operator*() const { 599 return getBit(getData(), getDataIndex()); 600 } 601 602 //===----------------------------------------------------------------------===// 603 // IntElementIterator 604 605 DenseElementsAttr::IntElementIterator::IntElementIterator( 606 DenseElementsAttr attr, size_t dataIndex) 607 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 608 attr.getRawData().data(), attr.isSplat(), dataIndex), 609 bitWidth(getDenseElementBitWidth(attr.getElementType())) {} 610 611 APInt DenseElementsAttr::IntElementIterator::operator*() const { 612 return readBits(getData(), 613 getDataIndex() * getDenseElementStorageWidth(bitWidth), 614 bitWidth); 615 } 616 617 //===----------------------------------------------------------------------===// 618 // ComplexIntElementIterator 619 620 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 621 DenseElementsAttr attr, size_t dataIndex) 622 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 623 std::complex<APInt>, std::complex<APInt>, 624 std::complex<APInt>>( 625 attr.getRawData().data(), attr.isSplat(), dataIndex) { 626 auto complexType = attr.getElementType().cast<ComplexType>(); 627 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 628 } 629 630 std::complex<APInt> 631 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 632 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 633 size_t offset = getDataIndex() * storageWidth * 2; 634 return {readBits(getData(), offset, bitWidth), 635 readBits(getData(), offset + storageWidth, bitWidth)}; 636 } 637 638 //===----------------------------------------------------------------------===// 639 // FloatElementIterator 640 641 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 642 const llvm::fltSemantics &smt, IntElementIterator it) 643 : llvm::mapped_iterator<IntElementIterator, 644 std::function<APFloat(const APInt &)>>( 645 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 646 647 //===----------------------------------------------------------------------===// 648 // ComplexFloatElementIterator 649 650 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( 651 const llvm::fltSemantics &smt, ComplexIntElementIterator it) 652 : llvm::mapped_iterator< 653 ComplexIntElementIterator, 654 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( 655 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { 656 return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; 657 }) {} 658 659 //===----------------------------------------------------------------------===// 660 // DenseElementsAttr 661 //===----------------------------------------------------------------------===// 662 663 /// Method for support type inquiry through isa, cast and dyn_cast. 664 bool DenseElementsAttr::classof(Attribute attr) { 665 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 666 } 667 668 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 669 ArrayRef<Attribute> values) { 670 assert(hasSameElementsOrSplat(type, values)); 671 672 // If the element type is not based on int/float/index, assume it is a string 673 // type. 674 auto eltType = type.getElementType(); 675 if (!type.getElementType().isIntOrIndexOrFloat()) { 676 SmallVector<StringRef, 8> stringValues; 677 stringValues.reserve(values.size()); 678 for (Attribute attr : values) { 679 assert(attr.isa<StringAttr>() && 680 "expected string value for non integer/index/float element"); 681 stringValues.push_back(attr.cast<StringAttr>().getValue()); 682 } 683 return get(type, stringValues); 684 } 685 686 // Otherwise, get the raw storage width to use for the allocation. 687 size_t bitWidth = getDenseElementBitWidth(eltType); 688 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 689 690 // Compress the attribute values into a character buffer. 691 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 692 values.size()); 693 APInt intVal; 694 for (unsigned i = 0, e = values.size(); i < e; ++i) { 695 assert(eltType == values[i].getType() && 696 "expected attribute value to have element type"); 697 if (eltType.isa<FloatType>()) 698 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 699 else if (eltType.isa<IntegerType, IndexType>()) 700 intVal = values[i].cast<IntegerAttr>().getValue(); 701 else 702 llvm_unreachable("unexpected element type"); 703 704 assert(intVal.getBitWidth() == bitWidth && 705 "expected value to have same bitwidth as element type"); 706 writeBits(data.data(), i * storageBitWidth, intVal); 707 } 708 return DenseIntOrFPElementsAttr::getRaw(type, data, 709 /*isSplat=*/(values.size() == 1)); 710 } 711 712 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 713 ArrayRef<bool> values) { 714 assert(hasSameElementsOrSplat(type, values)); 715 assert(type.getElementType().isInteger(1)); 716 717 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 718 for (int i = 0, e = values.size(); i != e; ++i) 719 setBit(buff.data(), i, values[i]); 720 return DenseIntOrFPElementsAttr::getRaw(type, buff, 721 /*isSplat=*/(values.size() == 1)); 722 } 723 724 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 725 ArrayRef<StringRef> values) { 726 assert(!type.getElementType().isIntOrFloat()); 727 return DenseStringElementsAttr::get(type, values); 728 } 729 730 /// Constructs a dense integer elements attribute from an array of APInt 731 /// values. Each APInt value is expected to have the same bitwidth as the 732 /// element type of 'type'. 733 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 734 ArrayRef<APInt> values) { 735 assert(type.getElementType().isIntOrIndex()); 736 assert(hasSameElementsOrSplat(type, values)); 737 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 738 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 739 /*isSplat=*/(values.size() == 1)); 740 } 741 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 742 ArrayRef<std::complex<APInt>> values) { 743 ComplexType complex = type.getElementType().cast<ComplexType>(); 744 assert(complex.getElementType().isa<IntegerType>()); 745 assert(hasSameElementsOrSplat(type, values)); 746 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 747 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 748 values.size() * 2); 749 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, 750 /*isSplat=*/(values.size() == 1)); 751 } 752 753 // Constructs a dense float elements attribute from an array of APFloat 754 // values. Each APFloat value is expected to have the same bitwidth as the 755 // element type of 'type'. 756 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 757 ArrayRef<APFloat> values) { 758 assert(type.getElementType().isa<FloatType>()); 759 assert(hasSameElementsOrSplat(type, values)); 760 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 761 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 762 /*isSplat=*/(values.size() == 1)); 763 } 764 DenseElementsAttr 765 DenseElementsAttr::get(ShapedType type, 766 ArrayRef<std::complex<APFloat>> values) { 767 ComplexType complex = type.getElementType().cast<ComplexType>(); 768 assert(complex.getElementType().isa<FloatType>()); 769 assert(hasSameElementsOrSplat(type, values)); 770 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 771 values.size() * 2); 772 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 773 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, 774 /*isSplat=*/(values.size() == 1)); 775 } 776 777 /// Construct a dense elements attribute from a raw buffer representing the 778 /// data for this attribute. Users should generally not use this methods as 779 /// the expected buffer format may not be a form the user expects. 780 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 781 ArrayRef<char> rawBuffer, 782 bool isSplatBuffer) { 783 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 784 } 785 786 /// Returns true if the given buffer is a valid raw buffer for the given type. 787 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 788 ArrayRef<char> rawBuffer, 789 bool &detectedSplat) { 790 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 791 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 792 793 // Storage width of 1 is special as it is packed by the bit. 794 if (storageWidth == 1) { 795 // Check for a splat, or a buffer equal to the number of elements which 796 // consists of either all 0's or all 1's. 797 detectedSplat = false; 798 if (rawBuffer.size() == 1) { 799 auto rawByte = static_cast<uint8_t>(rawBuffer[0]); 800 if (rawByte == 0 || rawByte == 0xff) { 801 detectedSplat = true; 802 return true; 803 } 804 } 805 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); 806 } 807 // All other types are 8-bit aligned. 808 if ((detectedSplat = rawBufferWidth == storageWidth)) 809 return true; 810 return rawBufferWidth == (storageWidth * type.getNumElements()); 811 } 812 813 /// Check the information for a C++ data type, check if this type is valid for 814 /// the current attribute. This method is used to verify specific type 815 /// invariants that the templatized 'getValues' method cannot. 816 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 817 bool isSigned) { 818 // Make sure that the data element size is the same as the type element width. 819 if (getDenseElementBitWidth(type) != 820 static_cast<size_t>(dataEltSize * CHAR_BIT)) 821 return false; 822 823 // Check that the element type is either float or integer or index. 824 if (!isInt) 825 return type.isa<FloatType>(); 826 if (type.isIndex()) 827 return true; 828 829 auto intType = type.dyn_cast<IntegerType>(); 830 if (!intType) 831 return false; 832 833 // Make sure signedness semantics is consistent. 834 if (intType.isSignless()) 835 return true; 836 return intType.isSigned() ? isSigned : !isSigned; 837 } 838 839 /// Defaults down the subclass implementation. 840 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 841 ArrayRef<char> data, 842 int64_t dataEltSize, 843 bool isInt, bool isSigned) { 844 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 845 isSigned); 846 } 847 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 848 ArrayRef<char> data, 849 int64_t dataEltSize, 850 bool isInt, 851 bool isSigned) { 852 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 853 isInt, isSigned); 854 } 855 856 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 857 bool isSigned) const { 858 return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned); 859 } 860 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 861 bool isSigned) const { 862 return ::isValidIntOrFloat( 863 getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, 864 isInt, isSigned); 865 } 866 867 /// Returns true if this attribute corresponds to a splat, i.e. if all element 868 /// values are the same. 869 bool DenseElementsAttr::isSplat() const { 870 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 871 } 872 873 /// Return if the given complex type has an integer element type. 874 static bool isComplexOfIntType(Type type) { 875 return type.cast<ComplexType>().getElementType().isa<IntegerType>(); 876 } 877 878 auto DenseElementsAttr::getComplexIntValues() const 879 -> llvm::iterator_range<ComplexIntElementIterator> { 880 assert(isComplexOfIntType(getElementType()) && 881 "expected complex integral type"); 882 return {ComplexIntElementIterator(*this, 0), 883 ComplexIntElementIterator(*this, getNumElements())}; 884 } 885 auto DenseElementsAttr::complex_value_begin() const 886 -> ComplexIntElementIterator { 887 assert(isComplexOfIntType(getElementType()) && 888 "expected complex integral type"); 889 return ComplexIntElementIterator(*this, 0); 890 } 891 auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator { 892 assert(isComplexOfIntType(getElementType()) && 893 "expected complex integral type"); 894 return ComplexIntElementIterator(*this, getNumElements()); 895 } 896 897 /// Return the held element values as a range of APFloat. The element type of 898 /// this attribute must be of float type. 899 auto DenseElementsAttr::getFloatValues() const 900 -> llvm::iterator_range<FloatElementIterator> { 901 auto elementType = getElementType().cast<FloatType>(); 902 const auto &elementSemantics = elementType.getFloatSemantics(); 903 return {FloatElementIterator(elementSemantics, raw_int_begin()), 904 FloatElementIterator(elementSemantics, raw_int_end())}; 905 } 906 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 907 auto elementType = getElementType().cast<FloatType>(); 908 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin()); 909 } 910 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 911 auto elementType = getElementType().cast<FloatType>(); 912 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end()); 913 } 914 915 auto DenseElementsAttr::getComplexFloatValues() const 916 -> llvm::iterator_range<ComplexFloatElementIterator> { 917 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 918 assert(eltTy.isa<FloatType>() && "expected complex float type"); 919 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 920 return {{semantics, {*this, 0}}, 921 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 922 } 923 auto DenseElementsAttr::complex_float_value_begin() const 924 -> ComplexFloatElementIterator { 925 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 926 assert(eltTy.isa<FloatType>() && "expected complex float type"); 927 return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}}; 928 } 929 auto DenseElementsAttr::complex_float_value_end() const 930 -> ComplexFloatElementIterator { 931 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 932 assert(eltTy.isa<FloatType>() && "expected complex float type"); 933 return {eltTy.cast<FloatType>().getFloatSemantics(), 934 {*this, static_cast<size_t>(getNumElements())}}; 935 } 936 937 /// Return the raw storage data held by this attribute. 938 ArrayRef<char> DenseElementsAttr::getRawData() const { 939 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data; 940 } 941 942 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 943 return static_cast<DenseStringElementsAttrStorage *>(impl)->data; 944 } 945 946 /// Return a new DenseElementsAttr that has the same data as the current 947 /// attribute, but has been reshaped to 'newType'. The new type must have the 948 /// same total number of elements as well as element type. 949 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 950 ShapedType curType = getType(); 951 if (curType == newType) 952 return *this; 953 954 assert(newType.getElementType() == curType.getElementType() && 955 "expected the same element type"); 956 assert(newType.getNumElements() == curType.getNumElements() && 957 "expected the same number of elements"); 958 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 959 } 960 961 /// Return a new DenseElementsAttr that has the same data as the current 962 /// attribute, but has bitcast elements such that it is now 'newType'. The new 963 /// type must have the same shape and element types of the same bitwidth as the 964 /// current type. 965 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) { 966 ShapedType curType = getType(); 967 Type curElType = curType.getElementType(); 968 if (curElType == newElType) 969 return *this; 970 971 assert(getDenseElementBitWidth(newElType) == 972 getDenseElementBitWidth(curElType) && 973 "expected element types with the same bitwidth"); 974 return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType), 975 getRawData(), isSplat()); 976 } 977 978 DenseElementsAttr 979 DenseElementsAttr::mapValues(Type newElementType, 980 function_ref<APInt(const APInt &)> mapping) const { 981 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 982 } 983 984 DenseElementsAttr DenseElementsAttr::mapValues( 985 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 986 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 987 } 988 989 ShapedType DenseElementsAttr::getType() const { 990 return Attribute::getType().cast<ShapedType>(); 991 } 992 993 Type DenseElementsAttr::getElementType() const { 994 return getType().getElementType(); 995 } 996 997 int64_t DenseElementsAttr::getNumElements() const { 998 return getType().getNumElements(); 999 } 1000 1001 //===----------------------------------------------------------------------===// 1002 // DenseIntOrFPElementsAttr 1003 //===----------------------------------------------------------------------===// 1004 1005 /// Utility method to write a range of APInt values to a buffer. 1006 template <typename APRangeT> 1007 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1008 APRangeT &&values) { 1009 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); 1010 size_t offset = 0; 1011 for (auto it = values.begin(), e = values.end(); it != e; 1012 ++it, offset += storageWidth) { 1013 assert((*it).getBitWidth() <= storageWidth); 1014 writeBits(data.data(), offset, *it); 1015 } 1016 } 1017 1018 /// Constructs a dense elements attribute from an array of raw APFloat values. 1019 /// Each APFloat value is expected to have the same bitwidth as the element 1020 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1021 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1022 size_t storageWidth, 1023 ArrayRef<APFloat> values, 1024 bool isSplat) { 1025 std::vector<char> data; 1026 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1027 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1028 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1029 } 1030 1031 /// Constructs a dense elements attribute from an array of raw APInt values. 1032 /// Each APInt value is expected to have the same bitwidth as the element type 1033 /// of 'type'. 1034 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1035 size_t storageWidth, 1036 ArrayRef<APInt> values, 1037 bool isSplat) { 1038 std::vector<char> data; 1039 writeAPIntsToBuffer(storageWidth, data, values); 1040 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1041 } 1042 1043 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1044 ArrayRef<char> data, 1045 bool isSplat) { 1046 assert((type.isa<RankedTensorType, VectorType>()) && 1047 "type must be ranked tensor or vector"); 1048 assert(type.hasStaticShape() && "type must have static shape"); 1049 return Base::get(type.getContext(), type, data, isSplat); 1050 } 1051 1052 /// Overload of the raw 'get' method that asserts that the given type is of 1053 /// complex type. This method is used to verify type invariants that the 1054 /// templatized 'get' method cannot. 1055 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1056 ArrayRef<char> data, 1057 int64_t dataEltSize, 1058 bool isInt, 1059 bool isSigned) { 1060 assert(::isValidIntOrFloat( 1061 type.getElementType().cast<ComplexType>().getElementType(), 1062 dataEltSize / 2, isInt, isSigned)); 1063 1064 int64_t numElements = data.size() / dataEltSize; 1065 assert(numElements == 1 || numElements == type.getNumElements()); 1066 return getRaw(type, data, /*isSplat=*/numElements == 1); 1067 } 1068 1069 /// Overload of the 'getRaw' method that asserts that the given type is of 1070 /// integer type. This method is used to verify type invariants that the 1071 /// templatized 'get' method cannot. 1072 DenseElementsAttr 1073 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1074 int64_t dataEltSize, bool isInt, 1075 bool isSigned) { 1076 assert( 1077 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1078 1079 int64_t numElements = data.size() / dataEltSize; 1080 assert(numElements == 1 || numElements == type.getNumElements()); 1081 return getRaw(type, data, /*isSplat=*/numElements == 1); 1082 } 1083 1084 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 1085 const char *inRawData, char *outRawData, size_t elementBitWidth, 1086 size_t numElements) { 1087 using llvm::support::ulittle16_t; 1088 using llvm::support::ulittle32_t; 1089 using llvm::support::ulittle64_t; 1090 1091 assert(llvm::support::endian::system_endianness() == // NOLINT 1092 llvm::support::endianness::big); // NOLINT 1093 // NOLINT to avoid warning message about replacing by static_assert() 1094 1095 // Following std::copy_n always converts endianness on BE machine. 1096 switch (elementBitWidth) { 1097 case 16: { 1098 const ulittle16_t *inRawDataPos = 1099 reinterpret_cast<const ulittle16_t *>(inRawData); 1100 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); 1101 std::copy_n(inRawDataPos, numElements, outDataPos); 1102 break; 1103 } 1104 case 32: { 1105 const ulittle32_t *inRawDataPos = 1106 reinterpret_cast<const ulittle32_t *>(inRawData); 1107 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); 1108 std::copy_n(inRawDataPos, numElements, outDataPos); 1109 break; 1110 } 1111 case 64: { 1112 const ulittle64_t *inRawDataPos = 1113 reinterpret_cast<const ulittle64_t *>(inRawData); 1114 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); 1115 std::copy_n(inRawDataPos, numElements, outDataPos); 1116 break; 1117 } 1118 default: { 1119 size_t nBytes = elementBitWidth / CHAR_BIT; 1120 for (size_t i = 0; i < nBytes; i++) 1121 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); 1122 break; 1123 } 1124 } 1125 } 1126 1127 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1128 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, 1129 ShapedType type) { 1130 size_t numElements = type.getNumElements(); 1131 Type elementType = type.getElementType(); 1132 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1133 elementType = complexTy.getElementType(); 1134 numElements = numElements * 2; 1135 } 1136 size_t elementBitWidth = getDenseElementStorageWidth(elementType); 1137 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && 1138 inRawData.size() <= outRawData.size()); 1139 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), 1140 elementBitWidth, numElements); 1141 } 1142 1143 //===----------------------------------------------------------------------===// 1144 // DenseFPElementsAttr 1145 //===----------------------------------------------------------------------===// 1146 1147 template <typename Fn, typename Attr> 1148 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1149 Type newElementType, 1150 llvm::SmallVectorImpl<char> &data) { 1151 size_t bitWidth = getDenseElementBitWidth(newElementType); 1152 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1153 1154 ShapedType newArrayType; 1155 if (inType.isa<RankedTensorType>()) 1156 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1157 else if (inType.isa<UnrankedTensorType>()) 1158 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1159 else if (inType.isa<VectorType>()) 1160 newArrayType = VectorType::get(inType.getShape(), newElementType); 1161 else 1162 assert(newArrayType && "Unhandled tensor type"); 1163 1164 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1165 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 1166 1167 // Functor used to process a single element value of the attribute. 1168 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1169 auto newInt = mapping(value); 1170 assert(newInt.getBitWidth() == bitWidth); 1171 writeBits(data.data(), index * storageBitWidth, newInt); 1172 }; 1173 1174 // Check for the splat case. 1175 if (attr.isSplat()) { 1176 processElt(*attr.begin(), /*index=*/0); 1177 return newArrayType; 1178 } 1179 1180 // Otherwise, process all of the element values. 1181 uint64_t elementIdx = 0; 1182 for (auto value : attr) 1183 processElt(value, elementIdx++); 1184 return newArrayType; 1185 } 1186 1187 DenseElementsAttr DenseFPElementsAttr::mapValues( 1188 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1189 llvm::SmallVector<char, 8> elementData; 1190 auto newArrayType = 1191 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1192 1193 return getRaw(newArrayType, elementData, isSplat()); 1194 } 1195 1196 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1197 bool DenseFPElementsAttr::classof(Attribute attr) { 1198 return attr.isa<DenseElementsAttr>() && 1199 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1200 } 1201 1202 //===----------------------------------------------------------------------===// 1203 // DenseIntElementsAttr 1204 //===----------------------------------------------------------------------===// 1205 1206 DenseElementsAttr DenseIntElementsAttr::mapValues( 1207 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1208 llvm::SmallVector<char, 8> elementData; 1209 auto newArrayType = 1210 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1211 1212 return getRaw(newArrayType, elementData, isSplat()); 1213 } 1214 1215 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1216 bool DenseIntElementsAttr::classof(Attribute attr) { 1217 return attr.isa<DenseElementsAttr>() && 1218 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1219 } 1220 1221 //===----------------------------------------------------------------------===// 1222 // OpaqueElementsAttr 1223 //===----------------------------------------------------------------------===// 1224 1225 /// Return the value at the given index. If index does not refer to a valid 1226 /// element, then a null attribute is returned. 1227 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1228 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1229 return Attribute(); 1230 } 1231 1232 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1233 Dialect *dialect = getDialect().getDialect(); 1234 if (!dialect) 1235 return true; 1236 auto *interface = 1237 dialect->getRegisteredInterface<DialectDecodeAttributesInterface>(); 1238 if (!interface) 1239 return true; 1240 return failed(interface->decode(*this, result)); 1241 } 1242 1243 LogicalResult 1244 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1245 Identifier dialect, StringRef value, 1246 ShapedType type) { 1247 if (!Dialect::isValidNamespace(dialect.strref())) 1248 return emitError() << "invalid dialect namespace '" << dialect << "'"; 1249 return success(); 1250 } 1251 1252 //===----------------------------------------------------------------------===// 1253 // SparseElementsAttr 1254 //===----------------------------------------------------------------------===// 1255 1256 /// Return the value of the element at the given index. 1257 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1258 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1259 auto type = getType(); 1260 1261 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1262 // as a 1-D index array. 1263 auto sparseIndices = getIndices(); 1264 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1265 1266 // Check to see if the indices are a splat. 1267 if (sparseIndices.isSplat()) { 1268 // If the index is also not a splat of the index value, we know that the 1269 // value is zero. 1270 auto splatIndex = *sparseIndexValues.begin(); 1271 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1272 return getZeroAttr(); 1273 1274 // If the indices are a splat, we also expect the values to be a splat. 1275 assert(getValues().isSplat() && "expected splat values"); 1276 return getValues().getSplatValue(); 1277 } 1278 1279 // Build a mapping between known indices and the offset of the stored element. 1280 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1281 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1282 size_t rank = type.getRank(); 1283 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1284 mappedIndices.try_emplace( 1285 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1286 1287 // Look for the provided index key within the mapped indices. If the provided 1288 // index is not found, then return a zero attribute. 1289 auto it = mappedIndices.find(index); 1290 if (it == mappedIndices.end()) 1291 return getZeroAttr(); 1292 1293 // Otherwise, return the held sparse value element. 1294 return getValues().getValue(it->second); 1295 } 1296 1297 /// Get a zero APFloat for the given sparse attribute. 1298 APFloat SparseElementsAttr::getZeroAPFloat() const { 1299 auto eltType = getElementType().cast<FloatType>(); 1300 return APFloat(eltType.getFloatSemantics()); 1301 } 1302 1303 /// Get a zero APInt for the given sparse attribute. 1304 APInt SparseElementsAttr::getZeroAPInt() const { 1305 auto eltType = getElementType().cast<IntegerType>(); 1306 return APInt::getZero(eltType.getWidth()); 1307 } 1308 1309 /// Get a zero attribute for the given attribute type. 1310 Attribute SparseElementsAttr::getZeroAttr() const { 1311 auto eltType = getElementType(); 1312 1313 // Handle floating point elements. 1314 if (eltType.isa<FloatType>()) 1315 return FloatAttr::get(eltType, 0); 1316 1317 // Otherwise, this is an integer. 1318 // TODO: Handle StringAttr here. 1319 return IntegerAttr::get(eltType, 0); 1320 } 1321 1322 /// Flatten, and return, all of the sparse indices in this attribute in 1323 /// row-major order. 1324 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1325 std::vector<ptrdiff_t> flatSparseIndices; 1326 1327 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1328 // as a 1-D index array. 1329 auto sparseIndices = getIndices(); 1330 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1331 if (sparseIndices.isSplat()) { 1332 SmallVector<uint64_t, 8> indices(getType().getRank(), 1333 *sparseIndexValues.begin()); 1334 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1335 return flatSparseIndices; 1336 } 1337 1338 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1339 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1340 size_t rank = getType().getRank(); 1341 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1342 flatSparseIndices.push_back(getFlattenedIndex( 1343 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1344 return flatSparseIndices; 1345 } 1346 1347 LogicalResult 1348 SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1349 ShapedType type, DenseIntElementsAttr sparseIndices, 1350 DenseElementsAttr values) { 1351 ShapedType valuesType = values.getType(); 1352 if (valuesType.getRank() != 1) 1353 return emitError() << "expected 1-d tensor for sparse element values"; 1354 1355 // Verify the indices and values shape. 1356 ShapedType indicesType = sparseIndices.getType(); 1357 auto emitShapeError = [&]() { 1358 return emitError() << "expected shape ([" << type.getShape() 1359 << "]); inferred shape of indices literal ([" 1360 << indicesType.getShape() 1361 << "]); inferred shape of values literal ([" 1362 << valuesType.getShape() << "])"; 1363 }; 1364 // Verify indices shape. 1365 size_t rank = type.getRank(), indicesRank = indicesType.getRank(); 1366 if (indicesRank == 2) { 1367 if (indicesType.getDimSize(1) != static_cast<int64_t>(rank)) 1368 return emitShapeError(); 1369 } else if (indicesRank != 1 || rank != 1) { 1370 return emitShapeError(); 1371 } 1372 // Verify the values shape. 1373 int64_t numSparseIndices = indicesType.getDimSize(0); 1374 if (numSparseIndices != valuesType.getDimSize(0)) 1375 return emitShapeError(); 1376 1377 // Verify that the sparse indices are within the value shape. 1378 auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) { 1379 return emitError() 1380 << "sparse index #" << indexNum 1381 << " is not contained within the value shape, with index=[" << index 1382 << "], and type=" << type; 1383 }; 1384 1385 // Handle the case where the index values are a splat. 1386 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1387 if (sparseIndices.isSplat()) { 1388 SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin()); 1389 if (!ElementsAttr::isValidIndex(type, indices)) 1390 return emitIndexError(0, indices); 1391 return success(); 1392 } 1393 1394 // Otherwise, reinterpret each index as an ArrayRef. 1395 for (size_t i = 0, e = numSparseIndices; i != e; ++i) { 1396 ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank), 1397 rank); 1398 if (!ElementsAttr::isValidIndex(type, index)) 1399 return emitIndexError(i, index); 1400 } 1401 1402 return success(); 1403 } 1404 1405 //===----------------------------------------------------------------------===// 1406 // TypeAttr 1407 //===----------------------------------------------------------------------===// 1408 1409 void TypeAttr::walkImmediateSubElements( 1410 function_ref<void(Attribute)> walkAttrsFn, 1411 function_ref<void(Type)> walkTypesFn) const { 1412 walkTypesFn(getValue()); 1413 } 1414