1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/BuiltinDialect.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/IntegerSet.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/SymbolTable.h" 17 #include "mlir/IR/Types.h" 18 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" 19 #include "llvm/ADT/APSInt.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/Support/Endian.h" 22 23 using namespace mlir; 24 using namespace mlir::detail; 25 26 //===----------------------------------------------------------------------===// 27 /// Tablegen Attribute Definitions 28 //===----------------------------------------------------------------------===// 29 30 #define GET_ATTRDEF_CLASSES 31 #include "mlir/IR/BuiltinAttributes.cpp.inc" 32 33 //===----------------------------------------------------------------------===// 34 // BuiltinDialect 35 //===----------------------------------------------------------------------===// 36 37 void BuiltinDialect::registerAttributes() { 38 addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr, 39 DenseStringElementsAttr, DictionaryAttr, FloatAttr, 40 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr, 41 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr, 42 UnitAttr>(); 43 } 44 45 //===----------------------------------------------------------------------===// 46 // ArrayAttr 47 //===----------------------------------------------------------------------===// 48 49 void ArrayAttr::walkImmediateSubElements( 50 function_ref<void(Attribute)> walkAttrsFn, 51 function_ref<void(Type)> walkTypesFn) const { 52 for (Attribute attr : getValue()) 53 walkAttrsFn(attr); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // DictionaryAttr 58 //===----------------------------------------------------------------------===// 59 60 /// Helper function that does either an in place sort or sorts from source array 61 /// into destination. If inPlace then storage is both the source and the 62 /// destination, else value is the source and storage destination. Returns 63 /// whether source was sorted. 64 template <bool inPlace> 65 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 66 SmallVectorImpl<NamedAttribute> &storage) { 67 // Specialize for the common case. 68 switch (value.size()) { 69 case 0: 70 // Zero already sorted. 71 if (!inPlace) 72 storage.clear(); 73 break; 74 case 1: 75 // One already sorted but may need to be copied. 76 if (!inPlace) 77 storage.assign({value[0]}); 78 break; 79 case 2: { 80 bool isSorted = value[0] < value[1]; 81 if (inPlace) { 82 if (!isSorted) 83 std::swap(storage[0], storage[1]); 84 } else if (isSorted) { 85 storage.assign({value[0], value[1]}); 86 } else { 87 storage.assign({value[1], value[0]}); 88 } 89 return !isSorted; 90 } 91 default: 92 if (!inPlace) 93 storage.assign(value.begin(), value.end()); 94 // Check to see they are sorted already. 95 bool isSorted = llvm::is_sorted(value); 96 // If not, do a general sort. 97 if (!isSorted) 98 llvm::array_pod_sort(storage.begin(), storage.end()); 99 return !isSorted; 100 } 101 return false; 102 } 103 104 /// Returns an entry with a duplicate name from the given sorted array of named 105 /// attributes. Returns llvm::None if all elements have unique names. 106 static Optional<NamedAttribute> 107 findDuplicateElement(ArrayRef<NamedAttribute> value) { 108 const Optional<NamedAttribute> none{llvm::None}; 109 if (value.size() < 2) 110 return none; 111 112 if (value.size() == 2) 113 return value[0].first == value[1].first ? value[0] : none; 114 115 auto it = std::adjacent_find( 116 value.begin(), value.end(), 117 [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; }); 118 return it != value.end() ? *it : none; 119 } 120 121 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 122 SmallVectorImpl<NamedAttribute> &storage) { 123 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); 124 assert(!findDuplicateElement(storage) && 125 "DictionaryAttr element names must be unique"); 126 return isSorted; 127 } 128 129 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 130 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); 131 assert(!findDuplicateElement(array) && 132 "DictionaryAttr element names must be unique"); 133 return isSorted; 134 } 135 136 Optional<NamedAttribute> 137 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, 138 bool isSorted) { 139 if (!isSorted) 140 dictionaryAttrSort</*inPlace=*/true>(array, array); 141 return findDuplicateElement(array); 142 } 143 144 DictionaryAttr DictionaryAttr::get(MLIRContext *context, 145 ArrayRef<NamedAttribute> value) { 146 if (value.empty()) 147 return DictionaryAttr::getEmpty(context); 148 assert(llvm::all_of(value, 149 [](const NamedAttribute &attr) { return attr.second; }) && 150 "value cannot have null entries"); 151 152 // We need to sort the element list to canonicalize it. 153 SmallVector<NamedAttribute, 8> storage; 154 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 155 value = storage; 156 assert(!findDuplicateElement(value) && 157 "DictionaryAttr element names must be unique"); 158 return Base::get(context, value); 159 } 160 /// Construct a dictionary with an array of values that is known to already be 161 /// sorted by name and uniqued. 162 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context, 163 ArrayRef<NamedAttribute> value) { 164 if (value.empty()) 165 return DictionaryAttr::getEmpty(context); 166 // Ensure that the attribute elements are unique and sorted. 167 assert(llvm::is_sorted(value, 168 [](NamedAttribute l, NamedAttribute r) { 169 return l.first.strref() < r.first.strref(); 170 }) && 171 "expected attribute values to be sorted"); 172 assert(!findDuplicateElement(value) && 173 "DictionaryAttr element names must be unique"); 174 return Base::get(context, value); 175 } 176 177 /// Return the specified attribute if present, null otherwise. 178 Attribute DictionaryAttr::get(StringRef name) const { 179 Optional<NamedAttribute> attr = getNamed(name); 180 return attr ? attr->second : nullptr; 181 } 182 Attribute DictionaryAttr::get(Identifier name) const { 183 Optional<NamedAttribute> attr = getNamed(name); 184 return attr ? attr->second : nullptr; 185 } 186 187 /// Return the specified named attribute if present, None otherwise. 188 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 189 ArrayRef<NamedAttribute> values = getValue(); 190 const auto *it = llvm::lower_bound(values, name); 191 return it != values.end() && it->first == name ? *it 192 : Optional<NamedAttribute>(); 193 } 194 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { 195 for (auto elt : getValue()) 196 if (elt.first == name) 197 return elt; 198 return llvm::None; 199 } 200 201 DictionaryAttr::iterator DictionaryAttr::begin() const { 202 return getValue().begin(); 203 } 204 DictionaryAttr::iterator DictionaryAttr::end() const { 205 return getValue().end(); 206 } 207 size_t DictionaryAttr::size() const { return getValue().size(); } 208 209 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { 210 return Base::get(context, ArrayRef<NamedAttribute>()); 211 } 212 213 void DictionaryAttr::walkImmediateSubElements( 214 function_ref<void(Attribute)> walkAttrsFn, 215 function_ref<void(Type)> walkTypesFn) const { 216 for (Attribute attr : llvm::make_second_range(getValue())) 217 walkAttrsFn(attr); 218 } 219 220 //===----------------------------------------------------------------------===// 221 // StringAttr 222 //===----------------------------------------------------------------------===// 223 224 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) { 225 return Base::get(context, "", NoneType::get(context)); 226 } 227 228 /// Twine support for StringAttr. 229 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) { 230 // Fast-path empty twine. 231 if (twine.isTriviallyEmpty()) 232 return get(context); 233 SmallVector<char, 32> tempStr; 234 return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context)); 235 } 236 237 /// Twine support for StringAttr. 238 StringAttr StringAttr::get(const Twine &twine, Type type) { 239 SmallVector<char, 32> tempStr; 240 return Base::get(type.getContext(), twine.toStringRef(tempStr), type); 241 } 242 243 //===----------------------------------------------------------------------===// 244 // FloatAttr 245 //===----------------------------------------------------------------------===// 246 247 double FloatAttr::getValueAsDouble() const { 248 return getValueAsDouble(getValue()); 249 } 250 double FloatAttr::getValueAsDouble(APFloat value) { 251 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 252 bool losesInfo = false; 253 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 254 &losesInfo); 255 } 256 return value.convertToDouble(); 257 } 258 259 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 260 Type type, APFloat value) { 261 // Verify that the type is correct. 262 if (!type.isa<FloatType>()) 263 return emitError() << "expected floating point type"; 264 265 // Verify that the type semantics match that of the value. 266 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 267 return emitError() 268 << "FloatAttr type doesn't match the type implied by its value"; 269 } 270 return success(); 271 } 272 273 //===----------------------------------------------------------------------===// 274 // SymbolRefAttr 275 //===----------------------------------------------------------------------===// 276 277 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value, 278 ArrayRef<FlatSymbolRefAttr> nestedRefs) { 279 return get(StringAttr::get(ctx, value), nestedRefs); 280 } 281 282 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { 283 return get(ctx, value, {}).cast<FlatSymbolRefAttr>(); 284 } 285 286 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) { 287 return get(value, {}).cast<FlatSymbolRefAttr>(); 288 } 289 290 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) { 291 auto symName = 292 symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); 293 assert(symName && "value does not have a valid symbol name"); 294 return SymbolRefAttr::get(symName); 295 } 296 297 StringAttr SymbolRefAttr::getLeafReference() const { 298 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 299 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr(); 300 } 301 302 //===----------------------------------------------------------------------===// 303 // IntegerAttr 304 //===----------------------------------------------------------------------===// 305 306 int64_t IntegerAttr::getInt() const { 307 assert((getType().isIndex() || getType().isSignlessInteger()) && 308 "must be signless integer"); 309 return getValue().getSExtValue(); 310 } 311 312 int64_t IntegerAttr::getSInt() const { 313 assert(getType().isSignedInteger() && "must be signed integer"); 314 return getValue().getSExtValue(); 315 } 316 317 uint64_t IntegerAttr::getUInt() const { 318 assert(getType().isUnsignedInteger() && "must be unsigned integer"); 319 return getValue().getZExtValue(); 320 } 321 322 /// Return the value as an APSInt which carries the signed from the type of 323 /// the attribute. This traps on signless integers types! 324 APSInt IntegerAttr::getAPSInt() const { 325 assert(!getType().isSignlessInteger() && 326 "Signless integers don't carry a sign for APSInt"); 327 return APSInt(getValue(), getType().isUnsignedInteger()); 328 } 329 330 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, 331 Type type, APInt value) { 332 if (IntegerType integerType = type.dyn_cast<IntegerType>()) { 333 if (integerType.getWidth() != value.getBitWidth()) 334 return emitError() << "integer type bit width (" << integerType.getWidth() 335 << ") doesn't match value bit width (" 336 << value.getBitWidth() << ")"; 337 return success(); 338 } 339 if (type.isa<IndexType>()) 340 return success(); 341 return emitError() << "expected integer or index type"; 342 } 343 344 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { 345 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value)); 346 return attr.cast<BoolAttr>(); 347 } 348 349 //===----------------------------------------------------------------------===// 350 // BoolAttr 351 352 bool BoolAttr::getValue() const { 353 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl); 354 return storage->value.getBoolValue(); 355 } 356 357 bool BoolAttr::classof(Attribute attr) { 358 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 359 return intAttr && intAttr.getType().isSignlessInteger(1); 360 } 361 362 //===----------------------------------------------------------------------===// 363 // OpaqueAttr 364 //===----------------------------------------------------------------------===// 365 366 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError, 367 Identifier dialect, StringRef attrData, 368 Type type) { 369 if (!Dialect::isValidNamespace(dialect.strref())) 370 return emitError() << "invalid dialect namespace '" << dialect << "'"; 371 372 // Check that the dialect is actually registered. 373 MLIRContext *context = dialect.getContext(); 374 if (!context->allowsUnregisteredDialects() && 375 !context->getLoadedDialect(dialect.strref())) { 376 return emitError() 377 << "#" << dialect << "<\"" << attrData << "\"> : " << type 378 << " attribute created with unregistered dialect. If this is " 379 "intended, please call allowUnregisteredDialects() on the " 380 "MLIRContext, or use -allow-unregistered-dialect with " 381 "the MLIR opt tool used"; 382 } 383 384 return success(); 385 } 386 387 //===----------------------------------------------------------------------===// 388 // DenseElementsAttr Utilities 389 //===----------------------------------------------------------------------===// 390 391 /// Get the bitwidth of a dense element type within the buffer. 392 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 393 static size_t getDenseElementStorageWidth(size_t origWidth) { 394 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 395 } 396 static size_t getDenseElementStorageWidth(Type elementType) { 397 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 398 } 399 400 /// Set a bit to a specific value. 401 static void setBit(char *rawData, size_t bitPos, bool value) { 402 if (value) 403 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 404 else 405 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 406 } 407 408 /// Return the value of the specified bit. 409 static bool getBit(const char *rawData, size_t bitPos) { 410 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 411 } 412 413 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for 414 /// BE format. 415 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, 416 char *result) { 417 assert(llvm::support::endian::system_endianness() == // NOLINT 418 llvm::support::endianness::big); // NOLINT 419 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 420 421 // Copy the words filled with data. 422 // For example, when `value` has 2 words, the first word is filled with data. 423 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| 424 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 425 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 426 numFilledWords, result); 427 // Convert last word of APInt to LE format and store it in char 428 // array(`valueLE`). 429 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| 430 size_t lastWordPos = numFilledWords; 431 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); 432 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 433 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, 434 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); 435 // Extract actual APInt data from `valueLE`, convert endianness to BE format, 436 // and store it in `result`. 437 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| 438 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 439 valueLE.begin(), result + lastWordPos, 440 (numBytes - lastWordPos) * CHAR_BIT, 1); 441 } 442 443 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE 444 /// format. 445 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, 446 APInt &result) { 447 assert(llvm::support::endian::system_endianness() == // NOLINT 448 llvm::support::endianness::big); // NOLINT 449 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 450 451 // Copy the data that fills the word of `result` from `inArray`. 452 // For example, when `result` has 2 words, the first word will be filled with 453 // data. So, the first 8 bytes are copied from `inArray` here. 454 // `inArray` (10 bytes, BE): |abcdefgh|ij| 455 // ==> `result` (2 words, BE): |abcdefgh|--------| 456 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 457 std::copy_n( 458 inArray, numFilledWords, 459 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 460 461 // Convert array data which will be last word of `result` to LE format, and 462 // store it in char array(`inArrayLE`). 463 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| 464 size_t lastWordPos = numFilledWords; 465 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); 466 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 467 inArray + lastWordPos, inArrayLE.begin(), 468 (numBytes - lastWordPos) * CHAR_BIT, 1); 469 470 // Convert `inArrayLE` to BE format, and store it in last word of `result`. 471 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| 472 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 473 inArrayLE.begin(), 474 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + 475 lastWordPos, 476 APInt::APINT_BITS_PER_WORD, 1); 477 } 478 479 /// Writes value to the bit position `bitPos` in array `rawData`. 480 static void writeBits(char *rawData, size_t bitPos, APInt value) { 481 size_t bitWidth = value.getBitWidth(); 482 483 // If the bitwidth is 1 we just toggle the specific bit. 484 if (bitWidth == 1) 485 return setBit(rawData, bitPos, value.isOneValue()); 486 487 // Otherwise, the bit position is guaranteed to be byte aligned. 488 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 489 if (llvm::support::endian::system_endianness() == 490 llvm::support::endianness::big) { 491 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. 492 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 493 // work correctly in BE format. 494 // ex. `value` (2 words including 10 bytes) 495 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| 496 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), 497 rawData + (bitPos / CHAR_BIT)); 498 } else { 499 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 500 llvm::divideCeil(bitWidth, CHAR_BIT), 501 rawData + (bitPos / CHAR_BIT)); 502 } 503 } 504 505 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 506 /// `rawData`. 507 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 508 // Handle a boolean bit position. 509 if (bitWidth == 1) 510 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 511 512 // Otherwise, the bit position must be 8-bit aligned. 513 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 514 APInt result(bitWidth, 0); 515 if (llvm::support::endian::system_endianness() == 516 llvm::support::endianness::big) { 517 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. 518 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 519 // work correctly in BE format. 520 // ex. `result` (2 words including 10 bytes) 521 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function 522 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), 523 llvm::divideCeil(bitWidth, CHAR_BIT), result); 524 } else { 525 std::copy_n(rawData + (bitPos / CHAR_BIT), 526 llvm::divideCeil(bitWidth, CHAR_BIT), 527 const_cast<char *>( 528 reinterpret_cast<const char *>(result.getRawData()))); 529 } 530 return result; 531 } 532 533 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 534 /// the same element count as 'type'. 535 template <typename Values> 536 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 537 return (values.size() == 1) || 538 (type.getNumElements() == static_cast<int64_t>(values.size())); 539 } 540 541 //===----------------------------------------------------------------------===// 542 // DenseElementsAttr Iterators 543 //===----------------------------------------------------------------------===// 544 545 //===----------------------------------------------------------------------===// 546 // AttributeElementIterator 547 548 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 549 DenseElementsAttr attr, size_t index) 550 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 551 Attribute, Attribute, Attribute>( 552 attr.getAsOpaquePointer(), index) {} 553 554 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 555 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 556 Type eltTy = owner.getElementType(); 557 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 558 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 559 if (eltTy.isa<IndexType>()) 560 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 561 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 562 IntElementIterator intIt(owner, index); 563 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 564 return FloatAttr::get(eltTy, *floatIt); 565 } 566 if (auto complexTy = eltTy.dyn_cast<ComplexType>()) { 567 auto complexEltTy = complexTy.getElementType(); 568 ComplexIntElementIterator complexIntIt(owner, index); 569 if (complexEltTy.isa<IntegerType>()) { 570 auto value = *complexIntIt; 571 auto real = IntegerAttr::get(complexEltTy, value.real()); 572 auto imag = IntegerAttr::get(complexEltTy, value.imag()); 573 return ArrayAttr::get(complexTy.getContext(), 574 ArrayRef<Attribute>{real, imag}); 575 } 576 577 ComplexFloatElementIterator complexFloatIt( 578 complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt); 579 auto value = *complexFloatIt; 580 auto real = FloatAttr::get(complexEltTy, value.real()); 581 auto imag = FloatAttr::get(complexEltTy, value.imag()); 582 return ArrayAttr::get(complexTy.getContext(), 583 ArrayRef<Attribute>{real, imag}); 584 } 585 if (owner.isa<DenseStringElementsAttr>()) { 586 ArrayRef<StringRef> vals = owner.getRawStringData(); 587 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 588 } 589 llvm_unreachable("unexpected element type"); 590 } 591 592 //===----------------------------------------------------------------------===// 593 // BoolElementIterator 594 595 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 596 DenseElementsAttr attr, size_t dataIndex) 597 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 598 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 599 600 bool DenseElementsAttr::BoolElementIterator::operator*() const { 601 return getBit(getData(), getDataIndex()); 602 } 603 604 //===----------------------------------------------------------------------===// 605 // IntElementIterator 606 607 DenseElementsAttr::IntElementIterator::IntElementIterator( 608 DenseElementsAttr attr, size_t dataIndex) 609 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 610 attr.getRawData().data(), attr.isSplat(), dataIndex), 611 bitWidth(getDenseElementBitWidth(attr.getElementType())) {} 612 613 APInt DenseElementsAttr::IntElementIterator::operator*() const { 614 return readBits(getData(), 615 getDataIndex() * getDenseElementStorageWidth(bitWidth), 616 bitWidth); 617 } 618 619 //===----------------------------------------------------------------------===// 620 // ComplexIntElementIterator 621 622 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 623 DenseElementsAttr attr, size_t dataIndex) 624 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 625 std::complex<APInt>, std::complex<APInt>, 626 std::complex<APInt>>( 627 attr.getRawData().data(), attr.isSplat(), dataIndex) { 628 auto complexType = attr.getElementType().cast<ComplexType>(); 629 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 630 } 631 632 std::complex<APInt> 633 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 634 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 635 size_t offset = getDataIndex() * storageWidth * 2; 636 return {readBits(getData(), offset, bitWidth), 637 readBits(getData(), offset + storageWidth, bitWidth)}; 638 } 639 640 //===----------------------------------------------------------------------===// 641 // FloatElementIterator 642 643 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 644 const llvm::fltSemantics &smt, IntElementIterator it) 645 : llvm::mapped_iterator<IntElementIterator, 646 std::function<APFloat(const APInt &)>>( 647 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 648 649 //===----------------------------------------------------------------------===// 650 // ComplexFloatElementIterator 651 652 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( 653 const llvm::fltSemantics &smt, ComplexIntElementIterator it) 654 : llvm::mapped_iterator< 655 ComplexIntElementIterator, 656 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( 657 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { 658 return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; 659 }) {} 660 661 //===----------------------------------------------------------------------===// 662 // DenseElementsAttr 663 //===----------------------------------------------------------------------===// 664 665 /// Method for support type inquiry through isa, cast and dyn_cast. 666 bool DenseElementsAttr::classof(Attribute attr) { 667 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 668 } 669 670 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 671 ArrayRef<Attribute> values) { 672 assert(hasSameElementsOrSplat(type, values)); 673 674 // If the element type is not based on int/float/index, assume it is a string 675 // type. 676 auto eltType = type.getElementType(); 677 if (!type.getElementType().isIntOrIndexOrFloat()) { 678 SmallVector<StringRef, 8> stringValues; 679 stringValues.reserve(values.size()); 680 for (Attribute attr : values) { 681 assert(attr.isa<StringAttr>() && 682 "expected string value for non integer/index/float element"); 683 stringValues.push_back(attr.cast<StringAttr>().getValue()); 684 } 685 return get(type, stringValues); 686 } 687 688 // Otherwise, get the raw storage width to use for the allocation. 689 size_t bitWidth = getDenseElementBitWidth(eltType); 690 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 691 692 // Compress the attribute values into a character buffer. 693 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 694 values.size()); 695 APInt intVal; 696 for (unsigned i = 0, e = values.size(); i < e; ++i) { 697 assert(eltType == values[i].getType() && 698 "expected attribute value to have element type"); 699 if (eltType.isa<FloatType>()) 700 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 701 else if (eltType.isa<IntegerType, IndexType>()) 702 intVal = values[i].cast<IntegerAttr>().getValue(); 703 else 704 llvm_unreachable("unexpected element type"); 705 706 assert(intVal.getBitWidth() == bitWidth && 707 "expected value to have same bitwidth as element type"); 708 writeBits(data.data(), i * storageBitWidth, intVal); 709 } 710 return DenseIntOrFPElementsAttr::getRaw(type, data, 711 /*isSplat=*/(values.size() == 1)); 712 } 713 714 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 715 ArrayRef<bool> values) { 716 assert(hasSameElementsOrSplat(type, values)); 717 assert(type.getElementType().isInteger(1)); 718 719 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 720 for (int i = 0, e = values.size(); i != e; ++i) 721 setBit(buff.data(), i, values[i]); 722 return DenseIntOrFPElementsAttr::getRaw(type, buff, 723 /*isSplat=*/(values.size() == 1)); 724 } 725 726 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 727 ArrayRef<StringRef> values) { 728 assert(!type.getElementType().isIntOrFloat()); 729 return DenseStringElementsAttr::get(type, values); 730 } 731 732 /// Constructs a dense integer elements attribute from an array of APInt 733 /// values. Each APInt value is expected to have the same bitwidth as the 734 /// element type of 'type'. 735 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 736 ArrayRef<APInt> values) { 737 assert(type.getElementType().isIntOrIndex()); 738 assert(hasSameElementsOrSplat(type, values)); 739 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 740 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 741 /*isSplat=*/(values.size() == 1)); 742 } 743 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 744 ArrayRef<std::complex<APInt>> values) { 745 ComplexType complex = type.getElementType().cast<ComplexType>(); 746 assert(complex.getElementType().isa<IntegerType>()); 747 assert(hasSameElementsOrSplat(type, values)); 748 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 749 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 750 values.size() * 2); 751 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, 752 /*isSplat=*/(values.size() == 1)); 753 } 754 755 // Constructs a dense float elements attribute from an array of APFloat 756 // values. Each APFloat value is expected to have the same bitwidth as the 757 // element type of 'type'. 758 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 759 ArrayRef<APFloat> values) { 760 assert(type.getElementType().isa<FloatType>()); 761 assert(hasSameElementsOrSplat(type, values)); 762 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 763 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 764 /*isSplat=*/(values.size() == 1)); 765 } 766 DenseElementsAttr 767 DenseElementsAttr::get(ShapedType type, 768 ArrayRef<std::complex<APFloat>> values) { 769 ComplexType complex = type.getElementType().cast<ComplexType>(); 770 assert(complex.getElementType().isa<FloatType>()); 771 assert(hasSameElementsOrSplat(type, values)); 772 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 773 values.size() * 2); 774 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 775 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, 776 /*isSplat=*/(values.size() == 1)); 777 } 778 779 /// Construct a dense elements attribute from a raw buffer representing the 780 /// data for this attribute. Users should generally not use this methods as 781 /// the expected buffer format may not be a form the user expects. 782 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 783 ArrayRef<char> rawBuffer, 784 bool isSplatBuffer) { 785 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 786 } 787 788 /// Returns true if the given buffer is a valid raw buffer for the given type. 789 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 790 ArrayRef<char> rawBuffer, 791 bool &detectedSplat) { 792 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 793 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 794 795 // Storage width of 1 is special as it is packed by the bit. 796 if (storageWidth == 1) { 797 // Check for a splat, or a buffer equal to the number of elements which 798 // consists of either all 0's or all 1's. 799 detectedSplat = false; 800 if (rawBuffer.size() == 1) { 801 auto rawByte = static_cast<uint8_t>(rawBuffer[0]); 802 if (rawByte == 0 || rawByte == 0xff) { 803 detectedSplat = true; 804 return true; 805 } 806 } 807 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); 808 } 809 // All other types are 8-bit aligned. 810 if ((detectedSplat = rawBufferWidth == storageWidth)) 811 return true; 812 return rawBufferWidth == (storageWidth * type.getNumElements()); 813 } 814 815 /// Check the information for a C++ data type, check if this type is valid for 816 /// the current attribute. This method is used to verify specific type 817 /// invariants that the templatized 'getValues' method cannot. 818 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 819 bool isSigned) { 820 // Make sure that the data element size is the same as the type element width. 821 if (getDenseElementBitWidth(type) != 822 static_cast<size_t>(dataEltSize * CHAR_BIT)) 823 return false; 824 825 // Check that the element type is either float or integer or index. 826 if (!isInt) 827 return type.isa<FloatType>(); 828 if (type.isIndex()) 829 return true; 830 831 auto intType = type.dyn_cast<IntegerType>(); 832 if (!intType) 833 return false; 834 835 // Make sure signedness semantics is consistent. 836 if (intType.isSignless()) 837 return true; 838 return intType.isSigned() ? isSigned : !isSigned; 839 } 840 841 /// Defaults down the subclass implementation. 842 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 843 ArrayRef<char> data, 844 int64_t dataEltSize, 845 bool isInt, bool isSigned) { 846 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 847 isSigned); 848 } 849 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 850 ArrayRef<char> data, 851 int64_t dataEltSize, 852 bool isInt, 853 bool isSigned) { 854 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 855 isInt, isSigned); 856 } 857 858 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 859 bool isSigned) const { 860 return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned); 861 } 862 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 863 bool isSigned) const { 864 return ::isValidIntOrFloat( 865 getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, 866 isInt, isSigned); 867 } 868 869 /// Returns true if this attribute corresponds to a splat, i.e. if all element 870 /// values are the same. 871 bool DenseElementsAttr::isSplat() const { 872 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 873 } 874 875 /// Return if the given complex type has an integer element type. 876 static bool isComplexOfIntType(Type type) { 877 return type.cast<ComplexType>().getElementType().isa<IntegerType>(); 878 } 879 880 auto DenseElementsAttr::getComplexIntValues() const 881 -> llvm::iterator_range<ComplexIntElementIterator> { 882 assert(isComplexOfIntType(getElementType()) && 883 "expected complex integral type"); 884 return {ComplexIntElementIterator(*this, 0), 885 ComplexIntElementIterator(*this, getNumElements())}; 886 } 887 auto DenseElementsAttr::complex_value_begin() const 888 -> ComplexIntElementIterator { 889 assert(isComplexOfIntType(getElementType()) && 890 "expected complex integral type"); 891 return ComplexIntElementIterator(*this, 0); 892 } 893 auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator { 894 assert(isComplexOfIntType(getElementType()) && 895 "expected complex integral type"); 896 return ComplexIntElementIterator(*this, getNumElements()); 897 } 898 899 /// Return the held element values as a range of APFloat. The element type of 900 /// this attribute must be of float type. 901 auto DenseElementsAttr::getFloatValues() const 902 -> llvm::iterator_range<FloatElementIterator> { 903 auto elementType = getElementType().cast<FloatType>(); 904 const auto &elementSemantics = elementType.getFloatSemantics(); 905 return {FloatElementIterator(elementSemantics, raw_int_begin()), 906 FloatElementIterator(elementSemantics, raw_int_end())}; 907 } 908 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 909 auto elementType = getElementType().cast<FloatType>(); 910 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin()); 911 } 912 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 913 auto elementType = getElementType().cast<FloatType>(); 914 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end()); 915 } 916 917 auto DenseElementsAttr::getComplexFloatValues() const 918 -> llvm::iterator_range<ComplexFloatElementIterator> { 919 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 920 assert(eltTy.isa<FloatType>() && "expected complex float type"); 921 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 922 return {{semantics, {*this, 0}}, 923 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 924 } 925 auto DenseElementsAttr::complex_float_value_begin() const 926 -> ComplexFloatElementIterator { 927 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 928 assert(eltTy.isa<FloatType>() && "expected complex float type"); 929 return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}}; 930 } 931 auto DenseElementsAttr::complex_float_value_end() const 932 -> ComplexFloatElementIterator { 933 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 934 assert(eltTy.isa<FloatType>() && "expected complex float type"); 935 return {eltTy.cast<FloatType>().getFloatSemantics(), 936 {*this, static_cast<size_t>(getNumElements())}}; 937 } 938 939 /// Return the raw storage data held by this attribute. 940 ArrayRef<char> DenseElementsAttr::getRawData() const { 941 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data; 942 } 943 944 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 945 return static_cast<DenseStringElementsAttrStorage *>(impl)->data; 946 } 947 948 /// Return a new DenseElementsAttr that has the same data as the current 949 /// attribute, but has been reshaped to 'newType'. The new type must have the 950 /// same total number of elements as well as element type. 951 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 952 ShapedType curType = getType(); 953 if (curType == newType) 954 return *this; 955 956 assert(newType.getElementType() == curType.getElementType() && 957 "expected the same element type"); 958 assert(newType.getNumElements() == curType.getNumElements() && 959 "expected the same number of elements"); 960 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 961 } 962 963 /// Return a new DenseElementsAttr that has the same data as the current 964 /// attribute, but has bitcast elements such that it is now 'newType'. The new 965 /// type must have the same shape and element types of the same bitwidth as the 966 /// current type. 967 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) { 968 ShapedType curType = getType(); 969 Type curElType = curType.getElementType(); 970 if (curElType == newElType) 971 return *this; 972 973 assert(getDenseElementBitWidth(newElType) == 974 getDenseElementBitWidth(curElType) && 975 "expected element types with the same bitwidth"); 976 return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType), 977 getRawData(), isSplat()); 978 } 979 980 DenseElementsAttr 981 DenseElementsAttr::mapValues(Type newElementType, 982 function_ref<APInt(const APInt &)> mapping) const { 983 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 984 } 985 986 DenseElementsAttr DenseElementsAttr::mapValues( 987 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 988 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 989 } 990 991 ShapedType DenseElementsAttr::getType() const { 992 return Attribute::getType().cast<ShapedType>(); 993 } 994 995 Type DenseElementsAttr::getElementType() const { 996 return getType().getElementType(); 997 } 998 999 int64_t DenseElementsAttr::getNumElements() const { 1000 return getType().getNumElements(); 1001 } 1002 1003 //===----------------------------------------------------------------------===// 1004 // DenseIntOrFPElementsAttr 1005 //===----------------------------------------------------------------------===// 1006 1007 /// Utility method to write a range of APInt values to a buffer. 1008 template <typename APRangeT> 1009 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1010 APRangeT &&values) { 1011 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); 1012 size_t offset = 0; 1013 for (auto it = values.begin(), e = values.end(); it != e; 1014 ++it, offset += storageWidth) { 1015 assert((*it).getBitWidth() <= storageWidth); 1016 writeBits(data.data(), offset, *it); 1017 } 1018 } 1019 1020 /// Constructs a dense elements attribute from an array of raw APFloat values. 1021 /// Each APFloat value is expected to have the same bitwidth as the element 1022 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1023 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1024 size_t storageWidth, 1025 ArrayRef<APFloat> values, 1026 bool isSplat) { 1027 std::vector<char> data; 1028 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1029 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1030 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1031 } 1032 1033 /// Constructs a dense elements attribute from an array of raw APInt values. 1034 /// Each APInt value is expected to have the same bitwidth as the element type 1035 /// of 'type'. 1036 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1037 size_t storageWidth, 1038 ArrayRef<APInt> values, 1039 bool isSplat) { 1040 std::vector<char> data; 1041 writeAPIntsToBuffer(storageWidth, data, values); 1042 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1043 } 1044 1045 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1046 ArrayRef<char> data, 1047 bool isSplat) { 1048 assert((type.isa<RankedTensorType, VectorType>()) && 1049 "type must be ranked tensor or vector"); 1050 assert(type.hasStaticShape() && "type must have static shape"); 1051 return Base::get(type.getContext(), type, data, isSplat); 1052 } 1053 1054 /// Overload of the raw 'get' method that asserts that the given type is of 1055 /// complex type. This method is used to verify type invariants that the 1056 /// templatized 'get' method cannot. 1057 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1058 ArrayRef<char> data, 1059 int64_t dataEltSize, 1060 bool isInt, 1061 bool isSigned) { 1062 assert(::isValidIntOrFloat( 1063 type.getElementType().cast<ComplexType>().getElementType(), 1064 dataEltSize / 2, isInt, isSigned)); 1065 1066 int64_t numElements = data.size() / dataEltSize; 1067 assert(numElements == 1 || numElements == type.getNumElements()); 1068 return getRaw(type, data, /*isSplat=*/numElements == 1); 1069 } 1070 1071 /// Overload of the 'getRaw' method that asserts that the given type is of 1072 /// integer type. This method is used to verify type invariants that the 1073 /// templatized 'get' method cannot. 1074 DenseElementsAttr 1075 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1076 int64_t dataEltSize, bool isInt, 1077 bool isSigned) { 1078 assert( 1079 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1080 1081 int64_t numElements = data.size() / dataEltSize; 1082 assert(numElements == 1 || numElements == type.getNumElements()); 1083 return getRaw(type, data, /*isSplat=*/numElements == 1); 1084 } 1085 1086 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 1087 const char *inRawData, char *outRawData, size_t elementBitWidth, 1088 size_t numElements) { 1089 using llvm::support::ulittle16_t; 1090 using llvm::support::ulittle32_t; 1091 using llvm::support::ulittle64_t; 1092 1093 assert(llvm::support::endian::system_endianness() == // NOLINT 1094 llvm::support::endianness::big); // NOLINT 1095 // NOLINT to avoid warning message about replacing by static_assert() 1096 1097 // Following std::copy_n always converts endianness on BE machine. 1098 switch (elementBitWidth) { 1099 case 16: { 1100 const ulittle16_t *inRawDataPos = 1101 reinterpret_cast<const ulittle16_t *>(inRawData); 1102 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); 1103 std::copy_n(inRawDataPos, numElements, outDataPos); 1104 break; 1105 } 1106 case 32: { 1107 const ulittle32_t *inRawDataPos = 1108 reinterpret_cast<const ulittle32_t *>(inRawData); 1109 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); 1110 std::copy_n(inRawDataPos, numElements, outDataPos); 1111 break; 1112 } 1113 case 64: { 1114 const ulittle64_t *inRawDataPos = 1115 reinterpret_cast<const ulittle64_t *>(inRawData); 1116 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); 1117 std::copy_n(inRawDataPos, numElements, outDataPos); 1118 break; 1119 } 1120 default: { 1121 size_t nBytes = elementBitWidth / CHAR_BIT; 1122 for (size_t i = 0; i < nBytes; i++) 1123 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); 1124 break; 1125 } 1126 } 1127 } 1128 1129 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1130 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, 1131 ShapedType type) { 1132 size_t numElements = type.getNumElements(); 1133 Type elementType = type.getElementType(); 1134 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1135 elementType = complexTy.getElementType(); 1136 numElements = numElements * 2; 1137 } 1138 size_t elementBitWidth = getDenseElementStorageWidth(elementType); 1139 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && 1140 inRawData.size() <= outRawData.size()); 1141 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), 1142 elementBitWidth, numElements); 1143 } 1144 1145 //===----------------------------------------------------------------------===// 1146 // DenseFPElementsAttr 1147 //===----------------------------------------------------------------------===// 1148 1149 template <typename Fn, typename Attr> 1150 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1151 Type newElementType, 1152 llvm::SmallVectorImpl<char> &data) { 1153 size_t bitWidth = getDenseElementBitWidth(newElementType); 1154 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1155 1156 ShapedType newArrayType; 1157 if (inType.isa<RankedTensorType>()) 1158 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1159 else if (inType.isa<UnrankedTensorType>()) 1160 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1161 else if (inType.isa<VectorType>()) 1162 newArrayType = VectorType::get(inType.getShape(), newElementType); 1163 else 1164 assert(newArrayType && "Unhandled tensor type"); 1165 1166 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1167 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 1168 1169 // Functor used to process a single element value of the attribute. 1170 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1171 auto newInt = mapping(value); 1172 assert(newInt.getBitWidth() == bitWidth); 1173 writeBits(data.data(), index * storageBitWidth, newInt); 1174 }; 1175 1176 // Check for the splat case. 1177 if (attr.isSplat()) { 1178 processElt(*attr.begin(), /*index=*/0); 1179 return newArrayType; 1180 } 1181 1182 // Otherwise, process all of the element values. 1183 uint64_t elementIdx = 0; 1184 for (auto value : attr) 1185 processElt(value, elementIdx++); 1186 return newArrayType; 1187 } 1188 1189 DenseElementsAttr DenseFPElementsAttr::mapValues( 1190 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1191 llvm::SmallVector<char, 8> elementData; 1192 auto newArrayType = 1193 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1194 1195 return getRaw(newArrayType, elementData, isSplat()); 1196 } 1197 1198 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1199 bool DenseFPElementsAttr::classof(Attribute attr) { 1200 return attr.isa<DenseElementsAttr>() && 1201 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1202 } 1203 1204 //===----------------------------------------------------------------------===// 1205 // DenseIntElementsAttr 1206 //===----------------------------------------------------------------------===// 1207 1208 DenseElementsAttr DenseIntElementsAttr::mapValues( 1209 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1210 llvm::SmallVector<char, 8> elementData; 1211 auto newArrayType = 1212 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1213 1214 return getRaw(newArrayType, elementData, isSplat()); 1215 } 1216 1217 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1218 bool DenseIntElementsAttr::classof(Attribute attr) { 1219 return attr.isa<DenseElementsAttr>() && 1220 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1221 } 1222 1223 //===----------------------------------------------------------------------===// 1224 // OpaqueElementsAttr 1225 //===----------------------------------------------------------------------===// 1226 1227 /// Return the value at the given index. If index does not refer to a valid 1228 /// element, then a null attribute is returned. 1229 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1230 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1231 return Attribute(); 1232 } 1233 1234 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1235 Dialect *dialect = getDialect().getDialect(); 1236 if (!dialect) 1237 return true; 1238 auto *interface = 1239 dialect->getRegisteredInterface<DialectDecodeAttributesInterface>(); 1240 if (!interface) 1241 return true; 1242 return failed(interface->decode(*this, result)); 1243 } 1244 1245 LogicalResult 1246 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1247 Identifier dialect, StringRef value, 1248 ShapedType type) { 1249 if (!Dialect::isValidNamespace(dialect.strref())) 1250 return emitError() << "invalid dialect namespace '" << dialect << "'"; 1251 return success(); 1252 } 1253 1254 //===----------------------------------------------------------------------===// 1255 // SparseElementsAttr 1256 //===----------------------------------------------------------------------===// 1257 1258 /// Return the value of the element at the given index. 1259 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1260 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1261 auto type = getType(); 1262 1263 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1264 // as a 1-D index array. 1265 auto sparseIndices = getIndices(); 1266 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1267 1268 // Check to see if the indices are a splat. 1269 if (sparseIndices.isSplat()) { 1270 // If the index is also not a splat of the index value, we know that the 1271 // value is zero. 1272 auto splatIndex = *sparseIndexValues.begin(); 1273 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1274 return getZeroAttr(); 1275 1276 // If the indices are a splat, we also expect the values to be a splat. 1277 assert(getValues().isSplat() && "expected splat values"); 1278 return getValues().getSplatValue(); 1279 } 1280 1281 // Build a mapping between known indices and the offset of the stored element. 1282 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1283 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1284 size_t rank = type.getRank(); 1285 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1286 mappedIndices.try_emplace( 1287 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1288 1289 // Look for the provided index key within the mapped indices. If the provided 1290 // index is not found, then return a zero attribute. 1291 auto it = mappedIndices.find(index); 1292 if (it == mappedIndices.end()) 1293 return getZeroAttr(); 1294 1295 // Otherwise, return the held sparse value element. 1296 return getValues().getValue(it->second); 1297 } 1298 1299 /// Get a zero APFloat for the given sparse attribute. 1300 APFloat SparseElementsAttr::getZeroAPFloat() const { 1301 auto eltType = getElementType().cast<FloatType>(); 1302 return APFloat(eltType.getFloatSemantics()); 1303 } 1304 1305 /// Get a zero APInt for the given sparse attribute. 1306 APInt SparseElementsAttr::getZeroAPInt() const { 1307 auto eltType = getElementType().cast<IntegerType>(); 1308 return APInt::getZero(eltType.getWidth()); 1309 } 1310 1311 /// Get a zero attribute for the given attribute type. 1312 Attribute SparseElementsAttr::getZeroAttr() const { 1313 auto eltType = getElementType(); 1314 1315 // Handle floating point elements. 1316 if (eltType.isa<FloatType>()) 1317 return FloatAttr::get(eltType, 0); 1318 1319 // Otherwise, this is an integer. 1320 // TODO: Handle StringAttr here. 1321 return IntegerAttr::get(eltType, 0); 1322 } 1323 1324 /// Flatten, and return, all of the sparse indices in this attribute in 1325 /// row-major order. 1326 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1327 std::vector<ptrdiff_t> flatSparseIndices; 1328 1329 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1330 // as a 1-D index array. 1331 auto sparseIndices = getIndices(); 1332 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1333 if (sparseIndices.isSplat()) { 1334 SmallVector<uint64_t, 8> indices(getType().getRank(), 1335 *sparseIndexValues.begin()); 1336 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1337 return flatSparseIndices; 1338 } 1339 1340 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1341 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1342 size_t rank = getType().getRank(); 1343 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1344 flatSparseIndices.push_back(getFlattenedIndex( 1345 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1346 return flatSparseIndices; 1347 } 1348 1349 LogicalResult 1350 SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1351 ShapedType type, DenseIntElementsAttr sparseIndices, 1352 DenseElementsAttr values) { 1353 ShapedType valuesType = values.getType(); 1354 if (valuesType.getRank() != 1) 1355 return emitError() << "expected 1-d tensor for sparse element values"; 1356 1357 // Verify the indices and values shape. 1358 ShapedType indicesType = sparseIndices.getType(); 1359 auto emitShapeError = [&]() { 1360 return emitError() << "expected shape ([" << type.getShape() 1361 << "]); inferred shape of indices literal ([" 1362 << indicesType.getShape() 1363 << "]); inferred shape of values literal ([" 1364 << valuesType.getShape() << "])"; 1365 }; 1366 // Verify indices shape. 1367 size_t rank = type.getRank(), indicesRank = indicesType.getRank(); 1368 if (indicesRank == 2) { 1369 if (indicesType.getDimSize(1) != static_cast<int64_t>(rank)) 1370 return emitShapeError(); 1371 } else if (indicesRank != 1 || rank != 1) { 1372 return emitShapeError(); 1373 } 1374 // Verify the values shape. 1375 int64_t numSparseIndices = indicesType.getDimSize(0); 1376 if (numSparseIndices != valuesType.getDimSize(0)) 1377 return emitShapeError(); 1378 1379 // Verify that the sparse indices are within the value shape. 1380 auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) { 1381 return emitError() 1382 << "sparse index #" << indexNum 1383 << " is not contained within the value shape, with index=[" << index 1384 << "], and type=" << type; 1385 }; 1386 1387 // Handle the case where the index values are a splat. 1388 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1389 if (sparseIndices.isSplat()) { 1390 SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin()); 1391 if (!ElementsAttr::isValidIndex(type, indices)) 1392 return emitIndexError(0, indices); 1393 return success(); 1394 } 1395 1396 // Otherwise, reinterpret each index as an ArrayRef. 1397 for (size_t i = 0, e = numSparseIndices; i != e; ++i) { 1398 ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank), 1399 rank); 1400 if (!ElementsAttr::isValidIndex(type, index)) 1401 return emitIndexError(i, index); 1402 } 1403 1404 return success(); 1405 } 1406 1407 //===----------------------------------------------------------------------===// 1408 // TypeAttr 1409 //===----------------------------------------------------------------------===// 1410 1411 void TypeAttr::walkImmediateSubElements( 1412 function_ref<void(Attribute)> walkAttrsFn, 1413 function_ref<void(Type)> walkTypesFn) const { 1414 walkTypesFn(getValue()); 1415 } 1416