1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/BuiltinDialect.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/IntegerSet.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/SymbolTable.h" 17 #include "mlir/IR/Types.h" 18 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" 19 #include "llvm/ADT/APSInt.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/Support/Endian.h" 22 23 using namespace mlir; 24 using namespace mlir::detail; 25 26 //===----------------------------------------------------------------------===// 27 /// Tablegen Attribute Definitions 28 //===----------------------------------------------------------------------===// 29 30 #define GET_ATTRDEF_CLASSES 31 #include "mlir/IR/BuiltinAttributes.cpp.inc" 32 33 //===----------------------------------------------------------------------===// 34 // BuiltinDialect 35 //===----------------------------------------------------------------------===// 36 37 void BuiltinDialect::registerAttributes() { 38 addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr, 39 DenseStringElementsAttr, DictionaryAttr, FloatAttr, 40 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr, 41 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr, 42 UnitAttr>(); 43 } 44 45 //===----------------------------------------------------------------------===// 46 // ArrayAttr 47 //===----------------------------------------------------------------------===// 48 49 void ArrayAttr::walkImmediateSubElements( 50 function_ref<void(Attribute)> walkAttrsFn, 51 function_ref<void(Type)> walkTypesFn) const { 52 for (Attribute attr : getValue()) 53 walkAttrsFn(attr); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // DictionaryAttr 58 //===----------------------------------------------------------------------===// 59 60 /// Helper function that does either an in place sort or sorts from source array 61 /// into destination. If inPlace then storage is both the source and the 62 /// destination, else value is the source and storage destination. Returns 63 /// whether source was sorted. 64 template <bool inPlace> 65 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 66 SmallVectorImpl<NamedAttribute> &storage) { 67 // Specialize for the common case. 68 switch (value.size()) { 69 case 0: 70 // Zero already sorted. 71 break; 72 case 1: 73 // One already sorted but may need to be copied. 74 if (!inPlace) 75 storage.assign({value[0]}); 76 break; 77 case 2: { 78 bool isSorted = value[0] < value[1]; 79 if (inPlace) { 80 if (!isSorted) 81 std::swap(storage[0], storage[1]); 82 } else if (isSorted) { 83 storage.assign({value[0], value[1]}); 84 } else { 85 storage.assign({value[1], value[0]}); 86 } 87 return !isSorted; 88 } 89 default: 90 if (!inPlace) 91 storage.assign(value.begin(), value.end()); 92 // Check to see they are sorted already. 93 bool isSorted = llvm::is_sorted(value); 94 // If not, do a general sort. 95 if (!isSorted) 96 llvm::array_pod_sort(storage.begin(), storage.end()); 97 return !isSorted; 98 } 99 return false; 100 } 101 102 /// Returns an entry with a duplicate name from the given sorted array of named 103 /// attributes. Returns llvm::None if all elements have unique names. 104 static Optional<NamedAttribute> 105 findDuplicateElement(ArrayRef<NamedAttribute> value) { 106 const Optional<NamedAttribute> none{llvm::None}; 107 if (value.size() < 2) 108 return none; 109 110 if (value.size() == 2) 111 return value[0].first == value[1].first ? value[0] : none; 112 113 auto it = std::adjacent_find( 114 value.begin(), value.end(), 115 [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; }); 116 return it != value.end() ? *it : none; 117 } 118 119 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 120 SmallVectorImpl<NamedAttribute> &storage) { 121 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); 122 assert(!findDuplicateElement(storage) && 123 "DictionaryAttr element names must be unique"); 124 return isSorted; 125 } 126 127 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 128 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); 129 assert(!findDuplicateElement(array) && 130 "DictionaryAttr element names must be unique"); 131 return isSorted; 132 } 133 134 Optional<NamedAttribute> 135 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, 136 bool isSorted) { 137 if (!isSorted) 138 dictionaryAttrSort</*inPlace=*/true>(array, array); 139 return findDuplicateElement(array); 140 } 141 142 DictionaryAttr DictionaryAttr::get(MLIRContext *context, 143 ArrayRef<NamedAttribute> value) { 144 if (value.empty()) 145 return DictionaryAttr::getEmpty(context); 146 assert(llvm::all_of(value, 147 [](const NamedAttribute &attr) { return attr.second; }) && 148 "value cannot have null entries"); 149 150 // We need to sort the element list to canonicalize it. 151 SmallVector<NamedAttribute, 8> storage; 152 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 153 value = storage; 154 assert(!findDuplicateElement(value) && 155 "DictionaryAttr element names must be unique"); 156 return Base::get(context, value); 157 } 158 /// Construct a dictionary with an array of values that is known to already be 159 /// sorted by name and uniqued. 160 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context, 161 ArrayRef<NamedAttribute> value) { 162 if (value.empty()) 163 return DictionaryAttr::getEmpty(context); 164 // Ensure that the attribute elements are unique and sorted. 165 assert(llvm::is_sorted(value, 166 [](NamedAttribute l, NamedAttribute r) { 167 return l.first.strref() < r.first.strref(); 168 }) && 169 "expected attribute values to be sorted"); 170 assert(!findDuplicateElement(value) && 171 "DictionaryAttr element names must be unique"); 172 return Base::get(context, value); 173 } 174 175 /// Return the specified attribute if present, null otherwise. 176 Attribute DictionaryAttr::get(StringRef name) const { 177 Optional<NamedAttribute> attr = getNamed(name); 178 return attr ? attr->second : nullptr; 179 } 180 Attribute DictionaryAttr::get(Identifier name) const { 181 Optional<NamedAttribute> attr = getNamed(name); 182 return attr ? attr->second : nullptr; 183 } 184 185 /// Return the specified named attribute if present, None otherwise. 186 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 187 ArrayRef<NamedAttribute> values = getValue(); 188 const auto *it = llvm::lower_bound(values, name); 189 return it != values.end() && it->first == name ? *it 190 : Optional<NamedAttribute>(); 191 } 192 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { 193 for (auto elt : getValue()) 194 if (elt.first == name) 195 return elt; 196 return llvm::None; 197 } 198 199 DictionaryAttr::iterator DictionaryAttr::begin() const { 200 return getValue().begin(); 201 } 202 DictionaryAttr::iterator DictionaryAttr::end() const { 203 return getValue().end(); 204 } 205 size_t DictionaryAttr::size() const { return getValue().size(); } 206 207 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { 208 return Base::get(context, ArrayRef<NamedAttribute>()); 209 } 210 211 void DictionaryAttr::walkImmediateSubElements( 212 function_ref<void(Attribute)> walkAttrsFn, 213 function_ref<void(Type)> walkTypesFn) const { 214 for (Attribute attr : llvm::make_second_range(getValue())) 215 walkAttrsFn(attr); 216 } 217 218 //===----------------------------------------------------------------------===// 219 // StringAttr 220 //===----------------------------------------------------------------------===// 221 222 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) { 223 return Base::get(context, "", NoneType::get(context)); 224 } 225 226 /// Twine support for StringAttr. 227 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) { 228 // Fast-path empty twine. 229 if (twine.isTriviallyEmpty()) 230 return get(context); 231 SmallVector<char, 32> tempStr; 232 return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context)); 233 } 234 235 /// Twine support for StringAttr. 236 StringAttr StringAttr::get(const Twine &twine, Type type) { 237 SmallVector<char, 32> tempStr; 238 return Base::get(type.getContext(), twine.toStringRef(tempStr), type); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // FloatAttr 243 //===----------------------------------------------------------------------===// 244 245 double FloatAttr::getValueAsDouble() const { 246 return getValueAsDouble(getValue()); 247 } 248 double FloatAttr::getValueAsDouble(APFloat value) { 249 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 250 bool losesInfo = false; 251 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 252 &losesInfo); 253 } 254 return value.convertToDouble(); 255 } 256 257 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 258 Type type, APFloat value) { 259 // Verify that the type is correct. 260 if (!type.isa<FloatType>()) 261 return emitError() << "expected floating point type"; 262 263 // Verify that the type semantics match that of the value. 264 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 265 return emitError() 266 << "FloatAttr type doesn't match the type implied by its value"; 267 } 268 return success(); 269 } 270 271 //===----------------------------------------------------------------------===// 272 // SymbolRefAttr 273 //===----------------------------------------------------------------------===// 274 275 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value, 276 ArrayRef<FlatSymbolRefAttr> nestedRefs) { 277 return get(StringAttr::get(ctx, value), nestedRefs); 278 } 279 280 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { 281 return get(ctx, value, {}).cast<FlatSymbolRefAttr>(); 282 } 283 284 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) { 285 return get(value, {}).cast<FlatSymbolRefAttr>(); 286 } 287 288 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) { 289 auto symName = 290 symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); 291 assert(symName && "value does not have a valid symbol name"); 292 return SymbolRefAttr::get(symName); 293 } 294 295 StringAttr SymbolRefAttr::getLeafReference() const { 296 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 297 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr(); 298 } 299 300 //===----------------------------------------------------------------------===// 301 // IntegerAttr 302 //===----------------------------------------------------------------------===// 303 304 int64_t IntegerAttr::getInt() const { 305 assert((getType().isIndex() || getType().isSignlessInteger()) && 306 "must be signless integer"); 307 return getValue().getSExtValue(); 308 } 309 310 int64_t IntegerAttr::getSInt() const { 311 assert(getType().isSignedInteger() && "must be signed integer"); 312 return getValue().getSExtValue(); 313 } 314 315 uint64_t IntegerAttr::getUInt() const { 316 assert(getType().isUnsignedInteger() && "must be unsigned integer"); 317 return getValue().getZExtValue(); 318 } 319 320 /// Return the value as an APSInt which carries the signed from the type of 321 /// the attribute. This traps on signless integers types! 322 APSInt IntegerAttr::getAPSInt() const { 323 assert(!getType().isSignlessInteger() && 324 "Signless integers don't carry a sign for APSInt"); 325 return APSInt(getValue(), getType().isUnsignedInteger()); 326 } 327 328 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, 329 Type type, APInt value) { 330 if (IntegerType integerType = type.dyn_cast<IntegerType>()) { 331 if (integerType.getWidth() != value.getBitWidth()) 332 return emitError() << "integer type bit width (" << integerType.getWidth() 333 << ") doesn't match value bit width (" 334 << value.getBitWidth() << ")"; 335 return success(); 336 } 337 if (type.isa<IndexType>()) 338 return success(); 339 return emitError() << "expected integer or index type"; 340 } 341 342 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { 343 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value)); 344 return attr.cast<BoolAttr>(); 345 } 346 347 //===----------------------------------------------------------------------===// 348 // BoolAttr 349 350 bool BoolAttr::getValue() const { 351 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl); 352 return storage->value.getBoolValue(); 353 } 354 355 bool BoolAttr::classof(Attribute attr) { 356 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 357 return intAttr && intAttr.getType().isSignlessInteger(1); 358 } 359 360 //===----------------------------------------------------------------------===// 361 // OpaqueAttr 362 //===----------------------------------------------------------------------===// 363 364 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError, 365 Identifier dialect, StringRef attrData, 366 Type type) { 367 if (!Dialect::isValidNamespace(dialect.strref())) 368 return emitError() << "invalid dialect namespace '" << dialect << "'"; 369 370 // Check that the dialect is actually registered. 371 MLIRContext *context = dialect.getContext(); 372 if (!context->allowsUnregisteredDialects() && 373 !context->getLoadedDialect(dialect.strref())) { 374 return emitError() 375 << "#" << dialect << "<\"" << attrData << "\"> : " << type 376 << " attribute created with unregistered dialect. If this is " 377 "intended, please call allowUnregisteredDialects() on the " 378 "MLIRContext, or use -allow-unregistered-dialect with " 379 "the MLIR opt tool used"; 380 } 381 382 return success(); 383 } 384 385 //===----------------------------------------------------------------------===// 386 // DenseElementsAttr Utilities 387 //===----------------------------------------------------------------------===// 388 389 /// Get the bitwidth of a dense element type within the buffer. 390 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 391 static size_t getDenseElementStorageWidth(size_t origWidth) { 392 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 393 } 394 static size_t getDenseElementStorageWidth(Type elementType) { 395 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 396 } 397 398 /// Set a bit to a specific value. 399 static void setBit(char *rawData, size_t bitPos, bool value) { 400 if (value) 401 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 402 else 403 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 404 } 405 406 /// Return the value of the specified bit. 407 static bool getBit(const char *rawData, size_t bitPos) { 408 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 409 } 410 411 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for 412 /// BE format. 413 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, 414 char *result) { 415 assert(llvm::support::endian::system_endianness() == // NOLINT 416 llvm::support::endianness::big); // NOLINT 417 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 418 419 // Copy the words filled with data. 420 // For example, when `value` has 2 words, the first word is filled with data. 421 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| 422 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 423 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 424 numFilledWords, result); 425 // Convert last word of APInt to LE format and store it in char 426 // array(`valueLE`). 427 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| 428 size_t lastWordPos = numFilledWords; 429 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); 430 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 431 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, 432 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); 433 // Extract actual APInt data from `valueLE`, convert endianness to BE format, 434 // and store it in `result`. 435 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| 436 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 437 valueLE.begin(), result + lastWordPos, 438 (numBytes - lastWordPos) * CHAR_BIT, 1); 439 } 440 441 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE 442 /// format. 443 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, 444 APInt &result) { 445 assert(llvm::support::endian::system_endianness() == // NOLINT 446 llvm::support::endianness::big); // NOLINT 447 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 448 449 // Copy the data that fills the word of `result` from `inArray`. 450 // For example, when `result` has 2 words, the first word will be filled with 451 // data. So, the first 8 bytes are copied from `inArray` here. 452 // `inArray` (10 bytes, BE): |abcdefgh|ij| 453 // ==> `result` (2 words, BE): |abcdefgh|--------| 454 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 455 std::copy_n( 456 inArray, numFilledWords, 457 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 458 459 // Convert array data which will be last word of `result` to LE format, and 460 // store it in char array(`inArrayLE`). 461 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| 462 size_t lastWordPos = numFilledWords; 463 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); 464 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 465 inArray + lastWordPos, inArrayLE.begin(), 466 (numBytes - lastWordPos) * CHAR_BIT, 1); 467 468 // Convert `inArrayLE` to BE format, and store it in last word of `result`. 469 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| 470 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 471 inArrayLE.begin(), 472 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + 473 lastWordPos, 474 APInt::APINT_BITS_PER_WORD, 1); 475 } 476 477 /// Writes value to the bit position `bitPos` in array `rawData`. 478 static void writeBits(char *rawData, size_t bitPos, APInt value) { 479 size_t bitWidth = value.getBitWidth(); 480 481 // If the bitwidth is 1 we just toggle the specific bit. 482 if (bitWidth == 1) 483 return setBit(rawData, bitPos, value.isOneValue()); 484 485 // Otherwise, the bit position is guaranteed to be byte aligned. 486 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 487 if (llvm::support::endian::system_endianness() == 488 llvm::support::endianness::big) { 489 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. 490 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 491 // work correctly in BE format. 492 // ex. `value` (2 words including 10 bytes) 493 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| 494 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), 495 rawData + (bitPos / CHAR_BIT)); 496 } else { 497 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 498 llvm::divideCeil(bitWidth, CHAR_BIT), 499 rawData + (bitPos / CHAR_BIT)); 500 } 501 } 502 503 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 504 /// `rawData`. 505 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 506 // Handle a boolean bit position. 507 if (bitWidth == 1) 508 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 509 510 // Otherwise, the bit position must be 8-bit aligned. 511 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 512 APInt result(bitWidth, 0); 513 if (llvm::support::endian::system_endianness() == 514 llvm::support::endianness::big) { 515 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. 516 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 517 // work correctly in BE format. 518 // ex. `result` (2 words including 10 bytes) 519 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function 520 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), 521 llvm::divideCeil(bitWidth, CHAR_BIT), result); 522 } else { 523 std::copy_n(rawData + (bitPos / CHAR_BIT), 524 llvm::divideCeil(bitWidth, CHAR_BIT), 525 const_cast<char *>( 526 reinterpret_cast<const char *>(result.getRawData()))); 527 } 528 return result; 529 } 530 531 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 532 /// the same element count as 'type'. 533 template <typename Values> 534 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 535 return (values.size() == 1) || 536 (type.getNumElements() == static_cast<int64_t>(values.size())); 537 } 538 539 //===----------------------------------------------------------------------===// 540 // DenseElementsAttr Iterators 541 //===----------------------------------------------------------------------===// 542 543 //===----------------------------------------------------------------------===// 544 // AttributeElementIterator 545 546 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 547 DenseElementsAttr attr, size_t index) 548 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 549 Attribute, Attribute, Attribute>( 550 attr.getAsOpaquePointer(), index) {} 551 552 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 553 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 554 Type eltTy = owner.getElementType(); 555 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 556 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 557 if (eltTy.isa<IndexType>()) 558 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 559 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 560 IntElementIterator intIt(owner, index); 561 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 562 return FloatAttr::get(eltTy, *floatIt); 563 } 564 if (auto complexTy = eltTy.dyn_cast<ComplexType>()) { 565 auto complexEltTy = complexTy.getElementType(); 566 ComplexIntElementIterator complexIntIt(owner, index); 567 if (complexEltTy.isa<IntegerType>()) { 568 auto value = *complexIntIt; 569 auto real = IntegerAttr::get(complexEltTy, value.real()); 570 auto imag = IntegerAttr::get(complexEltTy, value.imag()); 571 return ArrayAttr::get(complexTy.getContext(), 572 ArrayRef<Attribute>{real, imag}); 573 } 574 575 ComplexFloatElementIterator complexFloatIt( 576 complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt); 577 auto value = *complexFloatIt; 578 auto real = FloatAttr::get(complexEltTy, value.real()); 579 auto imag = FloatAttr::get(complexEltTy, value.imag()); 580 return ArrayAttr::get(complexTy.getContext(), 581 ArrayRef<Attribute>{real, imag}); 582 } 583 if (owner.isa<DenseStringElementsAttr>()) { 584 ArrayRef<StringRef> vals = owner.getRawStringData(); 585 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 586 } 587 llvm_unreachable("unexpected element type"); 588 } 589 590 //===----------------------------------------------------------------------===// 591 // BoolElementIterator 592 593 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 594 DenseElementsAttr attr, size_t dataIndex) 595 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 596 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 597 598 bool DenseElementsAttr::BoolElementIterator::operator*() const { 599 return getBit(getData(), getDataIndex()); 600 } 601 602 //===----------------------------------------------------------------------===// 603 // IntElementIterator 604 605 DenseElementsAttr::IntElementIterator::IntElementIterator( 606 DenseElementsAttr attr, size_t dataIndex) 607 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 608 attr.getRawData().data(), attr.isSplat(), dataIndex), 609 bitWidth(getDenseElementBitWidth(attr.getElementType())) {} 610 611 APInt DenseElementsAttr::IntElementIterator::operator*() const { 612 return readBits(getData(), 613 getDataIndex() * getDenseElementStorageWidth(bitWidth), 614 bitWidth); 615 } 616 617 //===----------------------------------------------------------------------===// 618 // ComplexIntElementIterator 619 620 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 621 DenseElementsAttr attr, size_t dataIndex) 622 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 623 std::complex<APInt>, std::complex<APInt>, 624 std::complex<APInt>>( 625 attr.getRawData().data(), attr.isSplat(), dataIndex) { 626 auto complexType = attr.getElementType().cast<ComplexType>(); 627 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 628 } 629 630 std::complex<APInt> 631 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 632 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 633 size_t offset = getDataIndex() * storageWidth * 2; 634 return {readBits(getData(), offset, bitWidth), 635 readBits(getData(), offset + storageWidth, bitWidth)}; 636 } 637 638 //===----------------------------------------------------------------------===// 639 // FloatElementIterator 640 641 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 642 const llvm::fltSemantics &smt, IntElementIterator it) 643 : llvm::mapped_iterator<IntElementIterator, 644 std::function<APFloat(const APInt &)>>( 645 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 646 647 //===----------------------------------------------------------------------===// 648 // ComplexFloatElementIterator 649 650 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( 651 const llvm::fltSemantics &smt, ComplexIntElementIterator it) 652 : llvm::mapped_iterator< 653 ComplexIntElementIterator, 654 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( 655 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { 656 return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; 657 }) {} 658 659 //===----------------------------------------------------------------------===// 660 // DenseElementsAttr 661 //===----------------------------------------------------------------------===// 662 663 /// Method for support type inquiry through isa, cast and dyn_cast. 664 bool DenseElementsAttr::classof(Attribute attr) { 665 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 666 } 667 668 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 669 ArrayRef<Attribute> values) { 670 assert(hasSameElementsOrSplat(type, values)); 671 672 // If the element type is not based on int/float/index, assume it is a string 673 // type. 674 auto eltType = type.getElementType(); 675 if (!type.getElementType().isIntOrIndexOrFloat()) { 676 SmallVector<StringRef, 8> stringValues; 677 stringValues.reserve(values.size()); 678 for (Attribute attr : values) { 679 assert(attr.isa<StringAttr>() && 680 "expected string value for non integer/index/float element"); 681 stringValues.push_back(attr.cast<StringAttr>().getValue()); 682 } 683 return get(type, stringValues); 684 } 685 686 // Otherwise, get the raw storage width to use for the allocation. 687 size_t bitWidth = getDenseElementBitWidth(eltType); 688 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 689 690 // Compress the attribute values into a character buffer. 691 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 692 values.size()); 693 APInt intVal; 694 for (unsigned i = 0, e = values.size(); i < e; ++i) { 695 assert(eltType == values[i].getType() && 696 "expected attribute value to have element type"); 697 if (eltType.isa<FloatType>()) 698 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 699 else if (eltType.isa<IntegerType, IndexType>()) 700 intVal = values[i].cast<IntegerAttr>().getValue(); 701 else 702 llvm_unreachable("unexpected element type"); 703 704 assert(intVal.getBitWidth() == bitWidth && 705 "expected value to have same bitwidth as element type"); 706 writeBits(data.data(), i * storageBitWidth, intVal); 707 } 708 return DenseIntOrFPElementsAttr::getRaw(type, data, 709 /*isSplat=*/(values.size() == 1)); 710 } 711 712 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 713 ArrayRef<bool> values) { 714 assert(hasSameElementsOrSplat(type, values)); 715 assert(type.getElementType().isInteger(1)); 716 717 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 718 for (int i = 0, e = values.size(); i != e; ++i) 719 setBit(buff.data(), i, values[i]); 720 return DenseIntOrFPElementsAttr::getRaw(type, buff, 721 /*isSplat=*/(values.size() == 1)); 722 } 723 724 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 725 ArrayRef<StringRef> values) { 726 assert(!type.getElementType().isIntOrFloat()); 727 return DenseStringElementsAttr::get(type, values); 728 } 729 730 /// Constructs a dense integer elements attribute from an array of APInt 731 /// values. Each APInt value is expected to have the same bitwidth as the 732 /// element type of 'type'. 733 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 734 ArrayRef<APInt> values) { 735 assert(type.getElementType().isIntOrIndex()); 736 assert(hasSameElementsOrSplat(type, values)); 737 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 738 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 739 /*isSplat=*/(values.size() == 1)); 740 } 741 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 742 ArrayRef<std::complex<APInt>> values) { 743 ComplexType complex = type.getElementType().cast<ComplexType>(); 744 assert(complex.getElementType().isa<IntegerType>()); 745 assert(hasSameElementsOrSplat(type, values)); 746 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 747 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 748 values.size() * 2); 749 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, 750 /*isSplat=*/(values.size() == 1)); 751 } 752 753 // Constructs a dense float elements attribute from an array of APFloat 754 // values. Each APFloat value is expected to have the same bitwidth as the 755 // element type of 'type'. 756 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 757 ArrayRef<APFloat> values) { 758 assert(type.getElementType().isa<FloatType>()); 759 assert(hasSameElementsOrSplat(type, values)); 760 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 761 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 762 /*isSplat=*/(values.size() == 1)); 763 } 764 DenseElementsAttr 765 DenseElementsAttr::get(ShapedType type, 766 ArrayRef<std::complex<APFloat>> values) { 767 ComplexType complex = type.getElementType().cast<ComplexType>(); 768 assert(complex.getElementType().isa<FloatType>()); 769 assert(hasSameElementsOrSplat(type, values)); 770 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 771 values.size() * 2); 772 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 773 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, 774 /*isSplat=*/(values.size() == 1)); 775 } 776 777 /// Construct a dense elements attribute from a raw buffer representing the 778 /// data for this attribute. Users should generally not use this methods as 779 /// the expected buffer format may not be a form the user expects. 780 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 781 ArrayRef<char> rawBuffer, 782 bool isSplatBuffer) { 783 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 784 } 785 786 /// Returns true if the given buffer is a valid raw buffer for the given type. 787 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 788 ArrayRef<char> rawBuffer, 789 bool &detectedSplat) { 790 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 791 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 792 793 // Storage width of 1 is special as it is packed by the bit. 794 if (storageWidth == 1) { 795 // Check for a splat, or a buffer equal to the number of elements. 796 if ((detectedSplat = rawBuffer.size() == 1)) 797 return true; 798 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); 799 } 800 // All other types are 8-bit aligned. 801 if ((detectedSplat = rawBufferWidth == storageWidth)) 802 return true; 803 return rawBufferWidth == (storageWidth * type.getNumElements()); 804 } 805 806 /// Check the information for a C++ data type, check if this type is valid for 807 /// the current attribute. This method is used to verify specific type 808 /// invariants that the templatized 'getValues' method cannot. 809 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 810 bool isSigned) { 811 // Make sure that the data element size is the same as the type element width. 812 if (getDenseElementBitWidth(type) != 813 static_cast<size_t>(dataEltSize * CHAR_BIT)) 814 return false; 815 816 // Check that the element type is either float or integer or index. 817 if (!isInt) 818 return type.isa<FloatType>(); 819 if (type.isIndex()) 820 return true; 821 822 auto intType = type.dyn_cast<IntegerType>(); 823 if (!intType) 824 return false; 825 826 // Make sure signedness semantics is consistent. 827 if (intType.isSignless()) 828 return true; 829 return intType.isSigned() ? isSigned : !isSigned; 830 } 831 832 /// Defaults down the subclass implementation. 833 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 834 ArrayRef<char> data, 835 int64_t dataEltSize, 836 bool isInt, bool isSigned) { 837 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 838 isSigned); 839 } 840 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 841 ArrayRef<char> data, 842 int64_t dataEltSize, 843 bool isInt, 844 bool isSigned) { 845 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 846 isInt, isSigned); 847 } 848 849 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 850 bool isSigned) const { 851 return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned); 852 } 853 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 854 bool isSigned) const { 855 return ::isValidIntOrFloat( 856 getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, 857 isInt, isSigned); 858 } 859 860 /// Returns true if this attribute corresponds to a splat, i.e. if all element 861 /// values are the same. 862 bool DenseElementsAttr::isSplat() const { 863 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 864 } 865 866 /// Return if the given complex type has an integer element type. 867 static bool isComplexOfIntType(Type type) { 868 return type.cast<ComplexType>().getElementType().isa<IntegerType>(); 869 } 870 871 auto DenseElementsAttr::getComplexIntValues() const 872 -> llvm::iterator_range<ComplexIntElementIterator> { 873 assert(isComplexOfIntType(getElementType()) && 874 "expected complex integral type"); 875 return {ComplexIntElementIterator(*this, 0), 876 ComplexIntElementIterator(*this, getNumElements())}; 877 } 878 auto DenseElementsAttr::complex_value_begin() const 879 -> ComplexIntElementIterator { 880 assert(isComplexOfIntType(getElementType()) && 881 "expected complex integral type"); 882 return ComplexIntElementIterator(*this, 0); 883 } 884 auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator { 885 assert(isComplexOfIntType(getElementType()) && 886 "expected complex integral type"); 887 return ComplexIntElementIterator(*this, getNumElements()); 888 } 889 890 /// Return the held element values as a range of APFloat. The element type of 891 /// this attribute must be of float type. 892 auto DenseElementsAttr::getFloatValues() const 893 -> llvm::iterator_range<FloatElementIterator> { 894 auto elementType = getElementType().cast<FloatType>(); 895 const auto &elementSemantics = elementType.getFloatSemantics(); 896 return {FloatElementIterator(elementSemantics, raw_int_begin()), 897 FloatElementIterator(elementSemantics, raw_int_end())}; 898 } 899 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 900 auto elementType = getElementType().cast<FloatType>(); 901 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin()); 902 } 903 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 904 auto elementType = getElementType().cast<FloatType>(); 905 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end()); 906 } 907 908 auto DenseElementsAttr::getComplexFloatValues() const 909 -> llvm::iterator_range<ComplexFloatElementIterator> { 910 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 911 assert(eltTy.isa<FloatType>() && "expected complex float type"); 912 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 913 return {{semantics, {*this, 0}}, 914 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 915 } 916 auto DenseElementsAttr::complex_float_value_begin() const 917 -> ComplexFloatElementIterator { 918 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 919 assert(eltTy.isa<FloatType>() && "expected complex float type"); 920 return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}}; 921 } 922 auto DenseElementsAttr::complex_float_value_end() const 923 -> ComplexFloatElementIterator { 924 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 925 assert(eltTy.isa<FloatType>() && "expected complex float type"); 926 return {eltTy.cast<FloatType>().getFloatSemantics(), 927 {*this, static_cast<size_t>(getNumElements())}}; 928 } 929 930 /// Return the raw storage data held by this attribute. 931 ArrayRef<char> DenseElementsAttr::getRawData() const { 932 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data; 933 } 934 935 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 936 return static_cast<DenseStringElementsAttrStorage *>(impl)->data; 937 } 938 939 /// Return a new DenseElementsAttr that has the same data as the current 940 /// attribute, but has been reshaped to 'newType'. The new type must have the 941 /// same total number of elements as well as element type. 942 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 943 ShapedType curType = getType(); 944 if (curType == newType) 945 return *this; 946 947 assert(newType.getElementType() == curType.getElementType() && 948 "expected the same element type"); 949 assert(newType.getNumElements() == curType.getNumElements() && 950 "expected the same number of elements"); 951 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 952 } 953 954 /// Return a new DenseElementsAttr that has the same data as the current 955 /// attribute, but has bitcast elements such that it is now 'newType'. The new 956 /// type must have the same shape and element types of the same bitwidth as the 957 /// current type. 958 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) { 959 ShapedType curType = getType(); 960 Type curElType = curType.getElementType(); 961 if (curElType == newElType) 962 return *this; 963 964 assert(getDenseElementBitWidth(newElType) == 965 getDenseElementBitWidth(curElType) && 966 "expected element types with the same bitwidth"); 967 return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType), 968 getRawData(), isSplat()); 969 } 970 971 DenseElementsAttr 972 DenseElementsAttr::mapValues(Type newElementType, 973 function_ref<APInt(const APInt &)> mapping) const { 974 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 975 } 976 977 DenseElementsAttr DenseElementsAttr::mapValues( 978 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 979 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 980 } 981 982 ShapedType DenseElementsAttr::getType() const { 983 return Attribute::getType().cast<ShapedType>(); 984 } 985 986 Type DenseElementsAttr::getElementType() const { 987 return getType().getElementType(); 988 } 989 990 int64_t DenseElementsAttr::getNumElements() const { 991 return getType().getNumElements(); 992 } 993 994 //===----------------------------------------------------------------------===// 995 // DenseIntOrFPElementsAttr 996 //===----------------------------------------------------------------------===// 997 998 /// Utility method to write a range of APInt values to a buffer. 999 template <typename APRangeT> 1000 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1001 APRangeT &&values) { 1002 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); 1003 size_t offset = 0; 1004 for (auto it = values.begin(), e = values.end(); it != e; 1005 ++it, offset += storageWidth) { 1006 assert((*it).getBitWidth() <= storageWidth); 1007 writeBits(data.data(), offset, *it); 1008 } 1009 } 1010 1011 /// Constructs a dense elements attribute from an array of raw APFloat values. 1012 /// Each APFloat value is expected to have the same bitwidth as the element 1013 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1014 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1015 size_t storageWidth, 1016 ArrayRef<APFloat> values, 1017 bool isSplat) { 1018 std::vector<char> data; 1019 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1020 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1021 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1022 } 1023 1024 /// Constructs a dense elements attribute from an array of raw APInt values. 1025 /// Each APInt value is expected to have the same bitwidth as the element type 1026 /// of 'type'. 1027 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1028 size_t storageWidth, 1029 ArrayRef<APInt> values, 1030 bool isSplat) { 1031 std::vector<char> data; 1032 writeAPIntsToBuffer(storageWidth, data, values); 1033 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1034 } 1035 1036 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1037 ArrayRef<char> data, 1038 bool isSplat) { 1039 assert((type.isa<RankedTensorType, VectorType>()) && 1040 "type must be ranked tensor or vector"); 1041 assert(type.hasStaticShape() && "type must have static shape"); 1042 return Base::get(type.getContext(), type, data, isSplat); 1043 } 1044 1045 /// Overload of the raw 'get' method that asserts that the given type is of 1046 /// complex type. This method is used to verify type invariants that the 1047 /// templatized 'get' method cannot. 1048 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1049 ArrayRef<char> data, 1050 int64_t dataEltSize, 1051 bool isInt, 1052 bool isSigned) { 1053 assert(::isValidIntOrFloat( 1054 type.getElementType().cast<ComplexType>().getElementType(), 1055 dataEltSize / 2, isInt, isSigned)); 1056 1057 int64_t numElements = data.size() / dataEltSize; 1058 assert(numElements == 1 || numElements == type.getNumElements()); 1059 return getRaw(type, data, /*isSplat=*/numElements == 1); 1060 } 1061 1062 /// Overload of the 'getRaw' method that asserts that the given type is of 1063 /// integer type. This method is used to verify type invariants that the 1064 /// templatized 'get' method cannot. 1065 DenseElementsAttr 1066 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1067 int64_t dataEltSize, bool isInt, 1068 bool isSigned) { 1069 assert( 1070 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1071 1072 int64_t numElements = data.size() / dataEltSize; 1073 assert(numElements == 1 || numElements == type.getNumElements()); 1074 return getRaw(type, data, /*isSplat=*/numElements == 1); 1075 } 1076 1077 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 1078 const char *inRawData, char *outRawData, size_t elementBitWidth, 1079 size_t numElements) { 1080 using llvm::support::ulittle16_t; 1081 using llvm::support::ulittle32_t; 1082 using llvm::support::ulittle64_t; 1083 1084 assert(llvm::support::endian::system_endianness() == // NOLINT 1085 llvm::support::endianness::big); // NOLINT 1086 // NOLINT to avoid warning message about replacing by static_assert() 1087 1088 // Following std::copy_n always converts endianness on BE machine. 1089 switch (elementBitWidth) { 1090 case 16: { 1091 const ulittle16_t *inRawDataPos = 1092 reinterpret_cast<const ulittle16_t *>(inRawData); 1093 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); 1094 std::copy_n(inRawDataPos, numElements, outDataPos); 1095 break; 1096 } 1097 case 32: { 1098 const ulittle32_t *inRawDataPos = 1099 reinterpret_cast<const ulittle32_t *>(inRawData); 1100 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); 1101 std::copy_n(inRawDataPos, numElements, outDataPos); 1102 break; 1103 } 1104 case 64: { 1105 const ulittle64_t *inRawDataPos = 1106 reinterpret_cast<const ulittle64_t *>(inRawData); 1107 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); 1108 std::copy_n(inRawDataPos, numElements, outDataPos); 1109 break; 1110 } 1111 default: { 1112 size_t nBytes = elementBitWidth / CHAR_BIT; 1113 for (size_t i = 0; i < nBytes; i++) 1114 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); 1115 break; 1116 } 1117 } 1118 } 1119 1120 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1121 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, 1122 ShapedType type) { 1123 size_t numElements = type.getNumElements(); 1124 Type elementType = type.getElementType(); 1125 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1126 elementType = complexTy.getElementType(); 1127 numElements = numElements * 2; 1128 } 1129 size_t elementBitWidth = getDenseElementStorageWidth(elementType); 1130 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && 1131 inRawData.size() <= outRawData.size()); 1132 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), 1133 elementBitWidth, numElements); 1134 } 1135 1136 //===----------------------------------------------------------------------===// 1137 // DenseFPElementsAttr 1138 //===----------------------------------------------------------------------===// 1139 1140 template <typename Fn, typename Attr> 1141 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1142 Type newElementType, 1143 llvm::SmallVectorImpl<char> &data) { 1144 size_t bitWidth = getDenseElementBitWidth(newElementType); 1145 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1146 1147 ShapedType newArrayType; 1148 if (inType.isa<RankedTensorType>()) 1149 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1150 else if (inType.isa<UnrankedTensorType>()) 1151 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1152 else if (inType.isa<VectorType>()) 1153 newArrayType = VectorType::get(inType.getShape(), newElementType); 1154 else 1155 assert(newArrayType && "Unhandled tensor type"); 1156 1157 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1158 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 1159 1160 // Functor used to process a single element value of the attribute. 1161 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1162 auto newInt = mapping(value); 1163 assert(newInt.getBitWidth() == bitWidth); 1164 writeBits(data.data(), index * storageBitWidth, newInt); 1165 }; 1166 1167 // Check for the splat case. 1168 if (attr.isSplat()) { 1169 processElt(*attr.begin(), /*index=*/0); 1170 return newArrayType; 1171 } 1172 1173 // Otherwise, process all of the element values. 1174 uint64_t elementIdx = 0; 1175 for (auto value : attr) 1176 processElt(value, elementIdx++); 1177 return newArrayType; 1178 } 1179 1180 DenseElementsAttr DenseFPElementsAttr::mapValues( 1181 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1182 llvm::SmallVector<char, 8> elementData; 1183 auto newArrayType = 1184 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1185 1186 return getRaw(newArrayType, elementData, isSplat()); 1187 } 1188 1189 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1190 bool DenseFPElementsAttr::classof(Attribute attr) { 1191 return attr.isa<DenseElementsAttr>() && 1192 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1193 } 1194 1195 //===----------------------------------------------------------------------===// 1196 // DenseIntElementsAttr 1197 //===----------------------------------------------------------------------===// 1198 1199 DenseElementsAttr DenseIntElementsAttr::mapValues( 1200 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1201 llvm::SmallVector<char, 8> elementData; 1202 auto newArrayType = 1203 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1204 1205 return getRaw(newArrayType, elementData, isSplat()); 1206 } 1207 1208 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1209 bool DenseIntElementsAttr::classof(Attribute attr) { 1210 return attr.isa<DenseElementsAttr>() && 1211 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1212 } 1213 1214 //===----------------------------------------------------------------------===// 1215 // OpaqueElementsAttr 1216 //===----------------------------------------------------------------------===// 1217 1218 /// Return the value at the given index. If index does not refer to a valid 1219 /// element, then a null attribute is returned. 1220 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1221 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1222 return Attribute(); 1223 } 1224 1225 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1226 Dialect *dialect = getDialect().getDialect(); 1227 if (!dialect) 1228 return true; 1229 auto *interface = 1230 dialect->getRegisteredInterface<DialectDecodeAttributesInterface>(); 1231 if (!interface) 1232 return true; 1233 return failed(interface->decode(*this, result)); 1234 } 1235 1236 LogicalResult 1237 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1238 Identifier dialect, StringRef value, 1239 ShapedType type) { 1240 if (!Dialect::isValidNamespace(dialect.strref())) 1241 return emitError() << "invalid dialect namespace '" << dialect << "'"; 1242 return success(); 1243 } 1244 1245 //===----------------------------------------------------------------------===// 1246 // SparseElementsAttr 1247 //===----------------------------------------------------------------------===// 1248 1249 /// Return the value of the element at the given index. 1250 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1251 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1252 auto type = getType(); 1253 1254 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1255 // as a 1-D index array. 1256 auto sparseIndices = getIndices(); 1257 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1258 1259 // Check to see if the indices are a splat. 1260 if (sparseIndices.isSplat()) { 1261 // If the index is also not a splat of the index value, we know that the 1262 // value is zero. 1263 auto splatIndex = *sparseIndexValues.begin(); 1264 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1265 return getZeroAttr(); 1266 1267 // If the indices are a splat, we also expect the values to be a splat. 1268 assert(getValues().isSplat() && "expected splat values"); 1269 return getValues().getSplatValue(); 1270 } 1271 1272 // Build a mapping between known indices and the offset of the stored element. 1273 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1274 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1275 size_t rank = type.getRank(); 1276 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1277 mappedIndices.try_emplace( 1278 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1279 1280 // Look for the provided index key within the mapped indices. If the provided 1281 // index is not found, then return a zero attribute. 1282 auto it = mappedIndices.find(index); 1283 if (it == mappedIndices.end()) 1284 return getZeroAttr(); 1285 1286 // Otherwise, return the held sparse value element. 1287 return getValues().getValue(it->second); 1288 } 1289 1290 /// Get a zero APFloat for the given sparse attribute. 1291 APFloat SparseElementsAttr::getZeroAPFloat() const { 1292 auto eltType = getElementType().cast<FloatType>(); 1293 return APFloat(eltType.getFloatSemantics()); 1294 } 1295 1296 /// Get a zero APInt for the given sparse attribute. 1297 APInt SparseElementsAttr::getZeroAPInt() const { 1298 auto eltType = getElementType().cast<IntegerType>(); 1299 return APInt::getZero(eltType.getWidth()); 1300 } 1301 1302 /// Get a zero attribute for the given attribute type. 1303 Attribute SparseElementsAttr::getZeroAttr() const { 1304 auto eltType = getElementType(); 1305 1306 // Handle floating point elements. 1307 if (eltType.isa<FloatType>()) 1308 return FloatAttr::get(eltType, 0); 1309 1310 // Otherwise, this is an integer. 1311 // TODO: Handle StringAttr here. 1312 return IntegerAttr::get(eltType, 0); 1313 } 1314 1315 /// Flatten, and return, all of the sparse indices in this attribute in 1316 /// row-major order. 1317 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1318 std::vector<ptrdiff_t> flatSparseIndices; 1319 1320 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1321 // as a 1-D index array. 1322 auto sparseIndices = getIndices(); 1323 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1324 if (sparseIndices.isSplat()) { 1325 SmallVector<uint64_t, 8> indices(getType().getRank(), 1326 *sparseIndexValues.begin()); 1327 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1328 return flatSparseIndices; 1329 } 1330 1331 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1332 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1333 size_t rank = getType().getRank(); 1334 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1335 flatSparseIndices.push_back(getFlattenedIndex( 1336 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1337 return flatSparseIndices; 1338 } 1339 1340 LogicalResult 1341 SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1342 ShapedType type, DenseIntElementsAttr sparseIndices, 1343 DenseElementsAttr values) { 1344 ShapedType valuesType = values.getType(); 1345 if (valuesType.getRank() != 1) 1346 return emitError() << "expected 1-d tensor for sparse element values"; 1347 1348 // Verify the indices and values shape. 1349 ShapedType indicesType = sparseIndices.getType(); 1350 auto emitShapeError = [&]() { 1351 return emitError() << "expected shape ([" << type.getShape() 1352 << "]); inferred shape of indices literal ([" 1353 << indicesType.getShape() 1354 << "]); inferred shape of values literal ([" 1355 << valuesType.getShape() << "])"; 1356 }; 1357 // Verify indices shape. 1358 size_t rank = type.getRank(), indicesRank = indicesType.getRank(); 1359 if (indicesRank == 2) { 1360 if (indicesType.getDimSize(1) != static_cast<int64_t>(rank)) 1361 return emitShapeError(); 1362 } else if (indicesRank != 1 || rank != 1) { 1363 return emitShapeError(); 1364 } 1365 // Verify the values shape. 1366 int64_t numSparseIndices = indicesType.getDimSize(0); 1367 if (numSparseIndices != valuesType.getDimSize(0)) 1368 return emitShapeError(); 1369 1370 // Verify that the sparse indices are within the value shape. 1371 auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) { 1372 return emitError() 1373 << "sparse index #" << indexNum 1374 << " is not contained within the value shape, with index=[" << index 1375 << "], and type=" << type; 1376 }; 1377 1378 // Handle the case where the index values are a splat. 1379 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1380 if (sparseIndices.isSplat()) { 1381 SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin()); 1382 if (!ElementsAttr::isValidIndex(type, indices)) 1383 return emitIndexError(0, indices); 1384 return success(); 1385 } 1386 1387 // Otherwise, reinterpret each index as an ArrayRef. 1388 for (size_t i = 0, e = numSparseIndices; i != e; ++i) { 1389 ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank), 1390 rank); 1391 if (!ElementsAttr::isValidIndex(type, index)) 1392 return emitIndexError(i, index); 1393 } 1394 1395 return success(); 1396 } 1397 1398 //===----------------------------------------------------------------------===// 1399 // TypeAttr 1400 //===----------------------------------------------------------------------===// 1401 1402 void TypeAttr::walkImmediateSubElements( 1403 function_ref<void(Attribute)> walkAttrsFn, 1404 function_ref<void(Type)> walkTypesFn) const { 1405 walkTypesFn(getValue()); 1406 } 1407