1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Diagnostics.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/IntegerSet.h" 15 #include "mlir/IR/Types.h" 16 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 #include "llvm/Support/Endian.h" 20 21 using namespace mlir; 22 using namespace mlir::detail; 23 24 //===----------------------------------------------------------------------===// 25 /// Tablegen Attribute Definitions 26 //===----------------------------------------------------------------------===// 27 28 #define GET_ATTRDEF_CLASSES 29 #include "mlir/IR/BuiltinAttributes.cpp.inc" 30 31 //===----------------------------------------------------------------------===// 32 // DictionaryAttr 33 //===----------------------------------------------------------------------===// 34 35 /// Helper function that does either an in place sort or sorts from source array 36 /// into destination. If inPlace then storage is both the source and the 37 /// destination, else value is the source and storage destination. Returns 38 /// whether source was sorted. 39 template <bool inPlace> 40 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 41 SmallVectorImpl<NamedAttribute> &storage) { 42 // Specialize for the common case. 43 switch (value.size()) { 44 case 0: 45 // Zero already sorted. 46 break; 47 case 1: 48 // One already sorted but may need to be copied. 49 if (!inPlace) 50 storage.assign({value[0]}); 51 break; 52 case 2: { 53 bool isSorted = value[0] < value[1]; 54 if (inPlace) { 55 if (!isSorted) 56 std::swap(storage[0], storage[1]); 57 } else if (isSorted) { 58 storage.assign({value[0], value[1]}); 59 } else { 60 storage.assign({value[1], value[0]}); 61 } 62 return !isSorted; 63 } 64 default: 65 if (!inPlace) 66 storage.assign(value.begin(), value.end()); 67 // Check to see they are sorted already. 68 bool isSorted = llvm::is_sorted(value); 69 if (!isSorted) { 70 // If not, do a general sort. 71 llvm::array_pod_sort(storage.begin(), storage.end()); 72 value = storage; 73 } 74 return !isSorted; 75 } 76 return false; 77 } 78 79 /// Returns an entry with a duplicate name from the given sorted array of named 80 /// attributes. Returns llvm::None if all elements have unique names. 81 static Optional<NamedAttribute> 82 findDuplicateElement(ArrayRef<NamedAttribute> value) { 83 const Optional<NamedAttribute> none{llvm::None}; 84 if (value.size() < 2) 85 return none; 86 87 if (value.size() == 2) 88 return value[0].first == value[1].first ? value[0] : none; 89 90 auto it = std::adjacent_find( 91 value.begin(), value.end(), 92 [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; }); 93 return it != value.end() ? *it : none; 94 } 95 96 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 97 SmallVectorImpl<NamedAttribute> &storage) { 98 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); 99 assert(!findDuplicateElement(storage) && 100 "DictionaryAttr element names must be unique"); 101 return isSorted; 102 } 103 104 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 105 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); 106 assert(!findDuplicateElement(array) && 107 "DictionaryAttr element names must be unique"); 108 return isSorted; 109 } 110 111 Optional<NamedAttribute> 112 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, 113 bool isSorted) { 114 if (!isSorted) 115 dictionaryAttrSort</*inPlace=*/true>(array, array); 116 return findDuplicateElement(array); 117 } 118 119 DictionaryAttr DictionaryAttr::get(MLIRContext *context, 120 ArrayRef<NamedAttribute> value) { 121 if (value.empty()) 122 return DictionaryAttr::getEmpty(context); 123 assert(llvm::all_of(value, 124 [](const NamedAttribute &attr) { return attr.second; }) && 125 "value cannot have null entries"); 126 127 // We need to sort the element list to canonicalize it. 128 SmallVector<NamedAttribute, 8> storage; 129 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 130 value = storage; 131 assert(!findDuplicateElement(value) && 132 "DictionaryAttr element names must be unique"); 133 return Base::get(context, value); 134 } 135 /// Construct a dictionary with an array of values that is known to already be 136 /// sorted by name and uniqued. 137 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context, 138 ArrayRef<NamedAttribute> value) { 139 if (value.empty()) 140 return DictionaryAttr::getEmpty(context); 141 // Ensure that the attribute elements are unique and sorted. 142 assert(llvm::is_sorted(value, 143 [](NamedAttribute l, NamedAttribute r) { 144 return l.first.strref() < r.first.strref(); 145 }) && 146 "expected attribute values to be sorted"); 147 assert(!findDuplicateElement(value) && 148 "DictionaryAttr element names must be unique"); 149 return Base::get(context, value); 150 } 151 152 /// Return the specified attribute if present, null otherwise. 153 Attribute DictionaryAttr::get(StringRef name) const { 154 Optional<NamedAttribute> attr = getNamed(name); 155 return attr ? attr->second : nullptr; 156 } 157 Attribute DictionaryAttr::get(Identifier name) const { 158 Optional<NamedAttribute> attr = getNamed(name); 159 return attr ? attr->second : nullptr; 160 } 161 162 /// Return the specified named attribute if present, None otherwise. 163 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 164 ArrayRef<NamedAttribute> values = getValue(); 165 const auto *it = llvm::lower_bound(values, name); 166 return it != values.end() && it->first == name ? *it 167 : Optional<NamedAttribute>(); 168 } 169 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { 170 for (auto elt : getValue()) 171 if (elt.first == name) 172 return elt; 173 return llvm::None; 174 } 175 176 DictionaryAttr::iterator DictionaryAttr::begin() const { 177 return getValue().begin(); 178 } 179 DictionaryAttr::iterator DictionaryAttr::end() const { 180 return getValue().end(); 181 } 182 size_t DictionaryAttr::size() const { return getValue().size(); } 183 184 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { 185 return Base::get(context, ArrayRef<NamedAttribute>()); 186 } 187 188 //===----------------------------------------------------------------------===// 189 // FloatAttr 190 //===----------------------------------------------------------------------===// 191 192 FloatAttr FloatAttr::get(Type type, double value) { 193 return Base::get(type.getContext(), type, value); 194 } 195 196 FloatAttr FloatAttr::getChecked(function_ref<InFlightDiagnostic()> emitError, 197 Type type, double value) { 198 return Base::getChecked(emitError, type.getContext(), type, value); 199 } 200 201 FloatAttr FloatAttr::get(Type type, const APFloat &value) { 202 return Base::get(type.getContext(), type, value); 203 } 204 205 FloatAttr FloatAttr::getChecked(function_ref<InFlightDiagnostic()> emitError, 206 Type type, const APFloat &value) { 207 return Base::getChecked(emitError, type.getContext(), type, value); 208 } 209 210 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } 211 212 double FloatAttr::getValueAsDouble() const { 213 return getValueAsDouble(getValue()); 214 } 215 double FloatAttr::getValueAsDouble(APFloat value) { 216 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 217 bool losesInfo = false; 218 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 219 &losesInfo); 220 } 221 return value.convertToDouble(); 222 } 223 224 /// Verify construction invariants. 225 static LogicalResult 226 verifyFloatTypeInvariants(function_ref<InFlightDiagnostic()> emitError, 227 Type type) { 228 if (!type.isa<FloatType>()) 229 return emitError() << "expected floating point type"; 230 return success(); 231 } 232 233 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 234 Type type, double value) { 235 return verifyFloatTypeInvariants(emitError, type); 236 } 237 238 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 239 Type type, const APFloat &value) { 240 // Verify that the type is correct. 241 if (failed(verifyFloatTypeInvariants(emitError, type))) 242 return failure(); 243 244 // Verify that the type semantics match that of the value. 245 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 246 return emitError() 247 << "FloatAttr type doesn't match the type implied by its value"; 248 } 249 return success(); 250 } 251 252 //===----------------------------------------------------------------------===// 253 // SymbolRefAttr 254 //===----------------------------------------------------------------------===// 255 256 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { 257 return get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>(); 258 } 259 260 StringRef SymbolRefAttr::getLeafReference() const { 261 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 262 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 263 } 264 265 //===----------------------------------------------------------------------===// 266 // IntegerAttr 267 //===----------------------------------------------------------------------===// 268 269 IntegerAttr IntegerAttr::get(Type type, const APInt &value) { 270 if (type.isSignlessInteger(1)) 271 return BoolAttr::get(type.getContext(), value.getBoolValue()); 272 return Base::get(type.getContext(), type, value); 273 } 274 275 IntegerAttr IntegerAttr::get(Type type, int64_t value) { 276 // This uses 64 bit APInts by default for index type. 277 if (type.isIndex()) 278 return get(type, APInt(IndexType::kInternalStorageBitWidth, value)); 279 280 auto intType = type.cast<IntegerType>(); 281 return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger())); 282 } 283 284 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } 285 286 int64_t IntegerAttr::getInt() const { 287 assert((getImpl()->getType().isIndex() || 288 getImpl()->getType().isSignlessInteger()) && 289 "must be signless integer"); 290 return getValue().getSExtValue(); 291 } 292 293 int64_t IntegerAttr::getSInt() const { 294 assert(getImpl()->getType().isSignedInteger() && "must be signed integer"); 295 return getValue().getSExtValue(); 296 } 297 298 uint64_t IntegerAttr::getUInt() const { 299 assert(getImpl()->getType().isUnsignedInteger() && 300 "must be unsigned integer"); 301 return getValue().getZExtValue(); 302 } 303 304 static LogicalResult 305 verifyIntegerTypeInvariants(function_ref<InFlightDiagnostic()> emitError, 306 Type type) { 307 if (type.isa<IntegerType, IndexType>()) 308 return success(); 309 return emitError() << "expected integer or index type"; 310 } 311 312 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, 313 Type type, int64_t value) { 314 return verifyIntegerTypeInvariants(emitError, type); 315 } 316 317 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, 318 Type type, const APInt &value) { 319 if (failed(verifyIntegerTypeInvariants(emitError, type))) 320 return failure(); 321 if (auto integerType = type.dyn_cast<IntegerType>()) 322 if (integerType.getWidth() != value.getBitWidth()) 323 return emitError() << "integer type bit width (" << integerType.getWidth() 324 << ") doesn't match value bit width (" 325 << value.getBitWidth() << ")"; 326 return success(); 327 } 328 329 //===----------------------------------------------------------------------===// 330 // BoolAttr 331 332 bool BoolAttr::getValue() const { 333 auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl); 334 return storage->getValue().getBoolValue(); 335 } 336 337 bool BoolAttr::classof(Attribute attr) { 338 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 339 return intAttr && intAttr.getType().isSignlessInteger(1); 340 } 341 342 //===----------------------------------------------------------------------===// 343 // OpaqueAttr 344 //===----------------------------------------------------------------------===// 345 346 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError, 347 Identifier dialect, StringRef attrData, 348 Type type) { 349 if (!Dialect::isValidNamespace(dialect.strref())) 350 return emitError() << "invalid dialect namespace '" << dialect << "'"; 351 return success(); 352 } 353 354 //===----------------------------------------------------------------------===// 355 // ElementsAttr 356 //===----------------------------------------------------------------------===// 357 358 ShapedType ElementsAttr::getType() const { 359 return Attribute::getType().cast<ShapedType>(); 360 } 361 362 /// Returns the number of elements held by this attribute. 363 int64_t ElementsAttr::getNumElements() const { 364 return getType().getNumElements(); 365 } 366 367 /// Return the value at the given index. If index does not refer to a valid 368 /// element, then a null attribute is returned. 369 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 370 if (auto denseAttr = dyn_cast<DenseElementsAttr>()) 371 return denseAttr.getValue(index); 372 if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>()) 373 return opaqueAttr.getValue(index); 374 return cast<SparseElementsAttr>().getValue(index); 375 } 376 377 /// Return if the given 'index' refers to a valid element in this attribute. 378 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 379 auto type = getType(); 380 381 // Verify that the rank of the indices matches the held type. 382 auto rank = type.getRank(); 383 if (rank == 0 && index.size() == 1 && index[0] == 0) 384 return true; 385 if (rank != static_cast<int64_t>(index.size())) 386 return false; 387 388 // Verify that all of the indices are within the shape dimensions. 389 auto shape = type.getShape(); 390 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 391 return static_cast<int64_t>(index[i]) < shape[i]; 392 }); 393 } 394 395 ElementsAttr 396 ElementsAttr::mapValues(Type newElementType, 397 function_ref<APInt(const APInt &)> mapping) const { 398 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 399 return intOrFpAttr.mapValues(newElementType, mapping); 400 llvm_unreachable("unsupported ElementsAttr subtype"); 401 } 402 403 ElementsAttr 404 ElementsAttr::mapValues(Type newElementType, 405 function_ref<APInt(const APFloat &)> mapping) const { 406 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 407 return intOrFpAttr.mapValues(newElementType, mapping); 408 llvm_unreachable("unsupported ElementsAttr subtype"); 409 } 410 411 /// Method for support type inquiry through isa, cast and dyn_cast. 412 bool ElementsAttr::classof(Attribute attr) { 413 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr, 414 OpaqueElementsAttr, SparseElementsAttr>(); 415 } 416 417 /// Returns the 1 dimensional flattened row-major index from the given 418 /// multi-dimensional index. 419 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 420 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 421 auto type = getType(); 422 423 // Reduce the provided multidimensional index into a flattended 1D row-major 424 // index. 425 auto rank = type.getRank(); 426 auto shape = type.getShape(); 427 uint64_t valueIndex = 0; 428 uint64_t dimMultiplier = 1; 429 for (int i = rank - 1; i >= 0; --i) { 430 valueIndex += index[i] * dimMultiplier; 431 dimMultiplier *= shape[i]; 432 } 433 return valueIndex; 434 } 435 436 //===----------------------------------------------------------------------===// 437 // DenseElementsAttr Utilities 438 //===----------------------------------------------------------------------===// 439 440 /// Get the bitwidth of a dense element type within the buffer. 441 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 442 static size_t getDenseElementStorageWidth(size_t origWidth) { 443 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 444 } 445 static size_t getDenseElementStorageWidth(Type elementType) { 446 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 447 } 448 449 /// Set a bit to a specific value. 450 static void setBit(char *rawData, size_t bitPos, bool value) { 451 if (value) 452 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 453 else 454 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 455 } 456 457 /// Return the value of the specified bit. 458 static bool getBit(const char *rawData, size_t bitPos) { 459 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 460 } 461 462 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for 463 /// BE format. 464 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, 465 char *result) { 466 assert(llvm::support::endian::system_endianness() == // NOLINT 467 llvm::support::endianness::big); // NOLINT 468 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 469 470 // Copy the words filled with data. 471 // For example, when `value` has 2 words, the first word is filled with data. 472 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| 473 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 474 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 475 numFilledWords, result); 476 // Convert last word of APInt to LE format and store it in char 477 // array(`valueLE`). 478 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| 479 size_t lastWordPos = numFilledWords; 480 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); 481 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 482 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, 483 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); 484 // Extract actual APInt data from `valueLE`, convert endianness to BE format, 485 // and store it in `result`. 486 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| 487 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 488 valueLE.begin(), result + lastWordPos, 489 (numBytes - lastWordPos) * CHAR_BIT, 1); 490 } 491 492 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE 493 /// format. 494 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, 495 APInt &result) { 496 assert(llvm::support::endian::system_endianness() == // NOLINT 497 llvm::support::endianness::big); // NOLINT 498 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 499 500 // Copy the data that fills the word of `result` from `inArray`. 501 // For example, when `result` has 2 words, the first word will be filled with 502 // data. So, the first 8 bytes are copied from `inArray` here. 503 // `inArray` (10 bytes, BE): |abcdefgh|ij| 504 // ==> `result` (2 words, BE): |abcdefgh|--------| 505 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 506 std::copy_n( 507 inArray, numFilledWords, 508 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 509 510 // Convert array data which will be last word of `result` to LE format, and 511 // store it in char array(`inArrayLE`). 512 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| 513 size_t lastWordPos = numFilledWords; 514 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); 515 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 516 inArray + lastWordPos, inArrayLE.begin(), 517 (numBytes - lastWordPos) * CHAR_BIT, 1); 518 519 // Convert `inArrayLE` to BE format, and store it in last word of `result`. 520 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| 521 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 522 inArrayLE.begin(), 523 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + 524 lastWordPos, 525 APInt::APINT_BITS_PER_WORD, 1); 526 } 527 528 /// Writes value to the bit position `bitPos` in array `rawData`. 529 static void writeBits(char *rawData, size_t bitPos, APInt value) { 530 size_t bitWidth = value.getBitWidth(); 531 532 // If the bitwidth is 1 we just toggle the specific bit. 533 if (bitWidth == 1) 534 return setBit(rawData, bitPos, value.isOneValue()); 535 536 // Otherwise, the bit position is guaranteed to be byte aligned. 537 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 538 if (llvm::support::endian::system_endianness() == 539 llvm::support::endianness::big) { 540 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. 541 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 542 // work correctly in BE format. 543 // ex. `value` (2 words including 10 bytes) 544 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| 545 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), 546 rawData + (bitPos / CHAR_BIT)); 547 } else { 548 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 549 llvm::divideCeil(bitWidth, CHAR_BIT), 550 rawData + (bitPos / CHAR_BIT)); 551 } 552 } 553 554 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 555 /// `rawData`. 556 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 557 // Handle a boolean bit position. 558 if (bitWidth == 1) 559 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 560 561 // Otherwise, the bit position must be 8-bit aligned. 562 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 563 APInt result(bitWidth, 0); 564 if (llvm::support::endian::system_endianness() == 565 llvm::support::endianness::big) { 566 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. 567 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 568 // work correctly in BE format. 569 // ex. `result` (2 words including 10 bytes) 570 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function 571 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), 572 llvm::divideCeil(bitWidth, CHAR_BIT), result); 573 } else { 574 std::copy_n(rawData + (bitPos / CHAR_BIT), 575 llvm::divideCeil(bitWidth, CHAR_BIT), 576 const_cast<char *>( 577 reinterpret_cast<const char *>(result.getRawData()))); 578 } 579 return result; 580 } 581 582 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 583 /// the same element count as 'type'. 584 template <typename Values> 585 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 586 return (values.size() == 1) || 587 (type.getNumElements() == static_cast<int64_t>(values.size())); 588 } 589 590 //===----------------------------------------------------------------------===// 591 // DenseElementsAttr Iterators 592 //===----------------------------------------------------------------------===// 593 594 //===----------------------------------------------------------------------===// 595 // AttributeElementIterator 596 597 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 598 DenseElementsAttr attr, size_t index) 599 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 600 Attribute, Attribute, Attribute>( 601 attr.getAsOpaquePointer(), index) {} 602 603 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 604 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 605 Type eltTy = owner.getType().getElementType(); 606 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 607 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 608 if (eltTy.isa<IndexType>()) 609 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 610 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 611 IntElementIterator intIt(owner, index); 612 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 613 return FloatAttr::get(eltTy, *floatIt); 614 } 615 if (owner.isa<DenseStringElementsAttr>()) { 616 ArrayRef<StringRef> vals = owner.getRawStringData(); 617 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 618 } 619 llvm_unreachable("unexpected element type"); 620 } 621 622 //===----------------------------------------------------------------------===// 623 // BoolElementIterator 624 625 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 626 DenseElementsAttr attr, size_t dataIndex) 627 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 628 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 629 630 bool DenseElementsAttr::BoolElementIterator::operator*() const { 631 return getBit(getData(), getDataIndex()); 632 } 633 634 //===----------------------------------------------------------------------===// 635 // IntElementIterator 636 637 DenseElementsAttr::IntElementIterator::IntElementIterator( 638 DenseElementsAttr attr, size_t dataIndex) 639 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 640 attr.getRawData().data(), attr.isSplat(), dataIndex), 641 bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} 642 643 APInt DenseElementsAttr::IntElementIterator::operator*() const { 644 return readBits(getData(), 645 getDataIndex() * getDenseElementStorageWidth(bitWidth), 646 bitWidth); 647 } 648 649 //===----------------------------------------------------------------------===// 650 // ComplexIntElementIterator 651 652 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 653 DenseElementsAttr attr, size_t dataIndex) 654 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 655 std::complex<APInt>, std::complex<APInt>, 656 std::complex<APInt>>( 657 attr.getRawData().data(), attr.isSplat(), dataIndex) { 658 auto complexType = attr.getType().getElementType().cast<ComplexType>(); 659 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 660 } 661 662 std::complex<APInt> 663 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 664 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 665 size_t offset = getDataIndex() * storageWidth * 2; 666 return {readBits(getData(), offset, bitWidth), 667 readBits(getData(), offset + storageWidth, bitWidth)}; 668 } 669 670 //===----------------------------------------------------------------------===// 671 // FloatElementIterator 672 673 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 674 const llvm::fltSemantics &smt, IntElementIterator it) 675 : llvm::mapped_iterator<IntElementIterator, 676 std::function<APFloat(const APInt &)>>( 677 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 678 679 //===----------------------------------------------------------------------===// 680 // ComplexFloatElementIterator 681 682 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( 683 const llvm::fltSemantics &smt, ComplexIntElementIterator it) 684 : llvm::mapped_iterator< 685 ComplexIntElementIterator, 686 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( 687 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { 688 return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; 689 }) {} 690 691 //===----------------------------------------------------------------------===// 692 // DenseElementsAttr 693 //===----------------------------------------------------------------------===// 694 695 /// Method for support type inquiry through isa, cast and dyn_cast. 696 bool DenseElementsAttr::classof(Attribute attr) { 697 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 698 } 699 700 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 701 ArrayRef<Attribute> values) { 702 assert(hasSameElementsOrSplat(type, values)); 703 704 // If the element type is not based on int/float/index, assume it is a string 705 // type. 706 auto eltType = type.getElementType(); 707 if (!type.getElementType().isIntOrIndexOrFloat()) { 708 SmallVector<StringRef, 8> stringValues; 709 stringValues.reserve(values.size()); 710 for (Attribute attr : values) { 711 assert(attr.isa<StringAttr>() && 712 "expected string value for non integer/index/float element"); 713 stringValues.push_back(attr.cast<StringAttr>().getValue()); 714 } 715 return get(type, stringValues); 716 } 717 718 // Otherwise, get the raw storage width to use for the allocation. 719 size_t bitWidth = getDenseElementBitWidth(eltType); 720 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 721 722 // Compress the attribute values into a character buffer. 723 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 724 values.size()); 725 APInt intVal; 726 for (unsigned i = 0, e = values.size(); i < e; ++i) { 727 assert(eltType == values[i].getType() && 728 "expected attribute value to have element type"); 729 if (eltType.isa<FloatType>()) 730 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 731 else if (eltType.isa<IntegerType>()) 732 intVal = values[i].cast<IntegerAttr>().getValue(); 733 else 734 llvm_unreachable("unexpected element type"); 735 736 assert(intVal.getBitWidth() == bitWidth && 737 "expected value to have same bitwidth as element type"); 738 writeBits(data.data(), i * storageBitWidth, intVal); 739 } 740 return DenseIntOrFPElementsAttr::getRaw(type, data, 741 /*isSplat=*/(values.size() == 1)); 742 } 743 744 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 745 ArrayRef<bool> values) { 746 assert(hasSameElementsOrSplat(type, values)); 747 assert(type.getElementType().isInteger(1)); 748 749 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 750 for (int i = 0, e = values.size(); i != e; ++i) 751 setBit(buff.data(), i, values[i]); 752 return DenseIntOrFPElementsAttr::getRaw(type, buff, 753 /*isSplat=*/(values.size() == 1)); 754 } 755 756 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 757 ArrayRef<StringRef> values) { 758 assert(!type.getElementType().isIntOrFloat()); 759 return DenseStringElementsAttr::get(type, values); 760 } 761 762 /// Constructs a dense integer elements attribute from an array of APInt 763 /// values. Each APInt value is expected to have the same bitwidth as the 764 /// element type of 'type'. 765 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 766 ArrayRef<APInt> values) { 767 assert(type.getElementType().isIntOrIndex()); 768 assert(hasSameElementsOrSplat(type, values)); 769 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 770 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 771 /*isSplat=*/(values.size() == 1)); 772 } 773 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 774 ArrayRef<std::complex<APInt>> values) { 775 ComplexType complex = type.getElementType().cast<ComplexType>(); 776 assert(complex.getElementType().isa<IntegerType>()); 777 assert(hasSameElementsOrSplat(type, values)); 778 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 779 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 780 values.size() * 2); 781 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, 782 /*isSplat=*/(values.size() == 1)); 783 } 784 785 // Constructs a dense float elements attribute from an array of APFloat 786 // values. Each APFloat value is expected to have the same bitwidth as the 787 // element type of 'type'. 788 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 789 ArrayRef<APFloat> values) { 790 assert(type.getElementType().isa<FloatType>()); 791 assert(hasSameElementsOrSplat(type, values)); 792 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 793 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 794 /*isSplat=*/(values.size() == 1)); 795 } 796 DenseElementsAttr 797 DenseElementsAttr::get(ShapedType type, 798 ArrayRef<std::complex<APFloat>> values) { 799 ComplexType complex = type.getElementType().cast<ComplexType>(); 800 assert(complex.getElementType().isa<FloatType>()); 801 assert(hasSameElementsOrSplat(type, values)); 802 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 803 values.size() * 2); 804 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 805 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, 806 /*isSplat=*/(values.size() == 1)); 807 } 808 809 /// Construct a dense elements attribute from a raw buffer representing the 810 /// data for this attribute. Users should generally not use this methods as 811 /// the expected buffer format may not be a form the user expects. 812 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 813 ArrayRef<char> rawBuffer, 814 bool isSplatBuffer) { 815 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 816 } 817 818 /// Returns true if the given buffer is a valid raw buffer for the given type. 819 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 820 ArrayRef<char> rawBuffer, 821 bool &detectedSplat) { 822 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 823 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 824 825 // Storage width of 1 is special as it is packed by the bit. 826 if (storageWidth == 1) { 827 // Check for a splat, or a buffer equal to the number of elements. 828 if ((detectedSplat = rawBuffer.size() == 1)) 829 return true; 830 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); 831 } 832 // All other types are 8-bit aligned. 833 if ((detectedSplat = rawBufferWidth == storageWidth)) 834 return true; 835 return rawBufferWidth == (storageWidth * type.getNumElements()); 836 } 837 838 /// Check the information for a C++ data type, check if this type is valid for 839 /// the current attribute. This method is used to verify specific type 840 /// invariants that the templatized 'getValues' method cannot. 841 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 842 bool isSigned) { 843 // Make sure that the data element size is the same as the type element width. 844 if (getDenseElementBitWidth(type) != 845 static_cast<size_t>(dataEltSize * CHAR_BIT)) 846 return false; 847 848 // Check that the element type is either float or integer or index. 849 if (!isInt) 850 return type.isa<FloatType>(); 851 if (type.isIndex()) 852 return true; 853 854 auto intType = type.dyn_cast<IntegerType>(); 855 if (!intType) 856 return false; 857 858 // Make sure signedness semantics is consistent. 859 if (intType.isSignless()) 860 return true; 861 return intType.isSigned() ? isSigned : !isSigned; 862 } 863 864 /// Defaults down the subclass implementation. 865 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 866 ArrayRef<char> data, 867 int64_t dataEltSize, 868 bool isInt, bool isSigned) { 869 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 870 isSigned); 871 } 872 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 873 ArrayRef<char> data, 874 int64_t dataEltSize, 875 bool isInt, 876 bool isSigned) { 877 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 878 isInt, isSigned); 879 } 880 881 /// A method used to verify specific type invariants that the templatized 'get' 882 /// method cannot. 883 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 884 bool isSigned) const { 885 return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt, 886 isSigned); 887 } 888 889 /// Check the information for a C++ data type, check if this type is valid for 890 /// the current attribute. 891 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 892 bool isSigned) const { 893 return ::isValidIntOrFloat( 894 getType().getElementType().cast<ComplexType>().getElementType(), 895 dataEltSize / 2, isInt, isSigned); 896 } 897 898 /// Returns true if this attribute corresponds to a splat, i.e. if all element 899 /// values are the same. 900 bool DenseElementsAttr::isSplat() const { 901 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 902 } 903 904 /// Return the held element values as a range of Attributes. 905 auto DenseElementsAttr::getAttributeValues() const 906 -> llvm::iterator_range<AttributeElementIterator> { 907 return {attr_value_begin(), attr_value_end()}; 908 } 909 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 910 return AttributeElementIterator(*this, 0); 911 } 912 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 913 return AttributeElementIterator(*this, getNumElements()); 914 } 915 916 /// Return the held element values as a range of bool. The element type of 917 /// this attribute must be of integer type of bitwidth 1. 918 auto DenseElementsAttr::getBoolValues() const 919 -> llvm::iterator_range<BoolElementIterator> { 920 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 921 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 922 (void)eltType; 923 return {BoolElementIterator(*this, 0), 924 BoolElementIterator(*this, getNumElements())}; 925 } 926 927 /// Return the held element values as a range of APInts. The element type of 928 /// this attribute must be of integer type. 929 auto DenseElementsAttr::getIntValues() const 930 -> llvm::iterator_range<IntElementIterator> { 931 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 932 return {raw_int_begin(), raw_int_end()}; 933 } 934 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 935 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 936 return raw_int_begin(); 937 } 938 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 939 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 940 return raw_int_end(); 941 } 942 auto DenseElementsAttr::getComplexIntValues() const 943 -> llvm::iterator_range<ComplexIntElementIterator> { 944 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 945 (void)eltTy; 946 assert(eltTy.isa<IntegerType>() && "expected complex integral type"); 947 return {ComplexIntElementIterator(*this, 0), 948 ComplexIntElementIterator(*this, getNumElements())}; 949 } 950 951 /// Return the held element values as a range of APFloat. The element type of 952 /// this attribute must be of float type. 953 auto DenseElementsAttr::getFloatValues() const 954 -> llvm::iterator_range<FloatElementIterator> { 955 auto elementType = getType().getElementType().cast<FloatType>(); 956 const auto &elementSemantics = elementType.getFloatSemantics(); 957 return {FloatElementIterator(elementSemantics, raw_int_begin()), 958 FloatElementIterator(elementSemantics, raw_int_end())}; 959 } 960 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 961 return getFloatValues().begin(); 962 } 963 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 964 return getFloatValues().end(); 965 } 966 auto DenseElementsAttr::getComplexFloatValues() const 967 -> llvm::iterator_range<ComplexFloatElementIterator> { 968 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 969 assert(eltTy.isa<FloatType>() && "expected complex float type"); 970 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 971 return {{semantics, {*this, 0}}, 972 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 973 } 974 975 /// Return the raw storage data held by this attribute. 976 ArrayRef<char> DenseElementsAttr::getRawData() const { 977 return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data; 978 } 979 980 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 981 return static_cast<DenseStringElementsAttributeStorage *>(impl)->data; 982 } 983 984 /// Return a new DenseElementsAttr that has the same data as the current 985 /// attribute, but has been reshaped to 'newType'. The new type must have the 986 /// same total number of elements as well as element type. 987 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 988 ShapedType curType = getType(); 989 if (curType == newType) 990 return *this; 991 992 (void)curType; 993 assert(newType.getElementType() == curType.getElementType() && 994 "expected the same element type"); 995 assert(newType.getNumElements() == curType.getNumElements() && 996 "expected the same number of elements"); 997 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 998 } 999 1000 DenseElementsAttr 1001 DenseElementsAttr::mapValues(Type newElementType, 1002 function_ref<APInt(const APInt &)> mapping) const { 1003 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 1004 } 1005 1006 DenseElementsAttr DenseElementsAttr::mapValues( 1007 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1008 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 1009 } 1010 1011 //===----------------------------------------------------------------------===// 1012 // DenseStringElementsAttr 1013 //===----------------------------------------------------------------------===// 1014 1015 DenseStringElementsAttr 1016 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) { 1017 return Base::get(type.getContext(), type, values, (values.size() == 1)); 1018 } 1019 1020 //===----------------------------------------------------------------------===// 1021 // DenseIntOrFPElementsAttr 1022 //===----------------------------------------------------------------------===// 1023 1024 /// Utility method to write a range of APInt values to a buffer. 1025 template <typename APRangeT> 1026 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1027 APRangeT &&values) { 1028 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); 1029 size_t offset = 0; 1030 for (auto it = values.begin(), e = values.end(); it != e; 1031 ++it, offset += storageWidth) { 1032 assert((*it).getBitWidth() <= storageWidth); 1033 writeBits(data.data(), offset, *it); 1034 } 1035 } 1036 1037 /// Constructs a dense elements attribute from an array of raw APFloat values. 1038 /// Each APFloat value is expected to have the same bitwidth as the element 1039 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1040 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1041 size_t storageWidth, 1042 ArrayRef<APFloat> values, 1043 bool isSplat) { 1044 std::vector<char> data; 1045 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1046 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1047 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1048 } 1049 1050 /// Constructs a dense elements attribute from an array of raw APInt values. 1051 /// Each APInt value is expected to have the same bitwidth as the element type 1052 /// of 'type'. 1053 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1054 size_t storageWidth, 1055 ArrayRef<APInt> values, 1056 bool isSplat) { 1057 std::vector<char> data; 1058 writeAPIntsToBuffer(storageWidth, data, values); 1059 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1060 } 1061 1062 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1063 ArrayRef<char> data, 1064 bool isSplat) { 1065 assert((type.isa<RankedTensorType, VectorType>()) && 1066 "type must be ranked tensor or vector"); 1067 assert(type.hasStaticShape() && "type must have static shape"); 1068 return Base::get(type.getContext(), type, data, isSplat); 1069 } 1070 1071 /// Overload of the raw 'get' method that asserts that the given type is of 1072 /// complex type. This method is used to verify type invariants that the 1073 /// templatized 'get' method cannot. 1074 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1075 ArrayRef<char> data, 1076 int64_t dataEltSize, 1077 bool isInt, 1078 bool isSigned) { 1079 assert(::isValidIntOrFloat( 1080 type.getElementType().cast<ComplexType>().getElementType(), 1081 dataEltSize / 2, isInt, isSigned)); 1082 1083 int64_t numElements = data.size() / dataEltSize; 1084 assert(numElements == 1 || numElements == type.getNumElements()); 1085 return getRaw(type, data, /*isSplat=*/numElements == 1); 1086 } 1087 1088 /// Overload of the 'getRaw' method that asserts that the given type is of 1089 /// integer type. This method is used to verify type invariants that the 1090 /// templatized 'get' method cannot. 1091 DenseElementsAttr 1092 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1093 int64_t dataEltSize, bool isInt, 1094 bool isSigned) { 1095 assert( 1096 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1097 1098 int64_t numElements = data.size() / dataEltSize; 1099 assert(numElements == 1 || numElements == type.getNumElements()); 1100 return getRaw(type, data, /*isSplat=*/numElements == 1); 1101 } 1102 1103 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 1104 const char *inRawData, char *outRawData, size_t elementBitWidth, 1105 size_t numElements) { 1106 using llvm::support::ulittle16_t; 1107 using llvm::support::ulittle32_t; 1108 using llvm::support::ulittle64_t; 1109 1110 assert(llvm::support::endian::system_endianness() == // NOLINT 1111 llvm::support::endianness::big); // NOLINT 1112 // NOLINT to avoid warning message about replacing by static_assert() 1113 1114 // Following std::copy_n always converts endianness on BE machine. 1115 switch (elementBitWidth) { 1116 case 16: { 1117 const ulittle16_t *inRawDataPos = 1118 reinterpret_cast<const ulittle16_t *>(inRawData); 1119 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); 1120 std::copy_n(inRawDataPos, numElements, outDataPos); 1121 break; 1122 } 1123 case 32: { 1124 const ulittle32_t *inRawDataPos = 1125 reinterpret_cast<const ulittle32_t *>(inRawData); 1126 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); 1127 std::copy_n(inRawDataPos, numElements, outDataPos); 1128 break; 1129 } 1130 case 64: { 1131 const ulittle64_t *inRawDataPos = 1132 reinterpret_cast<const ulittle64_t *>(inRawData); 1133 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); 1134 std::copy_n(inRawDataPos, numElements, outDataPos); 1135 break; 1136 } 1137 default: { 1138 size_t nBytes = elementBitWidth / CHAR_BIT; 1139 for (size_t i = 0; i < nBytes; i++) 1140 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); 1141 break; 1142 } 1143 } 1144 } 1145 1146 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1147 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, 1148 ShapedType type) { 1149 size_t numElements = type.getNumElements(); 1150 Type elementType = type.getElementType(); 1151 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1152 elementType = complexTy.getElementType(); 1153 numElements = numElements * 2; 1154 } 1155 size_t elementBitWidth = getDenseElementStorageWidth(elementType); 1156 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && 1157 inRawData.size() <= outRawData.size()); 1158 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), 1159 elementBitWidth, numElements); 1160 } 1161 1162 //===----------------------------------------------------------------------===// 1163 // DenseFPElementsAttr 1164 //===----------------------------------------------------------------------===// 1165 1166 template <typename Fn, typename Attr> 1167 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1168 Type newElementType, 1169 llvm::SmallVectorImpl<char> &data) { 1170 size_t bitWidth = getDenseElementBitWidth(newElementType); 1171 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1172 1173 ShapedType newArrayType; 1174 if (inType.isa<RankedTensorType>()) 1175 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1176 else if (inType.isa<UnrankedTensorType>()) 1177 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1178 else if (inType.isa<VectorType>()) 1179 newArrayType = VectorType::get(inType.getShape(), newElementType); 1180 else 1181 assert(newArrayType && "Unhandled tensor type"); 1182 1183 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1184 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 1185 1186 // Functor used to process a single element value of the attribute. 1187 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1188 auto newInt = mapping(value); 1189 assert(newInt.getBitWidth() == bitWidth); 1190 writeBits(data.data(), index * storageBitWidth, newInt); 1191 }; 1192 1193 // Check for the splat case. 1194 if (attr.isSplat()) { 1195 processElt(*attr.begin(), /*index=*/0); 1196 return newArrayType; 1197 } 1198 1199 // Otherwise, process all of the element values. 1200 uint64_t elementIdx = 0; 1201 for (auto value : attr) 1202 processElt(value, elementIdx++); 1203 return newArrayType; 1204 } 1205 1206 DenseElementsAttr DenseFPElementsAttr::mapValues( 1207 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1208 llvm::SmallVector<char, 8> elementData; 1209 auto newArrayType = 1210 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1211 1212 return getRaw(newArrayType, elementData, isSplat()); 1213 } 1214 1215 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1216 bool DenseFPElementsAttr::classof(Attribute attr) { 1217 return attr.isa<DenseElementsAttr>() && 1218 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1219 } 1220 1221 //===----------------------------------------------------------------------===// 1222 // DenseIntElementsAttr 1223 //===----------------------------------------------------------------------===// 1224 1225 DenseElementsAttr DenseIntElementsAttr::mapValues( 1226 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1227 llvm::SmallVector<char, 8> elementData; 1228 auto newArrayType = 1229 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1230 1231 return getRaw(newArrayType, elementData, isSplat()); 1232 } 1233 1234 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1235 bool DenseIntElementsAttr::classof(Attribute attr) { 1236 return attr.isa<DenseElementsAttr>() && 1237 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1238 } 1239 1240 //===----------------------------------------------------------------------===// 1241 // OpaqueElementsAttr 1242 //===----------------------------------------------------------------------===// 1243 1244 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, 1245 StringRef bytes) { 1246 assert(TensorType::isValidElementType(type.getElementType()) && 1247 "Input element type should be a valid tensor element type"); 1248 return Base::get(type.getContext(), type, dialect, bytes); 1249 } 1250 1251 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } 1252 1253 /// Return the value at the given index. If index does not refer to a valid 1254 /// element, then a null attribute is returned. 1255 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1256 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1257 return Attribute(); 1258 } 1259 1260 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } 1261 1262 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1263 auto *d = getDialect(); 1264 if (!d) 1265 return true; 1266 auto *interface = 1267 d->getRegisteredInterface<DialectDecodeAttributesInterface>(); 1268 if (!interface) 1269 return true; 1270 return failed(interface->decode(*this, result)); 1271 } 1272 1273 //===----------------------------------------------------------------------===// 1274 // SparseElementsAttr 1275 //===----------------------------------------------------------------------===// 1276 1277 SparseElementsAttr SparseElementsAttr::get(ShapedType type, 1278 DenseElementsAttr indices, 1279 DenseElementsAttr values) { 1280 assert(indices.getType().getElementType().isInteger(64) && 1281 "expected sparse indices to be 64-bit integer values"); 1282 assert((type.isa<RankedTensorType, VectorType>()) && 1283 "type must be ranked tensor or vector"); 1284 assert(type.hasStaticShape() && "type must have static shape"); 1285 return Base::get(type.getContext(), type, 1286 indices.cast<DenseIntElementsAttr>(), values); 1287 } 1288 1289 DenseIntElementsAttr SparseElementsAttr::getIndices() const { 1290 return getImpl()->indices; 1291 } 1292 1293 DenseElementsAttr SparseElementsAttr::getValues() const { 1294 return getImpl()->values; 1295 } 1296 1297 /// Return the value of the element at the given index. 1298 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1299 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1300 auto type = getType(); 1301 1302 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1303 // as a 1-D index array. 1304 auto sparseIndices = getIndices(); 1305 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1306 1307 // Check to see if the indices are a splat. 1308 if (sparseIndices.isSplat()) { 1309 // If the index is also not a splat of the index value, we know that the 1310 // value is zero. 1311 auto splatIndex = *sparseIndexValues.begin(); 1312 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1313 return getZeroAttr(); 1314 1315 // If the indices are a splat, we also expect the values to be a splat. 1316 assert(getValues().isSplat() && "expected splat values"); 1317 return getValues().getSplatValue(); 1318 } 1319 1320 // Build a mapping between known indices and the offset of the stored element. 1321 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1322 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1323 size_t rank = type.getRank(); 1324 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1325 mappedIndices.try_emplace( 1326 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1327 1328 // Look for the provided index key within the mapped indices. If the provided 1329 // index is not found, then return a zero attribute. 1330 auto it = mappedIndices.find(index); 1331 if (it == mappedIndices.end()) 1332 return getZeroAttr(); 1333 1334 // Otherwise, return the held sparse value element. 1335 return getValues().getValue(it->second); 1336 } 1337 1338 /// Get a zero APFloat for the given sparse attribute. 1339 APFloat SparseElementsAttr::getZeroAPFloat() const { 1340 auto eltType = getType().getElementType().cast<FloatType>(); 1341 return APFloat(eltType.getFloatSemantics()); 1342 } 1343 1344 /// Get a zero APInt for the given sparse attribute. 1345 APInt SparseElementsAttr::getZeroAPInt() const { 1346 auto eltType = getType().getElementType().cast<IntegerType>(); 1347 return APInt::getNullValue(eltType.getWidth()); 1348 } 1349 1350 /// Get a zero attribute for the given attribute type. 1351 Attribute SparseElementsAttr::getZeroAttr() const { 1352 auto eltType = getType().getElementType(); 1353 1354 // Handle floating point elements. 1355 if (eltType.isa<FloatType>()) 1356 return FloatAttr::get(eltType, 0); 1357 1358 // Otherwise, this is an integer. 1359 // TODO: Handle StringAttr here. 1360 return IntegerAttr::get(eltType, 0); 1361 } 1362 1363 /// Flatten, and return, all of the sparse indices in this attribute in 1364 /// row-major order. 1365 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1366 std::vector<ptrdiff_t> flatSparseIndices; 1367 1368 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1369 // as a 1-D index array. 1370 auto sparseIndices = getIndices(); 1371 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1372 if (sparseIndices.isSplat()) { 1373 SmallVector<uint64_t, 8> indices(getType().getRank(), 1374 *sparseIndexValues.begin()); 1375 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1376 return flatSparseIndices; 1377 } 1378 1379 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1380 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1381 size_t rank = getType().getRank(); 1382 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1383 flatSparseIndices.push_back(getFlattenedIndex( 1384 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1385 return flatSparseIndices; 1386 } 1387