1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/BuiltinDialect.h" 13 #include "mlir/IR/Diagnostics.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/IR/Types.h" 17 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" 18 #include "llvm/ADT/APSInt.h" 19 #include "llvm/ADT/Sequence.h" 20 #include "llvm/ADT/Twine.h" 21 #include "llvm/Support/Endian.h" 22 23 using namespace mlir; 24 using namespace mlir::detail; 25 26 //===----------------------------------------------------------------------===// 27 /// Tablegen Attribute Definitions 28 //===----------------------------------------------------------------------===// 29 30 #define GET_ATTRDEF_CLASSES 31 #include "mlir/IR/BuiltinAttributes.cpp.inc" 32 33 //===----------------------------------------------------------------------===// 34 // BuiltinDialect 35 //===----------------------------------------------------------------------===// 36 37 void BuiltinDialect::registerAttributes() { 38 addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr, 39 DenseStringElementsAttr, DictionaryAttr, FloatAttr, 40 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr, 41 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr, 42 UnitAttr>(); 43 } 44 45 //===----------------------------------------------------------------------===// 46 // ArrayAttr 47 //===----------------------------------------------------------------------===// 48 49 void ArrayAttr::walkImmediateSubElements( 50 function_ref<void(Attribute)> walkAttrsFn, 51 function_ref<void(Type)> walkTypesFn) const { 52 for (Attribute attr : getValue()) 53 walkAttrsFn(attr); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // DictionaryAttr 58 //===----------------------------------------------------------------------===// 59 60 /// Helper function that does either an in place sort or sorts from source array 61 /// into destination. If inPlace then storage is both the source and the 62 /// destination, else value is the source and storage destination. Returns 63 /// whether source was sorted. 64 template <bool inPlace> 65 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 66 SmallVectorImpl<NamedAttribute> &storage) { 67 // Specialize for the common case. 68 switch (value.size()) { 69 case 0: 70 // Zero already sorted. 71 break; 72 case 1: 73 // One already sorted but may need to be copied. 74 if (!inPlace) 75 storage.assign({value[0]}); 76 break; 77 case 2: { 78 bool isSorted = value[0] < value[1]; 79 if (inPlace) { 80 if (!isSorted) 81 std::swap(storage[0], storage[1]); 82 } else if (isSorted) { 83 storage.assign({value[0], value[1]}); 84 } else { 85 storage.assign({value[1], value[0]}); 86 } 87 return !isSorted; 88 } 89 default: 90 if (!inPlace) 91 storage.assign(value.begin(), value.end()); 92 // Check to see they are sorted already. 93 bool isSorted = llvm::is_sorted(value); 94 // If not, do a general sort. 95 if (!isSorted) 96 llvm::array_pod_sort(storage.begin(), storage.end()); 97 return !isSorted; 98 } 99 return false; 100 } 101 102 /// Returns an entry with a duplicate name from the given sorted array of named 103 /// attributes. Returns llvm::None if all elements have unique names. 104 static Optional<NamedAttribute> 105 findDuplicateElement(ArrayRef<NamedAttribute> value) { 106 const Optional<NamedAttribute> none{llvm::None}; 107 if (value.size() < 2) 108 return none; 109 110 if (value.size() == 2) 111 return value[0].first == value[1].first ? value[0] : none; 112 113 auto it = std::adjacent_find( 114 value.begin(), value.end(), 115 [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; }); 116 return it != value.end() ? *it : none; 117 } 118 119 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 120 SmallVectorImpl<NamedAttribute> &storage) { 121 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); 122 assert(!findDuplicateElement(storage) && 123 "DictionaryAttr element names must be unique"); 124 return isSorted; 125 } 126 127 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 128 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); 129 assert(!findDuplicateElement(array) && 130 "DictionaryAttr element names must be unique"); 131 return isSorted; 132 } 133 134 Optional<NamedAttribute> 135 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, 136 bool isSorted) { 137 if (!isSorted) 138 dictionaryAttrSort</*inPlace=*/true>(array, array); 139 return findDuplicateElement(array); 140 } 141 142 DictionaryAttr DictionaryAttr::get(MLIRContext *context, 143 ArrayRef<NamedAttribute> value) { 144 if (value.empty()) 145 return DictionaryAttr::getEmpty(context); 146 assert(llvm::all_of(value, 147 [](const NamedAttribute &attr) { return attr.second; }) && 148 "value cannot have null entries"); 149 150 // We need to sort the element list to canonicalize it. 151 SmallVector<NamedAttribute, 8> storage; 152 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 153 value = storage; 154 assert(!findDuplicateElement(value) && 155 "DictionaryAttr element names must be unique"); 156 return Base::get(context, value); 157 } 158 /// Construct a dictionary with an array of values that is known to already be 159 /// sorted by name and uniqued. 160 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context, 161 ArrayRef<NamedAttribute> value) { 162 if (value.empty()) 163 return DictionaryAttr::getEmpty(context); 164 // Ensure that the attribute elements are unique and sorted. 165 assert(llvm::is_sorted(value, 166 [](NamedAttribute l, NamedAttribute r) { 167 return l.first.strref() < r.first.strref(); 168 }) && 169 "expected attribute values to be sorted"); 170 assert(!findDuplicateElement(value) && 171 "DictionaryAttr element names must be unique"); 172 return Base::get(context, value); 173 } 174 175 /// Return the specified attribute if present, null otherwise. 176 Attribute DictionaryAttr::get(StringRef name) const { 177 Optional<NamedAttribute> attr = getNamed(name); 178 return attr ? attr->second : nullptr; 179 } 180 Attribute DictionaryAttr::get(Identifier name) const { 181 Optional<NamedAttribute> attr = getNamed(name); 182 return attr ? attr->second : nullptr; 183 } 184 185 /// Return the specified named attribute if present, None otherwise. 186 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 187 ArrayRef<NamedAttribute> values = getValue(); 188 const auto *it = llvm::lower_bound(values, name); 189 return it != values.end() && it->first == name ? *it 190 : Optional<NamedAttribute>(); 191 } 192 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { 193 for (auto elt : getValue()) 194 if (elt.first == name) 195 return elt; 196 return llvm::None; 197 } 198 199 DictionaryAttr::iterator DictionaryAttr::begin() const { 200 return getValue().begin(); 201 } 202 DictionaryAttr::iterator DictionaryAttr::end() const { 203 return getValue().end(); 204 } 205 size_t DictionaryAttr::size() const { return getValue().size(); } 206 207 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { 208 return Base::get(context, ArrayRef<NamedAttribute>()); 209 } 210 211 void DictionaryAttr::walkImmediateSubElements( 212 function_ref<void(Attribute)> walkAttrsFn, 213 function_ref<void(Type)> walkTypesFn) const { 214 for (Attribute attr : llvm::make_second_range(getValue())) 215 walkAttrsFn(attr); 216 } 217 218 //===----------------------------------------------------------------------===// 219 // StringAttr 220 //===----------------------------------------------------------------------===// 221 222 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) { 223 return Base::get(context, "", NoneType::get(context)); 224 } 225 226 /// Twine support for StringAttr. 227 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) { 228 // Fast-path empty twine. 229 if (twine.isTriviallyEmpty()) 230 return get(context); 231 SmallVector<char, 32> tempStr; 232 return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context)); 233 } 234 235 /// Twine support for StringAttr. 236 StringAttr StringAttr::get(const Twine &twine, Type type) { 237 SmallVector<char, 32> tempStr; 238 return Base::get(type.getContext(), twine.toStringRef(tempStr), type); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // FloatAttr 243 //===----------------------------------------------------------------------===// 244 245 double FloatAttr::getValueAsDouble() const { 246 return getValueAsDouble(getValue()); 247 } 248 double FloatAttr::getValueAsDouble(APFloat value) { 249 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 250 bool losesInfo = false; 251 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 252 &losesInfo); 253 } 254 return value.convertToDouble(); 255 } 256 257 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 258 Type type, APFloat value) { 259 // Verify that the type is correct. 260 if (!type.isa<FloatType>()) 261 return emitError() << "expected floating point type"; 262 263 // Verify that the type semantics match that of the value. 264 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 265 return emitError() 266 << "FloatAttr type doesn't match the type implied by its value"; 267 } 268 return success(); 269 } 270 271 //===----------------------------------------------------------------------===// 272 // SymbolRefAttr 273 //===----------------------------------------------------------------------===// 274 275 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { 276 return get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>(); 277 } 278 279 StringRef SymbolRefAttr::getLeafReference() const { 280 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 281 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 282 } 283 284 //===----------------------------------------------------------------------===// 285 // IntegerAttr 286 //===----------------------------------------------------------------------===// 287 288 int64_t IntegerAttr::getInt() const { 289 assert((getType().isIndex() || getType().isSignlessInteger()) && 290 "must be signless integer"); 291 return getValue().getSExtValue(); 292 } 293 294 int64_t IntegerAttr::getSInt() const { 295 assert(getType().isSignedInteger() && "must be signed integer"); 296 return getValue().getSExtValue(); 297 } 298 299 uint64_t IntegerAttr::getUInt() const { 300 assert(getType().isUnsignedInteger() && "must be unsigned integer"); 301 return getValue().getZExtValue(); 302 } 303 304 /// Return the value as an APSInt which carries the signed from the type of 305 /// the attribute. This traps on signless integers types! 306 APSInt IntegerAttr::getAPSInt() const { 307 assert(!getType().isSignlessInteger() && 308 "Signless integers don't carry a sign for APSInt"); 309 return APSInt(getValue(), getType().isUnsignedInteger()); 310 } 311 312 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, 313 Type type, APInt value) { 314 if (IntegerType integerType = type.dyn_cast<IntegerType>()) { 315 if (integerType.getWidth() != value.getBitWidth()) 316 return emitError() << "integer type bit width (" << integerType.getWidth() 317 << ") doesn't match value bit width (" 318 << value.getBitWidth() << ")"; 319 return success(); 320 } 321 if (type.isa<IndexType>()) 322 return success(); 323 return emitError() << "expected integer or index type"; 324 } 325 326 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { 327 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value)); 328 return attr.cast<BoolAttr>(); 329 } 330 331 //===----------------------------------------------------------------------===// 332 // BoolAttr 333 334 bool BoolAttr::getValue() const { 335 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl); 336 return storage->value.getBoolValue(); 337 } 338 339 bool BoolAttr::classof(Attribute attr) { 340 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 341 return intAttr && intAttr.getType().isSignlessInteger(1); 342 } 343 344 //===----------------------------------------------------------------------===// 345 // OpaqueAttr 346 //===----------------------------------------------------------------------===// 347 348 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError, 349 Identifier dialect, StringRef attrData, 350 Type type) { 351 if (!Dialect::isValidNamespace(dialect.strref())) 352 return emitError() << "invalid dialect namespace '" << dialect << "'"; 353 354 // Check that the dialect is actually registered. 355 MLIRContext *context = dialect.getContext(); 356 if (!context->allowsUnregisteredDialects() && 357 !context->getLoadedDialect(dialect.strref())) { 358 return emitError() 359 << "#" << dialect << "<\"" << attrData << "\"> : " << type 360 << " attribute created with unregistered dialect. If this is " 361 "intended, please call allowUnregisteredDialects() on the " 362 "MLIRContext, or use -allow-unregistered-dialect with " 363 "the MLIR opt tool used"; 364 } 365 366 return success(); 367 } 368 369 //===----------------------------------------------------------------------===// 370 // ElementsAttr 371 //===----------------------------------------------------------------------===// 372 373 ShapedType ElementsAttr::getType() const { 374 return Attribute::getType().cast<ShapedType>(); 375 } 376 377 /// Returns the number of elements held by this attribute. 378 int64_t ElementsAttr::getNumElements() const { 379 return getType().getNumElements(); 380 } 381 382 /// Return the value at the given index. If index does not refer to a valid 383 /// element, then a null attribute is returned. 384 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 385 if (auto denseAttr = dyn_cast<DenseElementsAttr>()) 386 return denseAttr.getValue(index); 387 if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>()) 388 return opaqueAttr.getValue(index); 389 return cast<SparseElementsAttr>().getValue(index); 390 } 391 392 /// Return if the given 'index' refers to a valid element in this attribute. 393 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 394 auto type = getType(); 395 396 // Verify that the rank of the indices matches the held type. 397 auto rank = type.getRank(); 398 if (rank == 0 && index.size() == 1 && index[0] == 0) 399 return true; 400 if (rank != static_cast<int64_t>(index.size())) 401 return false; 402 403 // Verify that all of the indices are within the shape dimensions. 404 auto shape = type.getShape(); 405 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 406 int64_t dim = static_cast<int64_t>(index[i]); 407 return 0 <= dim && dim < shape[i]; 408 }); 409 } 410 411 ElementsAttr 412 ElementsAttr::mapValues(Type newElementType, 413 function_ref<APInt(const APInt &)> mapping) const { 414 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 415 return intOrFpAttr.mapValues(newElementType, mapping); 416 llvm_unreachable("unsupported ElementsAttr subtype"); 417 } 418 419 ElementsAttr 420 ElementsAttr::mapValues(Type newElementType, 421 function_ref<APInt(const APFloat &)> mapping) const { 422 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 423 return intOrFpAttr.mapValues(newElementType, mapping); 424 llvm_unreachable("unsupported ElementsAttr subtype"); 425 } 426 427 /// Method for support type inquiry through isa, cast and dyn_cast. 428 bool ElementsAttr::classof(Attribute attr) { 429 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr, 430 OpaqueElementsAttr, SparseElementsAttr>(); 431 } 432 433 /// Returns the 1 dimensional flattened row-major index from the given 434 /// multi-dimensional index. 435 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 436 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 437 auto type = getType(); 438 439 // Reduce the provided multidimensional index into a flattended 1D row-major 440 // index. 441 auto rank = type.getRank(); 442 auto shape = type.getShape(); 443 uint64_t valueIndex = 0; 444 uint64_t dimMultiplier = 1; 445 for (int i = rank - 1; i >= 0; --i) { 446 valueIndex += index[i] * dimMultiplier; 447 dimMultiplier *= shape[i]; 448 } 449 return valueIndex; 450 } 451 452 //===----------------------------------------------------------------------===// 453 // DenseElementsAttr Utilities 454 //===----------------------------------------------------------------------===// 455 456 /// Get the bitwidth of a dense element type within the buffer. 457 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 458 static size_t getDenseElementStorageWidth(size_t origWidth) { 459 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 460 } 461 static size_t getDenseElementStorageWidth(Type elementType) { 462 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 463 } 464 465 /// Set a bit to a specific value. 466 static void setBit(char *rawData, size_t bitPos, bool value) { 467 if (value) 468 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 469 else 470 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 471 } 472 473 /// Return the value of the specified bit. 474 static bool getBit(const char *rawData, size_t bitPos) { 475 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 476 } 477 478 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for 479 /// BE format. 480 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, 481 char *result) { 482 assert(llvm::support::endian::system_endianness() == // NOLINT 483 llvm::support::endianness::big); // NOLINT 484 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 485 486 // Copy the words filled with data. 487 // For example, when `value` has 2 words, the first word is filled with data. 488 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| 489 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 490 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 491 numFilledWords, result); 492 // Convert last word of APInt to LE format and store it in char 493 // array(`valueLE`). 494 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| 495 size_t lastWordPos = numFilledWords; 496 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); 497 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 498 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, 499 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); 500 // Extract actual APInt data from `valueLE`, convert endianness to BE format, 501 // and store it in `result`. 502 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| 503 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 504 valueLE.begin(), result + lastWordPos, 505 (numBytes - lastWordPos) * CHAR_BIT, 1); 506 } 507 508 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE 509 /// format. 510 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, 511 APInt &result) { 512 assert(llvm::support::endian::system_endianness() == // NOLINT 513 llvm::support::endianness::big); // NOLINT 514 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 515 516 // Copy the data that fills the word of `result` from `inArray`. 517 // For example, when `result` has 2 words, the first word will be filled with 518 // data. So, the first 8 bytes are copied from `inArray` here. 519 // `inArray` (10 bytes, BE): |abcdefgh|ij| 520 // ==> `result` (2 words, BE): |abcdefgh|--------| 521 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 522 std::copy_n( 523 inArray, numFilledWords, 524 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 525 526 // Convert array data which will be last word of `result` to LE format, and 527 // store it in char array(`inArrayLE`). 528 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| 529 size_t lastWordPos = numFilledWords; 530 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); 531 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 532 inArray + lastWordPos, inArrayLE.begin(), 533 (numBytes - lastWordPos) * CHAR_BIT, 1); 534 535 // Convert `inArrayLE` to BE format, and store it in last word of `result`. 536 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| 537 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 538 inArrayLE.begin(), 539 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + 540 lastWordPos, 541 APInt::APINT_BITS_PER_WORD, 1); 542 } 543 544 /// Writes value to the bit position `bitPos` in array `rawData`. 545 static void writeBits(char *rawData, size_t bitPos, APInt value) { 546 size_t bitWidth = value.getBitWidth(); 547 548 // If the bitwidth is 1 we just toggle the specific bit. 549 if (bitWidth == 1) 550 return setBit(rawData, bitPos, value.isOneValue()); 551 552 // Otherwise, the bit position is guaranteed to be byte aligned. 553 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 554 if (llvm::support::endian::system_endianness() == 555 llvm::support::endianness::big) { 556 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. 557 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 558 // work correctly in BE format. 559 // ex. `value` (2 words including 10 bytes) 560 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| 561 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), 562 rawData + (bitPos / CHAR_BIT)); 563 } else { 564 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 565 llvm::divideCeil(bitWidth, CHAR_BIT), 566 rawData + (bitPos / CHAR_BIT)); 567 } 568 } 569 570 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 571 /// `rawData`. 572 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 573 // Handle a boolean bit position. 574 if (bitWidth == 1) 575 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 576 577 // Otherwise, the bit position must be 8-bit aligned. 578 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 579 APInt result(bitWidth, 0); 580 if (llvm::support::endian::system_endianness() == 581 llvm::support::endianness::big) { 582 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. 583 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 584 // work correctly in BE format. 585 // ex. `result` (2 words including 10 bytes) 586 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function 587 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), 588 llvm::divideCeil(bitWidth, CHAR_BIT), result); 589 } else { 590 std::copy_n(rawData + (bitPos / CHAR_BIT), 591 llvm::divideCeil(bitWidth, CHAR_BIT), 592 const_cast<char *>( 593 reinterpret_cast<const char *>(result.getRawData()))); 594 } 595 return result; 596 } 597 598 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 599 /// the same element count as 'type'. 600 template <typename Values> 601 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 602 return (values.size() == 1) || 603 (type.getNumElements() == static_cast<int64_t>(values.size())); 604 } 605 606 //===----------------------------------------------------------------------===// 607 // DenseElementsAttr Iterators 608 //===----------------------------------------------------------------------===// 609 610 //===----------------------------------------------------------------------===// 611 // AttributeElementIterator 612 613 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 614 DenseElementsAttr attr, size_t index) 615 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 616 Attribute, Attribute, Attribute>( 617 attr.getAsOpaquePointer(), index) {} 618 619 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 620 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 621 Type eltTy = owner.getType().getElementType(); 622 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 623 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 624 if (eltTy.isa<IndexType>()) 625 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 626 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 627 IntElementIterator intIt(owner, index); 628 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 629 return FloatAttr::get(eltTy, *floatIt); 630 } 631 if (auto complexTy = eltTy.dyn_cast<ComplexType>()) { 632 auto complexEltTy = complexTy.getElementType(); 633 ComplexIntElementIterator complexIntIt(owner, index); 634 if (complexEltTy.isa<IntegerType>()) { 635 auto value = *complexIntIt; 636 auto real = IntegerAttr::get(complexEltTy, value.real()); 637 auto imag = IntegerAttr::get(complexEltTy, value.imag()); 638 return ArrayAttr::get(complexTy.getContext(), 639 ArrayRef<Attribute>{real, imag}); 640 } 641 642 ComplexFloatElementIterator complexFloatIt( 643 complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt); 644 auto value = *complexFloatIt; 645 auto real = FloatAttr::get(complexEltTy, value.real()); 646 auto imag = FloatAttr::get(complexEltTy, value.imag()); 647 return ArrayAttr::get(complexTy.getContext(), 648 ArrayRef<Attribute>{real, imag}); 649 } 650 if (owner.isa<DenseStringElementsAttr>()) { 651 ArrayRef<StringRef> vals = owner.getRawStringData(); 652 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 653 } 654 llvm_unreachable("unexpected element type"); 655 } 656 657 //===----------------------------------------------------------------------===// 658 // BoolElementIterator 659 660 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 661 DenseElementsAttr attr, size_t dataIndex) 662 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 663 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 664 665 bool DenseElementsAttr::BoolElementIterator::operator*() const { 666 return getBit(getData(), getDataIndex()); 667 } 668 669 //===----------------------------------------------------------------------===// 670 // IntElementIterator 671 672 DenseElementsAttr::IntElementIterator::IntElementIterator( 673 DenseElementsAttr attr, size_t dataIndex) 674 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 675 attr.getRawData().data(), attr.isSplat(), dataIndex), 676 bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} 677 678 APInt DenseElementsAttr::IntElementIterator::operator*() const { 679 return readBits(getData(), 680 getDataIndex() * getDenseElementStorageWidth(bitWidth), 681 bitWidth); 682 } 683 684 //===----------------------------------------------------------------------===// 685 // ComplexIntElementIterator 686 687 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 688 DenseElementsAttr attr, size_t dataIndex) 689 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 690 std::complex<APInt>, std::complex<APInt>, 691 std::complex<APInt>>( 692 attr.getRawData().data(), attr.isSplat(), dataIndex) { 693 auto complexType = attr.getType().getElementType().cast<ComplexType>(); 694 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 695 } 696 697 std::complex<APInt> 698 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 699 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 700 size_t offset = getDataIndex() * storageWidth * 2; 701 return {readBits(getData(), offset, bitWidth), 702 readBits(getData(), offset + storageWidth, bitWidth)}; 703 } 704 705 //===----------------------------------------------------------------------===// 706 // FloatElementIterator 707 708 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 709 const llvm::fltSemantics &smt, IntElementIterator it) 710 : llvm::mapped_iterator<IntElementIterator, 711 std::function<APFloat(const APInt &)>>( 712 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 713 714 //===----------------------------------------------------------------------===// 715 // ComplexFloatElementIterator 716 717 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( 718 const llvm::fltSemantics &smt, ComplexIntElementIterator it) 719 : llvm::mapped_iterator< 720 ComplexIntElementIterator, 721 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( 722 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { 723 return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; 724 }) {} 725 726 //===----------------------------------------------------------------------===// 727 // DenseElementsAttr 728 //===----------------------------------------------------------------------===// 729 730 /// Method for support type inquiry through isa, cast and dyn_cast. 731 bool DenseElementsAttr::classof(Attribute attr) { 732 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 733 } 734 735 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 736 ArrayRef<Attribute> values) { 737 assert(hasSameElementsOrSplat(type, values)); 738 739 // If the element type is not based on int/float/index, assume it is a string 740 // type. 741 auto eltType = type.getElementType(); 742 if (!type.getElementType().isIntOrIndexOrFloat()) { 743 SmallVector<StringRef, 8> stringValues; 744 stringValues.reserve(values.size()); 745 for (Attribute attr : values) { 746 assert(attr.isa<StringAttr>() && 747 "expected string value for non integer/index/float element"); 748 stringValues.push_back(attr.cast<StringAttr>().getValue()); 749 } 750 return get(type, stringValues); 751 } 752 753 // Otherwise, get the raw storage width to use for the allocation. 754 size_t bitWidth = getDenseElementBitWidth(eltType); 755 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 756 757 // Compress the attribute values into a character buffer. 758 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 759 values.size()); 760 APInt intVal; 761 for (unsigned i = 0, e = values.size(); i < e; ++i) { 762 assert(eltType == values[i].getType() && 763 "expected attribute value to have element type"); 764 if (eltType.isa<FloatType>()) 765 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 766 else if (eltType.isa<IntegerType, IndexType>()) 767 intVal = values[i].cast<IntegerAttr>().getValue(); 768 else 769 llvm_unreachable("unexpected element type"); 770 771 assert(intVal.getBitWidth() == bitWidth && 772 "expected value to have same bitwidth as element type"); 773 writeBits(data.data(), i * storageBitWidth, intVal); 774 } 775 return DenseIntOrFPElementsAttr::getRaw(type, data, 776 /*isSplat=*/(values.size() == 1)); 777 } 778 779 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 780 ArrayRef<bool> values) { 781 assert(hasSameElementsOrSplat(type, values)); 782 assert(type.getElementType().isInteger(1)); 783 784 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 785 for (int i = 0, e = values.size(); i != e; ++i) 786 setBit(buff.data(), i, values[i]); 787 return DenseIntOrFPElementsAttr::getRaw(type, buff, 788 /*isSplat=*/(values.size() == 1)); 789 } 790 791 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 792 ArrayRef<StringRef> values) { 793 assert(!type.getElementType().isIntOrFloat()); 794 return DenseStringElementsAttr::get(type, values); 795 } 796 797 /// Constructs a dense integer elements attribute from an array of APInt 798 /// values. Each APInt value is expected to have the same bitwidth as the 799 /// element type of 'type'. 800 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 801 ArrayRef<APInt> values) { 802 assert(type.getElementType().isIntOrIndex()); 803 assert(hasSameElementsOrSplat(type, values)); 804 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 805 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 806 /*isSplat=*/(values.size() == 1)); 807 } 808 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 809 ArrayRef<std::complex<APInt>> values) { 810 ComplexType complex = type.getElementType().cast<ComplexType>(); 811 assert(complex.getElementType().isa<IntegerType>()); 812 assert(hasSameElementsOrSplat(type, values)); 813 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 814 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 815 values.size() * 2); 816 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, 817 /*isSplat=*/(values.size() == 1)); 818 } 819 820 // Constructs a dense float elements attribute from an array of APFloat 821 // values. Each APFloat value is expected to have the same bitwidth as the 822 // element type of 'type'. 823 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 824 ArrayRef<APFloat> values) { 825 assert(type.getElementType().isa<FloatType>()); 826 assert(hasSameElementsOrSplat(type, values)); 827 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 828 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 829 /*isSplat=*/(values.size() == 1)); 830 } 831 DenseElementsAttr 832 DenseElementsAttr::get(ShapedType type, 833 ArrayRef<std::complex<APFloat>> values) { 834 ComplexType complex = type.getElementType().cast<ComplexType>(); 835 assert(complex.getElementType().isa<FloatType>()); 836 assert(hasSameElementsOrSplat(type, values)); 837 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 838 values.size() * 2); 839 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 840 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, 841 /*isSplat=*/(values.size() == 1)); 842 } 843 844 /// Construct a dense elements attribute from a raw buffer representing the 845 /// data for this attribute. Users should generally not use this methods as 846 /// the expected buffer format may not be a form the user expects. 847 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 848 ArrayRef<char> rawBuffer, 849 bool isSplatBuffer) { 850 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 851 } 852 853 /// Returns true if the given buffer is a valid raw buffer for the given type. 854 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 855 ArrayRef<char> rawBuffer, 856 bool &detectedSplat) { 857 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 858 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 859 860 // Storage width of 1 is special as it is packed by the bit. 861 if (storageWidth == 1) { 862 // Check for a splat, or a buffer equal to the number of elements. 863 if ((detectedSplat = rawBuffer.size() == 1)) 864 return true; 865 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); 866 } 867 // All other types are 8-bit aligned. 868 if ((detectedSplat = rawBufferWidth == storageWidth)) 869 return true; 870 return rawBufferWidth == (storageWidth * type.getNumElements()); 871 } 872 873 /// Check the information for a C++ data type, check if this type is valid for 874 /// the current attribute. This method is used to verify specific type 875 /// invariants that the templatized 'getValues' method cannot. 876 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 877 bool isSigned) { 878 // Make sure that the data element size is the same as the type element width. 879 if (getDenseElementBitWidth(type) != 880 static_cast<size_t>(dataEltSize * CHAR_BIT)) 881 return false; 882 883 // Check that the element type is either float or integer or index. 884 if (!isInt) 885 return type.isa<FloatType>(); 886 if (type.isIndex()) 887 return true; 888 889 auto intType = type.dyn_cast<IntegerType>(); 890 if (!intType) 891 return false; 892 893 // Make sure signedness semantics is consistent. 894 if (intType.isSignless()) 895 return true; 896 return intType.isSigned() ? isSigned : !isSigned; 897 } 898 899 /// Defaults down the subclass implementation. 900 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 901 ArrayRef<char> data, 902 int64_t dataEltSize, 903 bool isInt, bool isSigned) { 904 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 905 isSigned); 906 } 907 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 908 ArrayRef<char> data, 909 int64_t dataEltSize, 910 bool isInt, 911 bool isSigned) { 912 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 913 isInt, isSigned); 914 } 915 916 /// A method used to verify specific type invariants that the templatized 'get' 917 /// method cannot. 918 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 919 bool isSigned) const { 920 return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt, 921 isSigned); 922 } 923 924 /// Check the information for a C++ data type, check if this type is valid for 925 /// the current attribute. 926 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 927 bool isSigned) const { 928 return ::isValidIntOrFloat( 929 getType().getElementType().cast<ComplexType>().getElementType(), 930 dataEltSize / 2, isInt, isSigned); 931 } 932 933 /// Returns true if this attribute corresponds to a splat, i.e. if all element 934 /// values are the same. 935 bool DenseElementsAttr::isSplat() const { 936 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 937 } 938 939 /// Return the held element values as a range of Attributes. 940 auto DenseElementsAttr::getAttributeValues() const 941 -> llvm::iterator_range<AttributeElementIterator> { 942 return {attr_value_begin(), attr_value_end()}; 943 } 944 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 945 return AttributeElementIterator(*this, 0); 946 } 947 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 948 return AttributeElementIterator(*this, getNumElements()); 949 } 950 951 /// Return the held element values as a range of bool. The element type of 952 /// this attribute must be of integer type of bitwidth 1. 953 auto DenseElementsAttr::getBoolValues() const 954 -> llvm::iterator_range<BoolElementIterator> { 955 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 956 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 957 (void)eltType; 958 return {BoolElementIterator(*this, 0), 959 BoolElementIterator(*this, getNumElements())}; 960 } 961 962 /// Return the held element values as a range of APInts. The element type of 963 /// this attribute must be of integer type. 964 auto DenseElementsAttr::getIntValues() const 965 -> llvm::iterator_range<IntElementIterator> { 966 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 967 return {raw_int_begin(), raw_int_end()}; 968 } 969 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 970 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 971 return raw_int_begin(); 972 } 973 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 974 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 975 return raw_int_end(); 976 } 977 auto DenseElementsAttr::getComplexIntValues() const 978 -> llvm::iterator_range<ComplexIntElementIterator> { 979 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 980 (void)eltTy; 981 assert(eltTy.isa<IntegerType>() && "expected complex integral type"); 982 return {ComplexIntElementIterator(*this, 0), 983 ComplexIntElementIterator(*this, getNumElements())}; 984 } 985 986 /// Return the held element values as a range of APFloat. The element type of 987 /// this attribute must be of float type. 988 auto DenseElementsAttr::getFloatValues() const 989 -> llvm::iterator_range<FloatElementIterator> { 990 auto elementType = getType().getElementType().cast<FloatType>(); 991 const auto &elementSemantics = elementType.getFloatSemantics(); 992 return {FloatElementIterator(elementSemantics, raw_int_begin()), 993 FloatElementIterator(elementSemantics, raw_int_end())}; 994 } 995 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 996 return getFloatValues().begin(); 997 } 998 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 999 return getFloatValues().end(); 1000 } 1001 auto DenseElementsAttr::getComplexFloatValues() const 1002 -> llvm::iterator_range<ComplexFloatElementIterator> { 1003 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 1004 assert(eltTy.isa<FloatType>() && "expected complex float type"); 1005 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 1006 return {{semantics, {*this, 0}}, 1007 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 1008 } 1009 1010 /// Return the raw storage data held by this attribute. 1011 ArrayRef<char> DenseElementsAttr::getRawData() const { 1012 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data; 1013 } 1014 1015 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 1016 return static_cast<DenseStringElementsAttrStorage *>(impl)->data; 1017 } 1018 1019 /// Return a new DenseElementsAttr that has the same data as the current 1020 /// attribute, but has been reshaped to 'newType'. The new type must have the 1021 /// same total number of elements as well as element type. 1022 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 1023 ShapedType curType = getType(); 1024 if (curType == newType) 1025 return *this; 1026 1027 assert(newType.getElementType() == curType.getElementType() && 1028 "expected the same element type"); 1029 assert(newType.getNumElements() == curType.getNumElements() && 1030 "expected the same number of elements"); 1031 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 1032 } 1033 1034 /// Return a new DenseElementsAttr that has the same data as the current 1035 /// attribute, but has bitcast elements such that it is now 'newType'. The new 1036 /// type must have the same shape and element types of the same bitwidth as the 1037 /// current type. 1038 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) { 1039 ShapedType curType = getType(); 1040 Type curElType = curType.getElementType(); 1041 if (curElType == newElType) 1042 return *this; 1043 1044 assert(getDenseElementBitWidth(newElType) == 1045 getDenseElementBitWidth(curElType) && 1046 "expected element types with the same bitwidth"); 1047 return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType), 1048 getRawData(), isSplat()); 1049 } 1050 1051 DenseElementsAttr 1052 DenseElementsAttr::mapValues(Type newElementType, 1053 function_ref<APInt(const APInt &)> mapping) const { 1054 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 1055 } 1056 1057 DenseElementsAttr DenseElementsAttr::mapValues( 1058 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1059 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 1060 } 1061 1062 //===----------------------------------------------------------------------===// 1063 // DenseIntOrFPElementsAttr 1064 //===----------------------------------------------------------------------===// 1065 1066 /// Utility method to write a range of APInt values to a buffer. 1067 template <typename APRangeT> 1068 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1069 APRangeT &&values) { 1070 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); 1071 size_t offset = 0; 1072 for (auto it = values.begin(), e = values.end(); it != e; 1073 ++it, offset += storageWidth) { 1074 assert((*it).getBitWidth() <= storageWidth); 1075 writeBits(data.data(), offset, *it); 1076 } 1077 } 1078 1079 /// Constructs a dense elements attribute from an array of raw APFloat values. 1080 /// Each APFloat value is expected to have the same bitwidth as the element 1081 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1082 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1083 size_t storageWidth, 1084 ArrayRef<APFloat> values, 1085 bool isSplat) { 1086 std::vector<char> data; 1087 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1088 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1089 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1090 } 1091 1092 /// Constructs a dense elements attribute from an array of raw APInt values. 1093 /// Each APInt value is expected to have the same bitwidth as the element type 1094 /// of 'type'. 1095 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1096 size_t storageWidth, 1097 ArrayRef<APInt> values, 1098 bool isSplat) { 1099 std::vector<char> data; 1100 writeAPIntsToBuffer(storageWidth, data, values); 1101 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1102 } 1103 1104 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1105 ArrayRef<char> data, 1106 bool isSplat) { 1107 assert((type.isa<RankedTensorType, VectorType>()) && 1108 "type must be ranked tensor or vector"); 1109 assert(type.hasStaticShape() && "type must have static shape"); 1110 return Base::get(type.getContext(), type, data, isSplat); 1111 } 1112 1113 /// Overload of the raw 'get' method that asserts that the given type is of 1114 /// complex type. This method is used to verify type invariants that the 1115 /// templatized 'get' method cannot. 1116 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1117 ArrayRef<char> data, 1118 int64_t dataEltSize, 1119 bool isInt, 1120 bool isSigned) { 1121 assert(::isValidIntOrFloat( 1122 type.getElementType().cast<ComplexType>().getElementType(), 1123 dataEltSize / 2, isInt, isSigned)); 1124 1125 int64_t numElements = data.size() / dataEltSize; 1126 assert(numElements == 1 || numElements == type.getNumElements()); 1127 return getRaw(type, data, /*isSplat=*/numElements == 1); 1128 } 1129 1130 /// Overload of the 'getRaw' method that asserts that the given type is of 1131 /// integer type. This method is used to verify type invariants that the 1132 /// templatized 'get' method cannot. 1133 DenseElementsAttr 1134 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1135 int64_t dataEltSize, bool isInt, 1136 bool isSigned) { 1137 assert( 1138 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1139 1140 int64_t numElements = data.size() / dataEltSize; 1141 assert(numElements == 1 || numElements == type.getNumElements()); 1142 return getRaw(type, data, /*isSplat=*/numElements == 1); 1143 } 1144 1145 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 1146 const char *inRawData, char *outRawData, size_t elementBitWidth, 1147 size_t numElements) { 1148 using llvm::support::ulittle16_t; 1149 using llvm::support::ulittle32_t; 1150 using llvm::support::ulittle64_t; 1151 1152 assert(llvm::support::endian::system_endianness() == // NOLINT 1153 llvm::support::endianness::big); // NOLINT 1154 // NOLINT to avoid warning message about replacing by static_assert() 1155 1156 // Following std::copy_n always converts endianness on BE machine. 1157 switch (elementBitWidth) { 1158 case 16: { 1159 const ulittle16_t *inRawDataPos = 1160 reinterpret_cast<const ulittle16_t *>(inRawData); 1161 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); 1162 std::copy_n(inRawDataPos, numElements, outDataPos); 1163 break; 1164 } 1165 case 32: { 1166 const ulittle32_t *inRawDataPos = 1167 reinterpret_cast<const ulittle32_t *>(inRawData); 1168 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); 1169 std::copy_n(inRawDataPos, numElements, outDataPos); 1170 break; 1171 } 1172 case 64: { 1173 const ulittle64_t *inRawDataPos = 1174 reinterpret_cast<const ulittle64_t *>(inRawData); 1175 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); 1176 std::copy_n(inRawDataPos, numElements, outDataPos); 1177 break; 1178 } 1179 default: { 1180 size_t nBytes = elementBitWidth / CHAR_BIT; 1181 for (size_t i = 0; i < nBytes; i++) 1182 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); 1183 break; 1184 } 1185 } 1186 } 1187 1188 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1189 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, 1190 ShapedType type) { 1191 size_t numElements = type.getNumElements(); 1192 Type elementType = type.getElementType(); 1193 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1194 elementType = complexTy.getElementType(); 1195 numElements = numElements * 2; 1196 } 1197 size_t elementBitWidth = getDenseElementStorageWidth(elementType); 1198 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && 1199 inRawData.size() <= outRawData.size()); 1200 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), 1201 elementBitWidth, numElements); 1202 } 1203 1204 //===----------------------------------------------------------------------===// 1205 // DenseFPElementsAttr 1206 //===----------------------------------------------------------------------===// 1207 1208 template <typename Fn, typename Attr> 1209 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1210 Type newElementType, 1211 llvm::SmallVectorImpl<char> &data) { 1212 size_t bitWidth = getDenseElementBitWidth(newElementType); 1213 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1214 1215 ShapedType newArrayType; 1216 if (inType.isa<RankedTensorType>()) 1217 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1218 else if (inType.isa<UnrankedTensorType>()) 1219 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1220 else if (inType.isa<VectorType>()) 1221 newArrayType = VectorType::get(inType.getShape(), newElementType); 1222 else 1223 assert(newArrayType && "Unhandled tensor type"); 1224 1225 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1226 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 1227 1228 // Functor used to process a single element value of the attribute. 1229 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1230 auto newInt = mapping(value); 1231 assert(newInt.getBitWidth() == bitWidth); 1232 writeBits(data.data(), index * storageBitWidth, newInt); 1233 }; 1234 1235 // Check for the splat case. 1236 if (attr.isSplat()) { 1237 processElt(*attr.begin(), /*index=*/0); 1238 return newArrayType; 1239 } 1240 1241 // Otherwise, process all of the element values. 1242 uint64_t elementIdx = 0; 1243 for (auto value : attr) 1244 processElt(value, elementIdx++); 1245 return newArrayType; 1246 } 1247 1248 DenseElementsAttr DenseFPElementsAttr::mapValues( 1249 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1250 llvm::SmallVector<char, 8> elementData; 1251 auto newArrayType = 1252 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1253 1254 return getRaw(newArrayType, elementData, isSplat()); 1255 } 1256 1257 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1258 bool DenseFPElementsAttr::classof(Attribute attr) { 1259 return attr.isa<DenseElementsAttr>() && 1260 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1261 } 1262 1263 //===----------------------------------------------------------------------===// 1264 // DenseIntElementsAttr 1265 //===----------------------------------------------------------------------===// 1266 1267 DenseElementsAttr DenseIntElementsAttr::mapValues( 1268 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1269 llvm::SmallVector<char, 8> elementData; 1270 auto newArrayType = 1271 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1272 1273 return getRaw(newArrayType, elementData, isSplat()); 1274 } 1275 1276 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1277 bool DenseIntElementsAttr::classof(Attribute attr) { 1278 return attr.isa<DenseElementsAttr>() && 1279 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1280 } 1281 1282 //===----------------------------------------------------------------------===// 1283 // OpaqueElementsAttr 1284 //===----------------------------------------------------------------------===// 1285 1286 /// Return the value at the given index. If index does not refer to a valid 1287 /// element, then a null attribute is returned. 1288 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1289 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1290 return Attribute(); 1291 } 1292 1293 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1294 Dialect *dialect = getDialect().getDialect(); 1295 if (!dialect) 1296 return true; 1297 auto *interface = 1298 dialect->getRegisteredInterface<DialectDecodeAttributesInterface>(); 1299 if (!interface) 1300 return true; 1301 return failed(interface->decode(*this, result)); 1302 } 1303 1304 LogicalResult 1305 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1306 Identifier dialect, StringRef value, 1307 ShapedType type) { 1308 if (!Dialect::isValidNamespace(dialect.strref())) 1309 return emitError() << "invalid dialect namespace '" << dialect << "'"; 1310 return success(); 1311 } 1312 1313 //===----------------------------------------------------------------------===// 1314 // SparseElementsAttr 1315 //===----------------------------------------------------------------------===// 1316 1317 /// Return the value of the element at the given index. 1318 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1319 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1320 auto type = getType(); 1321 1322 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1323 // as a 1-D index array. 1324 auto sparseIndices = getIndices(); 1325 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1326 1327 // Check to see if the indices are a splat. 1328 if (sparseIndices.isSplat()) { 1329 // If the index is also not a splat of the index value, we know that the 1330 // value is zero. 1331 auto splatIndex = *sparseIndexValues.begin(); 1332 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1333 return getZeroAttr(); 1334 1335 // If the indices are a splat, we also expect the values to be a splat. 1336 assert(getValues().isSplat() && "expected splat values"); 1337 return getValues().getSplatValue(); 1338 } 1339 1340 // Build a mapping between known indices and the offset of the stored element. 1341 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1342 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1343 size_t rank = type.getRank(); 1344 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1345 mappedIndices.try_emplace( 1346 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1347 1348 // Look for the provided index key within the mapped indices. If the provided 1349 // index is not found, then return a zero attribute. 1350 auto it = mappedIndices.find(index); 1351 if (it == mappedIndices.end()) 1352 return getZeroAttr(); 1353 1354 // Otherwise, return the held sparse value element. 1355 return getValues().getValue(it->second); 1356 } 1357 1358 /// Get a zero APFloat for the given sparse attribute. 1359 APFloat SparseElementsAttr::getZeroAPFloat() const { 1360 auto eltType = getType().getElementType().cast<FloatType>(); 1361 return APFloat(eltType.getFloatSemantics()); 1362 } 1363 1364 /// Get a zero APInt for the given sparse attribute. 1365 APInt SparseElementsAttr::getZeroAPInt() const { 1366 auto eltType = getType().getElementType().cast<IntegerType>(); 1367 return APInt::getNullValue(eltType.getWidth()); 1368 } 1369 1370 /// Get a zero attribute for the given attribute type. 1371 Attribute SparseElementsAttr::getZeroAttr() const { 1372 auto eltType = getType().getElementType(); 1373 1374 // Handle floating point elements. 1375 if (eltType.isa<FloatType>()) 1376 return FloatAttr::get(eltType, 0); 1377 1378 // Otherwise, this is an integer. 1379 // TODO: Handle StringAttr here. 1380 return IntegerAttr::get(eltType, 0); 1381 } 1382 1383 /// Flatten, and return, all of the sparse indices in this attribute in 1384 /// row-major order. 1385 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1386 std::vector<ptrdiff_t> flatSparseIndices; 1387 1388 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1389 // as a 1-D index array. 1390 auto sparseIndices = getIndices(); 1391 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1392 if (sparseIndices.isSplat()) { 1393 SmallVector<uint64_t, 8> indices(getType().getRank(), 1394 *sparseIndexValues.begin()); 1395 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1396 return flatSparseIndices; 1397 } 1398 1399 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1400 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1401 size_t rank = getType().getRank(); 1402 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1403 flatSparseIndices.push_back(getFlattenedIndex( 1404 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1405 return flatSparseIndices; 1406 } 1407 1408 //===----------------------------------------------------------------------===// 1409 // TypeAttr 1410 //===----------------------------------------------------------------------===// 1411 1412 void TypeAttr::walkImmediateSubElements( 1413 function_ref<void(Attribute)> walkAttrsFn, 1414 function_ref<void(Type)> walkTypesFn) const { 1415 walkTypesFn(getValue()); 1416 } 1417