1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/BuiltinDialect.h" 13 #include "mlir/IR/Diagnostics.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/IR/Types.h" 17 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/Twine.h" 20 #include "llvm/Support/Endian.h" 21 22 using namespace mlir; 23 using namespace mlir::detail; 24 25 //===----------------------------------------------------------------------===// 26 /// Tablegen Attribute Definitions 27 //===----------------------------------------------------------------------===// 28 29 #define GET_ATTRDEF_CLASSES 30 #include "mlir/IR/BuiltinAttributes.cpp.inc" 31 32 //===----------------------------------------------------------------------===// 33 // BuiltinDialect 34 //===----------------------------------------------------------------------===// 35 36 void BuiltinDialect::registerAttributes() { 37 addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr, 38 DenseStringElementsAttr, DictionaryAttr, FloatAttr, 39 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr, 40 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr, 41 UnitAttr>(); 42 } 43 44 //===----------------------------------------------------------------------===// 45 // DictionaryAttr 46 //===----------------------------------------------------------------------===// 47 48 /// Helper function that does either an in place sort or sorts from source array 49 /// into destination. If inPlace then storage is both the source and the 50 /// destination, else value is the source and storage destination. Returns 51 /// whether source was sorted. 52 template <bool inPlace> 53 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 54 SmallVectorImpl<NamedAttribute> &storage) { 55 // Specialize for the common case. 56 switch (value.size()) { 57 case 0: 58 // Zero already sorted. 59 break; 60 case 1: 61 // One already sorted but may need to be copied. 62 if (!inPlace) 63 storage.assign({value[0]}); 64 break; 65 case 2: { 66 bool isSorted = value[0] < value[1]; 67 if (inPlace) { 68 if (!isSorted) 69 std::swap(storage[0], storage[1]); 70 } else if (isSorted) { 71 storage.assign({value[0], value[1]}); 72 } else { 73 storage.assign({value[1], value[0]}); 74 } 75 return !isSorted; 76 } 77 default: 78 if (!inPlace) 79 storage.assign(value.begin(), value.end()); 80 // Check to see they are sorted already. 81 bool isSorted = llvm::is_sorted(value); 82 if (!isSorted) { 83 // If not, do a general sort. 84 llvm::array_pod_sort(storage.begin(), storage.end()); 85 value = storage; 86 } 87 return !isSorted; 88 } 89 return false; 90 } 91 92 /// Returns an entry with a duplicate name from the given sorted array of named 93 /// attributes. Returns llvm::None if all elements have unique names. 94 static Optional<NamedAttribute> 95 findDuplicateElement(ArrayRef<NamedAttribute> value) { 96 const Optional<NamedAttribute> none{llvm::None}; 97 if (value.size() < 2) 98 return none; 99 100 if (value.size() == 2) 101 return value[0].first == value[1].first ? value[0] : none; 102 103 auto it = std::adjacent_find( 104 value.begin(), value.end(), 105 [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; }); 106 return it != value.end() ? *it : none; 107 } 108 109 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 110 SmallVectorImpl<NamedAttribute> &storage) { 111 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); 112 assert(!findDuplicateElement(storage) && 113 "DictionaryAttr element names must be unique"); 114 return isSorted; 115 } 116 117 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 118 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); 119 assert(!findDuplicateElement(array) && 120 "DictionaryAttr element names must be unique"); 121 return isSorted; 122 } 123 124 Optional<NamedAttribute> 125 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, 126 bool isSorted) { 127 if (!isSorted) 128 dictionaryAttrSort</*inPlace=*/true>(array, array); 129 return findDuplicateElement(array); 130 } 131 132 DictionaryAttr DictionaryAttr::get(MLIRContext *context, 133 ArrayRef<NamedAttribute> value) { 134 if (value.empty()) 135 return DictionaryAttr::getEmpty(context); 136 assert(llvm::all_of(value, 137 [](const NamedAttribute &attr) { return attr.second; }) && 138 "value cannot have null entries"); 139 140 // We need to sort the element list to canonicalize it. 141 SmallVector<NamedAttribute, 8> storage; 142 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 143 value = storage; 144 assert(!findDuplicateElement(value) && 145 "DictionaryAttr element names must be unique"); 146 return Base::get(context, value); 147 } 148 /// Construct a dictionary with an array of values that is known to already be 149 /// sorted by name and uniqued. 150 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context, 151 ArrayRef<NamedAttribute> value) { 152 if (value.empty()) 153 return DictionaryAttr::getEmpty(context); 154 // Ensure that the attribute elements are unique and sorted. 155 assert(llvm::is_sorted(value, 156 [](NamedAttribute l, NamedAttribute r) { 157 return l.first.strref() < r.first.strref(); 158 }) && 159 "expected attribute values to be sorted"); 160 assert(!findDuplicateElement(value) && 161 "DictionaryAttr element names must be unique"); 162 return Base::get(context, value); 163 } 164 165 /// Return the specified attribute if present, null otherwise. 166 Attribute DictionaryAttr::get(StringRef name) const { 167 Optional<NamedAttribute> attr = getNamed(name); 168 return attr ? attr->second : nullptr; 169 } 170 Attribute DictionaryAttr::get(Identifier name) const { 171 Optional<NamedAttribute> attr = getNamed(name); 172 return attr ? attr->second : nullptr; 173 } 174 175 /// Return the specified named attribute if present, None otherwise. 176 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 177 ArrayRef<NamedAttribute> values = getValue(); 178 const auto *it = llvm::lower_bound(values, name); 179 return it != values.end() && it->first == name ? *it 180 : Optional<NamedAttribute>(); 181 } 182 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { 183 for (auto elt : getValue()) 184 if (elt.first == name) 185 return elt; 186 return llvm::None; 187 } 188 189 DictionaryAttr::iterator DictionaryAttr::begin() const { 190 return getValue().begin(); 191 } 192 DictionaryAttr::iterator DictionaryAttr::end() const { 193 return getValue().end(); 194 } 195 size_t DictionaryAttr::size() const { return getValue().size(); } 196 197 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { 198 return Base::get(context, ArrayRef<NamedAttribute>()); 199 } 200 201 //===----------------------------------------------------------------------===// 202 // FloatAttr 203 //===----------------------------------------------------------------------===// 204 205 double FloatAttr::getValueAsDouble() const { 206 return getValueAsDouble(getValue()); 207 } 208 double FloatAttr::getValueAsDouble(APFloat value) { 209 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 210 bool losesInfo = false; 211 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 212 &losesInfo); 213 } 214 return value.convertToDouble(); 215 } 216 217 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 218 Type type, APFloat value) { 219 // Verify that the type is correct. 220 if (!type.isa<FloatType>()) 221 return emitError() << "expected floating point type"; 222 223 // Verify that the type semantics match that of the value. 224 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 225 return emitError() 226 << "FloatAttr type doesn't match the type implied by its value"; 227 } 228 return success(); 229 } 230 231 //===----------------------------------------------------------------------===// 232 // SymbolRefAttr 233 //===----------------------------------------------------------------------===// 234 235 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { 236 return get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>(); 237 } 238 239 StringRef SymbolRefAttr::getLeafReference() const { 240 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 241 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 242 } 243 244 //===----------------------------------------------------------------------===// 245 // IntegerAttr 246 //===----------------------------------------------------------------------===// 247 248 int64_t IntegerAttr::getInt() const { 249 assert((getType().isIndex() || getType().isSignlessInteger()) && 250 "must be signless integer"); 251 return getValue().getSExtValue(); 252 } 253 254 int64_t IntegerAttr::getSInt() const { 255 assert(getType().isSignedInteger() && "must be signed integer"); 256 return getValue().getSExtValue(); 257 } 258 259 uint64_t IntegerAttr::getUInt() const { 260 assert(getType().isUnsignedInteger() && "must be unsigned integer"); 261 return getValue().getZExtValue(); 262 } 263 264 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, 265 Type type, APInt value) { 266 if (IntegerType integerType = type.dyn_cast<IntegerType>()) { 267 if (integerType.getWidth() != value.getBitWidth()) 268 return emitError() << "integer type bit width (" << integerType.getWidth() 269 << ") doesn't match value bit width (" 270 << value.getBitWidth() << ")"; 271 return success(); 272 } 273 if (type.isa<IndexType>()) 274 return success(); 275 return emitError() << "expected integer or index type"; 276 } 277 278 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { 279 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value)); 280 return attr.cast<BoolAttr>(); 281 } 282 283 //===----------------------------------------------------------------------===// 284 // BoolAttr 285 286 bool BoolAttr::getValue() const { 287 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl); 288 return storage->value.getBoolValue(); 289 } 290 291 bool BoolAttr::classof(Attribute attr) { 292 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 293 return intAttr && intAttr.getType().isSignlessInteger(1); 294 } 295 296 //===----------------------------------------------------------------------===// 297 // OpaqueAttr 298 //===----------------------------------------------------------------------===// 299 300 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError, 301 Identifier dialect, StringRef attrData, 302 Type type) { 303 if (!Dialect::isValidNamespace(dialect.strref())) 304 return emitError() << "invalid dialect namespace '" << dialect << "'"; 305 return success(); 306 } 307 308 //===----------------------------------------------------------------------===// 309 // ElementsAttr 310 //===----------------------------------------------------------------------===// 311 312 ShapedType ElementsAttr::getType() const { 313 return Attribute::getType().cast<ShapedType>(); 314 } 315 316 /// Returns the number of elements held by this attribute. 317 int64_t ElementsAttr::getNumElements() const { 318 return getType().getNumElements(); 319 } 320 321 /// Return the value at the given index. If index does not refer to a valid 322 /// element, then a null attribute is returned. 323 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 324 if (auto denseAttr = dyn_cast<DenseElementsAttr>()) 325 return denseAttr.getValue(index); 326 if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>()) 327 return opaqueAttr.getValue(index); 328 return cast<SparseElementsAttr>().getValue(index); 329 } 330 331 /// Return if the given 'index' refers to a valid element in this attribute. 332 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 333 auto type = getType(); 334 335 // Verify that the rank of the indices matches the held type. 336 auto rank = type.getRank(); 337 if (rank == 0 && index.size() == 1 && index[0] == 0) 338 return true; 339 if (rank != static_cast<int64_t>(index.size())) 340 return false; 341 342 // Verify that all of the indices are within the shape dimensions. 343 auto shape = type.getShape(); 344 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 345 return static_cast<int64_t>(index[i]) < shape[i]; 346 }); 347 } 348 349 ElementsAttr 350 ElementsAttr::mapValues(Type newElementType, 351 function_ref<APInt(const APInt &)> mapping) const { 352 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 353 return intOrFpAttr.mapValues(newElementType, mapping); 354 llvm_unreachable("unsupported ElementsAttr subtype"); 355 } 356 357 ElementsAttr 358 ElementsAttr::mapValues(Type newElementType, 359 function_ref<APInt(const APFloat &)> mapping) const { 360 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 361 return intOrFpAttr.mapValues(newElementType, mapping); 362 llvm_unreachable("unsupported ElementsAttr subtype"); 363 } 364 365 /// Method for support type inquiry through isa, cast and dyn_cast. 366 bool ElementsAttr::classof(Attribute attr) { 367 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr, 368 OpaqueElementsAttr, SparseElementsAttr>(); 369 } 370 371 /// Returns the 1 dimensional flattened row-major index from the given 372 /// multi-dimensional index. 373 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 374 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 375 auto type = getType(); 376 377 // Reduce the provided multidimensional index into a flattended 1D row-major 378 // index. 379 auto rank = type.getRank(); 380 auto shape = type.getShape(); 381 uint64_t valueIndex = 0; 382 uint64_t dimMultiplier = 1; 383 for (int i = rank - 1; i >= 0; --i) { 384 valueIndex += index[i] * dimMultiplier; 385 dimMultiplier *= shape[i]; 386 } 387 return valueIndex; 388 } 389 390 //===----------------------------------------------------------------------===// 391 // DenseElementsAttr Utilities 392 //===----------------------------------------------------------------------===// 393 394 /// Get the bitwidth of a dense element type within the buffer. 395 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 396 static size_t getDenseElementStorageWidth(size_t origWidth) { 397 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 398 } 399 static size_t getDenseElementStorageWidth(Type elementType) { 400 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 401 } 402 403 /// Set a bit to a specific value. 404 static void setBit(char *rawData, size_t bitPos, bool value) { 405 if (value) 406 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 407 else 408 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 409 } 410 411 /// Return the value of the specified bit. 412 static bool getBit(const char *rawData, size_t bitPos) { 413 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 414 } 415 416 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for 417 /// BE format. 418 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, 419 char *result) { 420 assert(llvm::support::endian::system_endianness() == // NOLINT 421 llvm::support::endianness::big); // NOLINT 422 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 423 424 // Copy the words filled with data. 425 // For example, when `value` has 2 words, the first word is filled with data. 426 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| 427 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 428 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 429 numFilledWords, result); 430 // Convert last word of APInt to LE format and store it in char 431 // array(`valueLE`). 432 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| 433 size_t lastWordPos = numFilledWords; 434 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); 435 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 436 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, 437 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); 438 // Extract actual APInt data from `valueLE`, convert endianness to BE format, 439 // and store it in `result`. 440 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| 441 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 442 valueLE.begin(), result + lastWordPos, 443 (numBytes - lastWordPos) * CHAR_BIT, 1); 444 } 445 446 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE 447 /// format. 448 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, 449 APInt &result) { 450 assert(llvm::support::endian::system_endianness() == // NOLINT 451 llvm::support::endianness::big); // NOLINT 452 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 453 454 // Copy the data that fills the word of `result` from `inArray`. 455 // For example, when `result` has 2 words, the first word will be filled with 456 // data. So, the first 8 bytes are copied from `inArray` here. 457 // `inArray` (10 bytes, BE): |abcdefgh|ij| 458 // ==> `result` (2 words, BE): |abcdefgh|--------| 459 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 460 std::copy_n( 461 inArray, numFilledWords, 462 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 463 464 // Convert array data which will be last word of `result` to LE format, and 465 // store it in char array(`inArrayLE`). 466 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| 467 size_t lastWordPos = numFilledWords; 468 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); 469 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 470 inArray + lastWordPos, inArrayLE.begin(), 471 (numBytes - lastWordPos) * CHAR_BIT, 1); 472 473 // Convert `inArrayLE` to BE format, and store it in last word of `result`. 474 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| 475 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 476 inArrayLE.begin(), 477 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + 478 lastWordPos, 479 APInt::APINT_BITS_PER_WORD, 1); 480 } 481 482 /// Writes value to the bit position `bitPos` in array `rawData`. 483 static void writeBits(char *rawData, size_t bitPos, APInt value) { 484 size_t bitWidth = value.getBitWidth(); 485 486 // If the bitwidth is 1 we just toggle the specific bit. 487 if (bitWidth == 1) 488 return setBit(rawData, bitPos, value.isOneValue()); 489 490 // Otherwise, the bit position is guaranteed to be byte aligned. 491 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 492 if (llvm::support::endian::system_endianness() == 493 llvm::support::endianness::big) { 494 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. 495 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 496 // work correctly in BE format. 497 // ex. `value` (2 words including 10 bytes) 498 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| 499 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), 500 rawData + (bitPos / CHAR_BIT)); 501 } else { 502 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 503 llvm::divideCeil(bitWidth, CHAR_BIT), 504 rawData + (bitPos / CHAR_BIT)); 505 } 506 } 507 508 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 509 /// `rawData`. 510 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 511 // Handle a boolean bit position. 512 if (bitWidth == 1) 513 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 514 515 // Otherwise, the bit position must be 8-bit aligned. 516 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 517 APInt result(bitWidth, 0); 518 if (llvm::support::endian::system_endianness() == 519 llvm::support::endianness::big) { 520 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. 521 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 522 // work correctly in BE format. 523 // ex. `result` (2 words including 10 bytes) 524 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function 525 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), 526 llvm::divideCeil(bitWidth, CHAR_BIT), result); 527 } else { 528 std::copy_n(rawData + (bitPos / CHAR_BIT), 529 llvm::divideCeil(bitWidth, CHAR_BIT), 530 const_cast<char *>( 531 reinterpret_cast<const char *>(result.getRawData()))); 532 } 533 return result; 534 } 535 536 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 537 /// the same element count as 'type'. 538 template <typename Values> 539 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 540 return (values.size() == 1) || 541 (type.getNumElements() == static_cast<int64_t>(values.size())); 542 } 543 544 //===----------------------------------------------------------------------===// 545 // DenseElementsAttr Iterators 546 //===----------------------------------------------------------------------===// 547 548 //===----------------------------------------------------------------------===// 549 // AttributeElementIterator 550 551 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 552 DenseElementsAttr attr, size_t index) 553 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 554 Attribute, Attribute, Attribute>( 555 attr.getAsOpaquePointer(), index) {} 556 557 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 558 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 559 Type eltTy = owner.getType().getElementType(); 560 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 561 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 562 if (eltTy.isa<IndexType>()) 563 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 564 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 565 IntElementIterator intIt(owner, index); 566 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 567 return FloatAttr::get(eltTy, *floatIt); 568 } 569 if (owner.isa<DenseStringElementsAttr>()) { 570 ArrayRef<StringRef> vals = owner.getRawStringData(); 571 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 572 } 573 llvm_unreachable("unexpected element type"); 574 } 575 576 //===----------------------------------------------------------------------===// 577 // BoolElementIterator 578 579 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 580 DenseElementsAttr attr, size_t dataIndex) 581 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 582 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 583 584 bool DenseElementsAttr::BoolElementIterator::operator*() const { 585 return getBit(getData(), getDataIndex()); 586 } 587 588 //===----------------------------------------------------------------------===// 589 // IntElementIterator 590 591 DenseElementsAttr::IntElementIterator::IntElementIterator( 592 DenseElementsAttr attr, size_t dataIndex) 593 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 594 attr.getRawData().data(), attr.isSplat(), dataIndex), 595 bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} 596 597 APInt DenseElementsAttr::IntElementIterator::operator*() const { 598 return readBits(getData(), 599 getDataIndex() * getDenseElementStorageWidth(bitWidth), 600 bitWidth); 601 } 602 603 //===----------------------------------------------------------------------===// 604 // ComplexIntElementIterator 605 606 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 607 DenseElementsAttr attr, size_t dataIndex) 608 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 609 std::complex<APInt>, std::complex<APInt>, 610 std::complex<APInt>>( 611 attr.getRawData().data(), attr.isSplat(), dataIndex) { 612 auto complexType = attr.getType().getElementType().cast<ComplexType>(); 613 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 614 } 615 616 std::complex<APInt> 617 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 618 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 619 size_t offset = getDataIndex() * storageWidth * 2; 620 return {readBits(getData(), offset, bitWidth), 621 readBits(getData(), offset + storageWidth, bitWidth)}; 622 } 623 624 //===----------------------------------------------------------------------===// 625 // FloatElementIterator 626 627 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 628 const llvm::fltSemantics &smt, IntElementIterator it) 629 : llvm::mapped_iterator<IntElementIterator, 630 std::function<APFloat(const APInt &)>>( 631 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 632 633 //===----------------------------------------------------------------------===// 634 // ComplexFloatElementIterator 635 636 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( 637 const llvm::fltSemantics &smt, ComplexIntElementIterator it) 638 : llvm::mapped_iterator< 639 ComplexIntElementIterator, 640 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( 641 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { 642 return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; 643 }) {} 644 645 //===----------------------------------------------------------------------===// 646 // DenseElementsAttr 647 //===----------------------------------------------------------------------===// 648 649 /// Method for support type inquiry through isa, cast and dyn_cast. 650 bool DenseElementsAttr::classof(Attribute attr) { 651 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 652 } 653 654 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 655 ArrayRef<Attribute> values) { 656 assert(hasSameElementsOrSplat(type, values)); 657 658 // If the element type is not based on int/float/index, assume it is a string 659 // type. 660 auto eltType = type.getElementType(); 661 if (!type.getElementType().isIntOrIndexOrFloat()) { 662 SmallVector<StringRef, 8> stringValues; 663 stringValues.reserve(values.size()); 664 for (Attribute attr : values) { 665 assert(attr.isa<StringAttr>() && 666 "expected string value for non integer/index/float element"); 667 stringValues.push_back(attr.cast<StringAttr>().getValue()); 668 } 669 return get(type, stringValues); 670 } 671 672 // Otherwise, get the raw storage width to use for the allocation. 673 size_t bitWidth = getDenseElementBitWidth(eltType); 674 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 675 676 // Compress the attribute values into a character buffer. 677 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 678 values.size()); 679 APInt intVal; 680 for (unsigned i = 0, e = values.size(); i < e; ++i) { 681 assert(eltType == values[i].getType() && 682 "expected attribute value to have element type"); 683 if (eltType.isa<FloatType>()) 684 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 685 else if (eltType.isa<IntegerType>()) 686 intVal = values[i].cast<IntegerAttr>().getValue(); 687 else 688 llvm_unreachable("unexpected element type"); 689 690 assert(intVal.getBitWidth() == bitWidth && 691 "expected value to have same bitwidth as element type"); 692 writeBits(data.data(), i * storageBitWidth, intVal); 693 } 694 return DenseIntOrFPElementsAttr::getRaw(type, data, 695 /*isSplat=*/(values.size() == 1)); 696 } 697 698 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 699 ArrayRef<bool> values) { 700 assert(hasSameElementsOrSplat(type, values)); 701 assert(type.getElementType().isInteger(1)); 702 703 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 704 for (int i = 0, e = values.size(); i != e; ++i) 705 setBit(buff.data(), i, values[i]); 706 return DenseIntOrFPElementsAttr::getRaw(type, buff, 707 /*isSplat=*/(values.size() == 1)); 708 } 709 710 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 711 ArrayRef<StringRef> values) { 712 assert(!type.getElementType().isIntOrFloat()); 713 return DenseStringElementsAttr::get(type, values); 714 } 715 716 /// Constructs a dense integer elements attribute from an array of APInt 717 /// values. Each APInt value is expected to have the same bitwidth as the 718 /// element type of 'type'. 719 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 720 ArrayRef<APInt> values) { 721 assert(type.getElementType().isIntOrIndex()); 722 assert(hasSameElementsOrSplat(type, values)); 723 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 724 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 725 /*isSplat=*/(values.size() == 1)); 726 } 727 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 728 ArrayRef<std::complex<APInt>> values) { 729 ComplexType complex = type.getElementType().cast<ComplexType>(); 730 assert(complex.getElementType().isa<IntegerType>()); 731 assert(hasSameElementsOrSplat(type, values)); 732 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 733 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 734 values.size() * 2); 735 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, 736 /*isSplat=*/(values.size() == 1)); 737 } 738 739 // Constructs a dense float elements attribute from an array of APFloat 740 // values. Each APFloat value is expected to have the same bitwidth as the 741 // element type of 'type'. 742 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 743 ArrayRef<APFloat> values) { 744 assert(type.getElementType().isa<FloatType>()); 745 assert(hasSameElementsOrSplat(type, values)); 746 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 747 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 748 /*isSplat=*/(values.size() == 1)); 749 } 750 DenseElementsAttr 751 DenseElementsAttr::get(ShapedType type, 752 ArrayRef<std::complex<APFloat>> values) { 753 ComplexType complex = type.getElementType().cast<ComplexType>(); 754 assert(complex.getElementType().isa<FloatType>()); 755 assert(hasSameElementsOrSplat(type, values)); 756 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 757 values.size() * 2); 758 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 759 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, 760 /*isSplat=*/(values.size() == 1)); 761 } 762 763 /// Construct a dense elements attribute from a raw buffer representing the 764 /// data for this attribute. Users should generally not use this methods as 765 /// the expected buffer format may not be a form the user expects. 766 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 767 ArrayRef<char> rawBuffer, 768 bool isSplatBuffer) { 769 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 770 } 771 772 /// Returns true if the given buffer is a valid raw buffer for the given type. 773 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 774 ArrayRef<char> rawBuffer, 775 bool &detectedSplat) { 776 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 777 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 778 779 // Storage width of 1 is special as it is packed by the bit. 780 if (storageWidth == 1) { 781 // Check for a splat, or a buffer equal to the number of elements. 782 if ((detectedSplat = rawBuffer.size() == 1)) 783 return true; 784 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); 785 } 786 // All other types are 8-bit aligned. 787 if ((detectedSplat = rawBufferWidth == storageWidth)) 788 return true; 789 return rawBufferWidth == (storageWidth * type.getNumElements()); 790 } 791 792 /// Check the information for a C++ data type, check if this type is valid for 793 /// the current attribute. This method is used to verify specific type 794 /// invariants that the templatized 'getValues' method cannot. 795 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 796 bool isSigned) { 797 // Make sure that the data element size is the same as the type element width. 798 if (getDenseElementBitWidth(type) != 799 static_cast<size_t>(dataEltSize * CHAR_BIT)) 800 return false; 801 802 // Check that the element type is either float or integer or index. 803 if (!isInt) 804 return type.isa<FloatType>(); 805 if (type.isIndex()) 806 return true; 807 808 auto intType = type.dyn_cast<IntegerType>(); 809 if (!intType) 810 return false; 811 812 // Make sure signedness semantics is consistent. 813 if (intType.isSignless()) 814 return true; 815 return intType.isSigned() ? isSigned : !isSigned; 816 } 817 818 /// Defaults down the subclass implementation. 819 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 820 ArrayRef<char> data, 821 int64_t dataEltSize, 822 bool isInt, bool isSigned) { 823 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 824 isSigned); 825 } 826 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 827 ArrayRef<char> data, 828 int64_t dataEltSize, 829 bool isInt, 830 bool isSigned) { 831 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 832 isInt, isSigned); 833 } 834 835 /// A method used to verify specific type invariants that the templatized 'get' 836 /// method cannot. 837 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 838 bool isSigned) const { 839 return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt, 840 isSigned); 841 } 842 843 /// Check the information for a C++ data type, check if this type is valid for 844 /// the current attribute. 845 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 846 bool isSigned) const { 847 return ::isValidIntOrFloat( 848 getType().getElementType().cast<ComplexType>().getElementType(), 849 dataEltSize / 2, isInt, isSigned); 850 } 851 852 /// Returns true if this attribute corresponds to a splat, i.e. if all element 853 /// values are the same. 854 bool DenseElementsAttr::isSplat() const { 855 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 856 } 857 858 /// Return the held element values as a range of Attributes. 859 auto DenseElementsAttr::getAttributeValues() const 860 -> llvm::iterator_range<AttributeElementIterator> { 861 return {attr_value_begin(), attr_value_end()}; 862 } 863 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 864 return AttributeElementIterator(*this, 0); 865 } 866 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 867 return AttributeElementIterator(*this, getNumElements()); 868 } 869 870 /// Return the held element values as a range of bool. The element type of 871 /// this attribute must be of integer type of bitwidth 1. 872 auto DenseElementsAttr::getBoolValues() const 873 -> llvm::iterator_range<BoolElementIterator> { 874 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 875 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 876 (void)eltType; 877 return {BoolElementIterator(*this, 0), 878 BoolElementIterator(*this, getNumElements())}; 879 } 880 881 /// Return the held element values as a range of APInts. The element type of 882 /// this attribute must be of integer type. 883 auto DenseElementsAttr::getIntValues() const 884 -> llvm::iterator_range<IntElementIterator> { 885 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 886 return {raw_int_begin(), raw_int_end()}; 887 } 888 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 889 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 890 return raw_int_begin(); 891 } 892 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 893 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 894 return raw_int_end(); 895 } 896 auto DenseElementsAttr::getComplexIntValues() const 897 -> llvm::iterator_range<ComplexIntElementIterator> { 898 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 899 (void)eltTy; 900 assert(eltTy.isa<IntegerType>() && "expected complex integral type"); 901 return {ComplexIntElementIterator(*this, 0), 902 ComplexIntElementIterator(*this, getNumElements())}; 903 } 904 905 /// Return the held element values as a range of APFloat. The element type of 906 /// this attribute must be of float type. 907 auto DenseElementsAttr::getFloatValues() const 908 -> llvm::iterator_range<FloatElementIterator> { 909 auto elementType = getType().getElementType().cast<FloatType>(); 910 const auto &elementSemantics = elementType.getFloatSemantics(); 911 return {FloatElementIterator(elementSemantics, raw_int_begin()), 912 FloatElementIterator(elementSemantics, raw_int_end())}; 913 } 914 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 915 return getFloatValues().begin(); 916 } 917 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 918 return getFloatValues().end(); 919 } 920 auto DenseElementsAttr::getComplexFloatValues() const 921 -> llvm::iterator_range<ComplexFloatElementIterator> { 922 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 923 assert(eltTy.isa<FloatType>() && "expected complex float type"); 924 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 925 return {{semantics, {*this, 0}}, 926 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 927 } 928 929 /// Return the raw storage data held by this attribute. 930 ArrayRef<char> DenseElementsAttr::getRawData() const { 931 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data; 932 } 933 934 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 935 return static_cast<DenseStringElementsAttrStorage *>(impl)->data; 936 } 937 938 /// Return a new DenseElementsAttr that has the same data as the current 939 /// attribute, but has been reshaped to 'newType'. The new type must have the 940 /// same total number of elements as well as element type. 941 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 942 ShapedType curType = getType(); 943 if (curType == newType) 944 return *this; 945 946 (void)curType; 947 assert(newType.getElementType() == curType.getElementType() && 948 "expected the same element type"); 949 assert(newType.getNumElements() == curType.getNumElements() && 950 "expected the same number of elements"); 951 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 952 } 953 954 DenseElementsAttr 955 DenseElementsAttr::mapValues(Type newElementType, 956 function_ref<APInt(const APInt &)> mapping) const { 957 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 958 } 959 960 DenseElementsAttr DenseElementsAttr::mapValues( 961 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 962 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 963 } 964 965 //===----------------------------------------------------------------------===// 966 // DenseIntOrFPElementsAttr 967 //===----------------------------------------------------------------------===// 968 969 /// Utility method to write a range of APInt values to a buffer. 970 template <typename APRangeT> 971 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 972 APRangeT &&values) { 973 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); 974 size_t offset = 0; 975 for (auto it = values.begin(), e = values.end(); it != e; 976 ++it, offset += storageWidth) { 977 assert((*it).getBitWidth() <= storageWidth); 978 writeBits(data.data(), offset, *it); 979 } 980 } 981 982 /// Constructs a dense elements attribute from an array of raw APFloat values. 983 /// Each APFloat value is expected to have the same bitwidth as the element 984 /// type of 'type'. 'type' must be a vector or tensor with static shape. 985 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 986 size_t storageWidth, 987 ArrayRef<APFloat> values, 988 bool isSplat) { 989 std::vector<char> data; 990 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 991 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 992 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 993 } 994 995 /// Constructs a dense elements attribute from an array of raw APInt values. 996 /// Each APInt value is expected to have the same bitwidth as the element type 997 /// of 'type'. 998 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 999 size_t storageWidth, 1000 ArrayRef<APInt> values, 1001 bool isSplat) { 1002 std::vector<char> data; 1003 writeAPIntsToBuffer(storageWidth, data, values); 1004 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1005 } 1006 1007 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1008 ArrayRef<char> data, 1009 bool isSplat) { 1010 assert((type.isa<RankedTensorType, VectorType>()) && 1011 "type must be ranked tensor or vector"); 1012 assert(type.hasStaticShape() && "type must have static shape"); 1013 return Base::get(type.getContext(), type, data, isSplat); 1014 } 1015 1016 /// Overload of the raw 'get' method that asserts that the given type is of 1017 /// complex type. This method is used to verify type invariants that the 1018 /// templatized 'get' method cannot. 1019 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1020 ArrayRef<char> data, 1021 int64_t dataEltSize, 1022 bool isInt, 1023 bool isSigned) { 1024 assert(::isValidIntOrFloat( 1025 type.getElementType().cast<ComplexType>().getElementType(), 1026 dataEltSize / 2, isInt, isSigned)); 1027 1028 int64_t numElements = data.size() / dataEltSize; 1029 assert(numElements == 1 || numElements == type.getNumElements()); 1030 return getRaw(type, data, /*isSplat=*/numElements == 1); 1031 } 1032 1033 /// Overload of the 'getRaw' method that asserts that the given type is of 1034 /// integer type. This method is used to verify type invariants that the 1035 /// templatized 'get' method cannot. 1036 DenseElementsAttr 1037 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1038 int64_t dataEltSize, bool isInt, 1039 bool isSigned) { 1040 assert( 1041 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1042 1043 int64_t numElements = data.size() / dataEltSize; 1044 assert(numElements == 1 || numElements == type.getNumElements()); 1045 return getRaw(type, data, /*isSplat=*/numElements == 1); 1046 } 1047 1048 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 1049 const char *inRawData, char *outRawData, size_t elementBitWidth, 1050 size_t numElements) { 1051 using llvm::support::ulittle16_t; 1052 using llvm::support::ulittle32_t; 1053 using llvm::support::ulittle64_t; 1054 1055 assert(llvm::support::endian::system_endianness() == // NOLINT 1056 llvm::support::endianness::big); // NOLINT 1057 // NOLINT to avoid warning message about replacing by static_assert() 1058 1059 // Following std::copy_n always converts endianness on BE machine. 1060 switch (elementBitWidth) { 1061 case 16: { 1062 const ulittle16_t *inRawDataPos = 1063 reinterpret_cast<const ulittle16_t *>(inRawData); 1064 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); 1065 std::copy_n(inRawDataPos, numElements, outDataPos); 1066 break; 1067 } 1068 case 32: { 1069 const ulittle32_t *inRawDataPos = 1070 reinterpret_cast<const ulittle32_t *>(inRawData); 1071 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); 1072 std::copy_n(inRawDataPos, numElements, outDataPos); 1073 break; 1074 } 1075 case 64: { 1076 const ulittle64_t *inRawDataPos = 1077 reinterpret_cast<const ulittle64_t *>(inRawData); 1078 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); 1079 std::copy_n(inRawDataPos, numElements, outDataPos); 1080 break; 1081 } 1082 default: { 1083 size_t nBytes = elementBitWidth / CHAR_BIT; 1084 for (size_t i = 0; i < nBytes; i++) 1085 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); 1086 break; 1087 } 1088 } 1089 } 1090 1091 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1092 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, 1093 ShapedType type) { 1094 size_t numElements = type.getNumElements(); 1095 Type elementType = type.getElementType(); 1096 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1097 elementType = complexTy.getElementType(); 1098 numElements = numElements * 2; 1099 } 1100 size_t elementBitWidth = getDenseElementStorageWidth(elementType); 1101 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && 1102 inRawData.size() <= outRawData.size()); 1103 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), 1104 elementBitWidth, numElements); 1105 } 1106 1107 //===----------------------------------------------------------------------===// 1108 // DenseFPElementsAttr 1109 //===----------------------------------------------------------------------===// 1110 1111 template <typename Fn, typename Attr> 1112 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1113 Type newElementType, 1114 llvm::SmallVectorImpl<char> &data) { 1115 size_t bitWidth = getDenseElementBitWidth(newElementType); 1116 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1117 1118 ShapedType newArrayType; 1119 if (inType.isa<RankedTensorType>()) 1120 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1121 else if (inType.isa<UnrankedTensorType>()) 1122 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1123 else if (inType.isa<VectorType>()) 1124 newArrayType = VectorType::get(inType.getShape(), newElementType); 1125 else 1126 assert(newArrayType && "Unhandled tensor type"); 1127 1128 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1129 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 1130 1131 // Functor used to process a single element value of the attribute. 1132 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1133 auto newInt = mapping(value); 1134 assert(newInt.getBitWidth() == bitWidth); 1135 writeBits(data.data(), index * storageBitWidth, newInt); 1136 }; 1137 1138 // Check for the splat case. 1139 if (attr.isSplat()) { 1140 processElt(*attr.begin(), /*index=*/0); 1141 return newArrayType; 1142 } 1143 1144 // Otherwise, process all of the element values. 1145 uint64_t elementIdx = 0; 1146 for (auto value : attr) 1147 processElt(value, elementIdx++); 1148 return newArrayType; 1149 } 1150 1151 DenseElementsAttr DenseFPElementsAttr::mapValues( 1152 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1153 llvm::SmallVector<char, 8> elementData; 1154 auto newArrayType = 1155 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1156 1157 return getRaw(newArrayType, elementData, isSplat()); 1158 } 1159 1160 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1161 bool DenseFPElementsAttr::classof(Attribute attr) { 1162 return attr.isa<DenseElementsAttr>() && 1163 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1164 } 1165 1166 //===----------------------------------------------------------------------===// 1167 // DenseIntElementsAttr 1168 //===----------------------------------------------------------------------===// 1169 1170 DenseElementsAttr DenseIntElementsAttr::mapValues( 1171 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1172 llvm::SmallVector<char, 8> elementData; 1173 auto newArrayType = 1174 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1175 1176 return getRaw(newArrayType, elementData, isSplat()); 1177 } 1178 1179 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1180 bool DenseIntElementsAttr::classof(Attribute attr) { 1181 return attr.isa<DenseElementsAttr>() && 1182 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1183 } 1184 1185 //===----------------------------------------------------------------------===// 1186 // OpaqueElementsAttr 1187 //===----------------------------------------------------------------------===// 1188 1189 /// Return the value at the given index. If index does not refer to a valid 1190 /// element, then a null attribute is returned. 1191 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1192 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1193 return Attribute(); 1194 } 1195 1196 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1197 Dialect *dialect = getDialect().getDialect(); 1198 if (!dialect) 1199 return true; 1200 auto *interface = 1201 dialect->getRegisteredInterface<DialectDecodeAttributesInterface>(); 1202 if (!interface) 1203 return true; 1204 return failed(interface->decode(*this, result)); 1205 } 1206 1207 LogicalResult 1208 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1209 Identifier dialect, StringRef value, 1210 ShapedType type) { 1211 if (!Dialect::isValidNamespace(dialect.strref())) 1212 return emitError() << "invalid dialect namespace '" << dialect << "'"; 1213 return success(); 1214 } 1215 1216 //===----------------------------------------------------------------------===// 1217 // SparseElementsAttr 1218 //===----------------------------------------------------------------------===// 1219 1220 /// Return the value of the element at the given index. 1221 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1222 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1223 auto type = getType(); 1224 1225 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1226 // as a 1-D index array. 1227 auto sparseIndices = getIndices(); 1228 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1229 1230 // Check to see if the indices are a splat. 1231 if (sparseIndices.isSplat()) { 1232 // If the index is also not a splat of the index value, we know that the 1233 // value is zero. 1234 auto splatIndex = *sparseIndexValues.begin(); 1235 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1236 return getZeroAttr(); 1237 1238 // If the indices are a splat, we also expect the values to be a splat. 1239 assert(getValues().isSplat() && "expected splat values"); 1240 return getValues().getSplatValue(); 1241 } 1242 1243 // Build a mapping between known indices and the offset of the stored element. 1244 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1245 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1246 size_t rank = type.getRank(); 1247 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1248 mappedIndices.try_emplace( 1249 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1250 1251 // Look for the provided index key within the mapped indices. If the provided 1252 // index is not found, then return a zero attribute. 1253 auto it = mappedIndices.find(index); 1254 if (it == mappedIndices.end()) 1255 return getZeroAttr(); 1256 1257 // Otherwise, return the held sparse value element. 1258 return getValues().getValue(it->second); 1259 } 1260 1261 /// Get a zero APFloat for the given sparse attribute. 1262 APFloat SparseElementsAttr::getZeroAPFloat() const { 1263 auto eltType = getType().getElementType().cast<FloatType>(); 1264 return APFloat(eltType.getFloatSemantics()); 1265 } 1266 1267 /// Get a zero APInt for the given sparse attribute. 1268 APInt SparseElementsAttr::getZeroAPInt() const { 1269 auto eltType = getType().getElementType().cast<IntegerType>(); 1270 return APInt::getNullValue(eltType.getWidth()); 1271 } 1272 1273 /// Get a zero attribute for the given attribute type. 1274 Attribute SparseElementsAttr::getZeroAttr() const { 1275 auto eltType = getType().getElementType(); 1276 1277 // Handle floating point elements. 1278 if (eltType.isa<FloatType>()) 1279 return FloatAttr::get(eltType, 0); 1280 1281 // Otherwise, this is an integer. 1282 // TODO: Handle StringAttr here. 1283 return IntegerAttr::get(eltType, 0); 1284 } 1285 1286 /// Flatten, and return, all of the sparse indices in this attribute in 1287 /// row-major order. 1288 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1289 std::vector<ptrdiff_t> flatSparseIndices; 1290 1291 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1292 // as a 1-D index array. 1293 auto sparseIndices = getIndices(); 1294 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1295 if (sparseIndices.isSplat()) { 1296 SmallVector<uint64_t, 8> indices(getType().getRank(), 1297 *sparseIndexValues.begin()); 1298 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1299 return flatSparseIndices; 1300 } 1301 1302 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1303 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1304 size_t rank = getType().getRank(); 1305 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1306 flatSparseIndices.push_back(getFlattenedIndex( 1307 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1308 return flatSparseIndices; 1309 } 1310