1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/BuiltinDialect.h" 13 #include "mlir/IR/Diagnostics.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/IR/Types.h" 17 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/Twine.h" 20 #include "llvm/Support/Endian.h" 21 22 using namespace mlir; 23 using namespace mlir::detail; 24 25 //===----------------------------------------------------------------------===// 26 /// Tablegen Attribute Definitions 27 //===----------------------------------------------------------------------===// 28 29 #define GET_ATTRDEF_CLASSES 30 #include "mlir/IR/BuiltinAttributes.cpp.inc" 31 32 //===----------------------------------------------------------------------===// 33 // BuiltinDialect 34 //===----------------------------------------------------------------------===// 35 36 void BuiltinDialect::registerAttributes() { 37 addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr, 38 DenseStringElementsAttr, DictionaryAttr, FloatAttr, 39 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr, 40 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr, 41 UnitAttr>(); 42 } 43 44 //===----------------------------------------------------------------------===// 45 // DictionaryAttr 46 //===----------------------------------------------------------------------===// 47 48 /// Helper function that does either an in place sort or sorts from source array 49 /// into destination. If inPlace then storage is both the source and the 50 /// destination, else value is the source and storage destination. Returns 51 /// whether source was sorted. 52 template <bool inPlace> 53 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 54 SmallVectorImpl<NamedAttribute> &storage) { 55 // Specialize for the common case. 56 switch (value.size()) { 57 case 0: 58 // Zero already sorted. 59 break; 60 case 1: 61 // One already sorted but may need to be copied. 62 if (!inPlace) 63 storage.assign({value[0]}); 64 break; 65 case 2: { 66 bool isSorted = value[0] < value[1]; 67 if (inPlace) { 68 if (!isSorted) 69 std::swap(storage[0], storage[1]); 70 } else if (isSorted) { 71 storage.assign({value[0], value[1]}); 72 } else { 73 storage.assign({value[1], value[0]}); 74 } 75 return !isSorted; 76 } 77 default: 78 if (!inPlace) 79 storage.assign(value.begin(), value.end()); 80 // Check to see they are sorted already. 81 bool isSorted = llvm::is_sorted(value); 82 // If not, do a general sort. 83 if (!isSorted) 84 llvm::array_pod_sort(storage.begin(), storage.end()); 85 return !isSorted; 86 } 87 return false; 88 } 89 90 /// Returns an entry with a duplicate name from the given sorted array of named 91 /// attributes. Returns llvm::None if all elements have unique names. 92 static Optional<NamedAttribute> 93 findDuplicateElement(ArrayRef<NamedAttribute> value) { 94 const Optional<NamedAttribute> none{llvm::None}; 95 if (value.size() < 2) 96 return none; 97 98 if (value.size() == 2) 99 return value[0].first == value[1].first ? value[0] : none; 100 101 auto it = std::adjacent_find( 102 value.begin(), value.end(), 103 [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; }); 104 return it != value.end() ? *it : none; 105 } 106 107 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 108 SmallVectorImpl<NamedAttribute> &storage) { 109 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); 110 assert(!findDuplicateElement(storage) && 111 "DictionaryAttr element names must be unique"); 112 return isSorted; 113 } 114 115 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 116 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); 117 assert(!findDuplicateElement(array) && 118 "DictionaryAttr element names must be unique"); 119 return isSorted; 120 } 121 122 Optional<NamedAttribute> 123 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, 124 bool isSorted) { 125 if (!isSorted) 126 dictionaryAttrSort</*inPlace=*/true>(array, array); 127 return findDuplicateElement(array); 128 } 129 130 DictionaryAttr DictionaryAttr::get(MLIRContext *context, 131 ArrayRef<NamedAttribute> value) { 132 if (value.empty()) 133 return DictionaryAttr::getEmpty(context); 134 assert(llvm::all_of(value, 135 [](const NamedAttribute &attr) { return attr.second; }) && 136 "value cannot have null entries"); 137 138 // We need to sort the element list to canonicalize it. 139 SmallVector<NamedAttribute, 8> storage; 140 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 141 value = storage; 142 assert(!findDuplicateElement(value) && 143 "DictionaryAttr element names must be unique"); 144 return Base::get(context, value); 145 } 146 /// Construct a dictionary with an array of values that is known to already be 147 /// sorted by name and uniqued. 148 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context, 149 ArrayRef<NamedAttribute> value) { 150 if (value.empty()) 151 return DictionaryAttr::getEmpty(context); 152 // Ensure that the attribute elements are unique and sorted. 153 assert(llvm::is_sorted(value, 154 [](NamedAttribute l, NamedAttribute r) { 155 return l.first.strref() < r.first.strref(); 156 }) && 157 "expected attribute values to be sorted"); 158 assert(!findDuplicateElement(value) && 159 "DictionaryAttr element names must be unique"); 160 return Base::get(context, value); 161 } 162 163 /// Return the specified attribute if present, null otherwise. 164 Attribute DictionaryAttr::get(StringRef name) const { 165 Optional<NamedAttribute> attr = getNamed(name); 166 return attr ? attr->second : nullptr; 167 } 168 Attribute DictionaryAttr::get(Identifier name) const { 169 Optional<NamedAttribute> attr = getNamed(name); 170 return attr ? attr->second : nullptr; 171 } 172 173 /// Return the specified named attribute if present, None otherwise. 174 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 175 ArrayRef<NamedAttribute> values = getValue(); 176 const auto *it = llvm::lower_bound(values, name); 177 return it != values.end() && it->first == name ? *it 178 : Optional<NamedAttribute>(); 179 } 180 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { 181 for (auto elt : getValue()) 182 if (elt.first == name) 183 return elt; 184 return llvm::None; 185 } 186 187 DictionaryAttr::iterator DictionaryAttr::begin() const { 188 return getValue().begin(); 189 } 190 DictionaryAttr::iterator DictionaryAttr::end() const { 191 return getValue().end(); 192 } 193 size_t DictionaryAttr::size() const { return getValue().size(); } 194 195 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { 196 return Base::get(context, ArrayRef<NamedAttribute>()); 197 } 198 199 //===----------------------------------------------------------------------===// 200 // FloatAttr 201 //===----------------------------------------------------------------------===// 202 203 double FloatAttr::getValueAsDouble() const { 204 return getValueAsDouble(getValue()); 205 } 206 double FloatAttr::getValueAsDouble(APFloat value) { 207 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 208 bool losesInfo = false; 209 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 210 &losesInfo); 211 } 212 return value.convertToDouble(); 213 } 214 215 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 216 Type type, APFloat value) { 217 // Verify that the type is correct. 218 if (!type.isa<FloatType>()) 219 return emitError() << "expected floating point type"; 220 221 // Verify that the type semantics match that of the value. 222 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 223 return emitError() 224 << "FloatAttr type doesn't match the type implied by its value"; 225 } 226 return success(); 227 } 228 229 //===----------------------------------------------------------------------===// 230 // SymbolRefAttr 231 //===----------------------------------------------------------------------===// 232 233 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { 234 return get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>(); 235 } 236 237 StringRef SymbolRefAttr::getLeafReference() const { 238 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 239 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 240 } 241 242 //===----------------------------------------------------------------------===// 243 // IntegerAttr 244 //===----------------------------------------------------------------------===// 245 246 int64_t IntegerAttr::getInt() const { 247 assert((getType().isIndex() || getType().isSignlessInteger()) && 248 "must be signless integer"); 249 return getValue().getSExtValue(); 250 } 251 252 int64_t IntegerAttr::getSInt() const { 253 assert(getType().isSignedInteger() && "must be signed integer"); 254 return getValue().getSExtValue(); 255 } 256 257 uint64_t IntegerAttr::getUInt() const { 258 assert(getType().isUnsignedInteger() && "must be unsigned integer"); 259 return getValue().getZExtValue(); 260 } 261 262 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, 263 Type type, APInt value) { 264 if (IntegerType integerType = type.dyn_cast<IntegerType>()) { 265 if (integerType.getWidth() != value.getBitWidth()) 266 return emitError() << "integer type bit width (" << integerType.getWidth() 267 << ") doesn't match value bit width (" 268 << value.getBitWidth() << ")"; 269 return success(); 270 } 271 if (type.isa<IndexType>()) 272 return success(); 273 return emitError() << "expected integer or index type"; 274 } 275 276 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { 277 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value)); 278 return attr.cast<BoolAttr>(); 279 } 280 281 //===----------------------------------------------------------------------===// 282 // BoolAttr 283 284 bool BoolAttr::getValue() const { 285 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl); 286 return storage->value.getBoolValue(); 287 } 288 289 bool BoolAttr::classof(Attribute attr) { 290 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 291 return intAttr && intAttr.getType().isSignlessInteger(1); 292 } 293 294 //===----------------------------------------------------------------------===// 295 // OpaqueAttr 296 //===----------------------------------------------------------------------===// 297 298 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError, 299 Identifier dialect, StringRef attrData, 300 Type type) { 301 if (!Dialect::isValidNamespace(dialect.strref())) 302 return emitError() << "invalid dialect namespace '" << dialect << "'"; 303 304 // Check that the dialect is actually registered. 305 MLIRContext *context = dialect.getContext(); 306 if (!context->allowsUnregisteredDialects() && 307 !context->getLoadedDialect(dialect.strref())) { 308 return emitError() 309 << "#" << dialect << "<\"" << attrData << "\"> : " << type 310 << " attribute created with unregistered dialect. If this is " 311 "intended, please call allowUnregisteredDialects() on the " 312 "MLIRContext, or use -allow-unregistered-dialect with " 313 "mlir-opt"; 314 } 315 316 return success(); 317 } 318 319 //===----------------------------------------------------------------------===// 320 // ElementsAttr 321 //===----------------------------------------------------------------------===// 322 323 ShapedType ElementsAttr::getType() const { 324 return Attribute::getType().cast<ShapedType>(); 325 } 326 327 /// Returns the number of elements held by this attribute. 328 int64_t ElementsAttr::getNumElements() const { 329 return getType().getNumElements(); 330 } 331 332 /// Return the value at the given index. If index does not refer to a valid 333 /// element, then a null attribute is returned. 334 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 335 if (auto denseAttr = dyn_cast<DenseElementsAttr>()) 336 return denseAttr.getValue(index); 337 if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>()) 338 return opaqueAttr.getValue(index); 339 return cast<SparseElementsAttr>().getValue(index); 340 } 341 342 /// Return if the given 'index' refers to a valid element in this attribute. 343 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 344 auto type = getType(); 345 346 // Verify that the rank of the indices matches the held type. 347 auto rank = type.getRank(); 348 if (rank == 0 && index.size() == 1 && index[0] == 0) 349 return true; 350 if (rank != static_cast<int64_t>(index.size())) 351 return false; 352 353 // Verify that all of the indices are within the shape dimensions. 354 auto shape = type.getShape(); 355 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 356 int64_t dim = static_cast<int64_t>(index[i]); 357 return 0 <= dim && dim < shape[i]; 358 }); 359 } 360 361 ElementsAttr 362 ElementsAttr::mapValues(Type newElementType, 363 function_ref<APInt(const APInt &)> mapping) const { 364 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 365 return intOrFpAttr.mapValues(newElementType, mapping); 366 llvm_unreachable("unsupported ElementsAttr subtype"); 367 } 368 369 ElementsAttr 370 ElementsAttr::mapValues(Type newElementType, 371 function_ref<APInt(const APFloat &)> mapping) const { 372 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 373 return intOrFpAttr.mapValues(newElementType, mapping); 374 llvm_unreachable("unsupported ElementsAttr subtype"); 375 } 376 377 /// Method for support type inquiry through isa, cast and dyn_cast. 378 bool ElementsAttr::classof(Attribute attr) { 379 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr, 380 OpaqueElementsAttr, SparseElementsAttr>(); 381 } 382 383 /// Returns the 1 dimensional flattened row-major index from the given 384 /// multi-dimensional index. 385 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 386 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 387 auto type = getType(); 388 389 // Reduce the provided multidimensional index into a flattended 1D row-major 390 // index. 391 auto rank = type.getRank(); 392 auto shape = type.getShape(); 393 uint64_t valueIndex = 0; 394 uint64_t dimMultiplier = 1; 395 for (int i = rank - 1; i >= 0; --i) { 396 valueIndex += index[i] * dimMultiplier; 397 dimMultiplier *= shape[i]; 398 } 399 return valueIndex; 400 } 401 402 //===----------------------------------------------------------------------===// 403 // DenseElementsAttr Utilities 404 //===----------------------------------------------------------------------===// 405 406 /// Get the bitwidth of a dense element type within the buffer. 407 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 408 static size_t getDenseElementStorageWidth(size_t origWidth) { 409 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 410 } 411 static size_t getDenseElementStorageWidth(Type elementType) { 412 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 413 } 414 415 /// Set a bit to a specific value. 416 static void setBit(char *rawData, size_t bitPos, bool value) { 417 if (value) 418 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 419 else 420 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 421 } 422 423 /// Return the value of the specified bit. 424 static bool getBit(const char *rawData, size_t bitPos) { 425 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 426 } 427 428 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for 429 /// BE format. 430 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, 431 char *result) { 432 assert(llvm::support::endian::system_endianness() == // NOLINT 433 llvm::support::endianness::big); // NOLINT 434 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 435 436 // Copy the words filled with data. 437 // For example, when `value` has 2 words, the first word is filled with data. 438 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| 439 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 440 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 441 numFilledWords, result); 442 // Convert last word of APInt to LE format and store it in char 443 // array(`valueLE`). 444 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| 445 size_t lastWordPos = numFilledWords; 446 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); 447 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 448 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, 449 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); 450 // Extract actual APInt data from `valueLE`, convert endianness to BE format, 451 // and store it in `result`. 452 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| 453 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 454 valueLE.begin(), result + lastWordPos, 455 (numBytes - lastWordPos) * CHAR_BIT, 1); 456 } 457 458 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE 459 /// format. 460 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, 461 APInt &result) { 462 assert(llvm::support::endian::system_endianness() == // NOLINT 463 llvm::support::endianness::big); // NOLINT 464 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 465 466 // Copy the data that fills the word of `result` from `inArray`. 467 // For example, when `result` has 2 words, the first word will be filled with 468 // data. So, the first 8 bytes are copied from `inArray` here. 469 // `inArray` (10 bytes, BE): |abcdefgh|ij| 470 // ==> `result` (2 words, BE): |abcdefgh|--------| 471 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 472 std::copy_n( 473 inArray, numFilledWords, 474 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 475 476 // Convert array data which will be last word of `result` to LE format, and 477 // store it in char array(`inArrayLE`). 478 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| 479 size_t lastWordPos = numFilledWords; 480 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); 481 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 482 inArray + lastWordPos, inArrayLE.begin(), 483 (numBytes - lastWordPos) * CHAR_BIT, 1); 484 485 // Convert `inArrayLE` to BE format, and store it in last word of `result`. 486 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| 487 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 488 inArrayLE.begin(), 489 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + 490 lastWordPos, 491 APInt::APINT_BITS_PER_WORD, 1); 492 } 493 494 /// Writes value to the bit position `bitPos` in array `rawData`. 495 static void writeBits(char *rawData, size_t bitPos, APInt value) { 496 size_t bitWidth = value.getBitWidth(); 497 498 // If the bitwidth is 1 we just toggle the specific bit. 499 if (bitWidth == 1) 500 return setBit(rawData, bitPos, value.isOneValue()); 501 502 // Otherwise, the bit position is guaranteed to be byte aligned. 503 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 504 if (llvm::support::endian::system_endianness() == 505 llvm::support::endianness::big) { 506 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. 507 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 508 // work correctly in BE format. 509 // ex. `value` (2 words including 10 bytes) 510 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| 511 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), 512 rawData + (bitPos / CHAR_BIT)); 513 } else { 514 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 515 llvm::divideCeil(bitWidth, CHAR_BIT), 516 rawData + (bitPos / CHAR_BIT)); 517 } 518 } 519 520 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 521 /// `rawData`. 522 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 523 // Handle a boolean bit position. 524 if (bitWidth == 1) 525 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 526 527 // Otherwise, the bit position must be 8-bit aligned. 528 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 529 APInt result(bitWidth, 0); 530 if (llvm::support::endian::system_endianness() == 531 llvm::support::endianness::big) { 532 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. 533 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 534 // work correctly in BE format. 535 // ex. `result` (2 words including 10 bytes) 536 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function 537 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), 538 llvm::divideCeil(bitWidth, CHAR_BIT), result); 539 } else { 540 std::copy_n(rawData + (bitPos / CHAR_BIT), 541 llvm::divideCeil(bitWidth, CHAR_BIT), 542 const_cast<char *>( 543 reinterpret_cast<const char *>(result.getRawData()))); 544 } 545 return result; 546 } 547 548 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 549 /// the same element count as 'type'. 550 template <typename Values> 551 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 552 return (values.size() == 1) || 553 (type.getNumElements() == static_cast<int64_t>(values.size())); 554 } 555 556 //===----------------------------------------------------------------------===// 557 // DenseElementsAttr Iterators 558 //===----------------------------------------------------------------------===// 559 560 //===----------------------------------------------------------------------===// 561 // AttributeElementIterator 562 563 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 564 DenseElementsAttr attr, size_t index) 565 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 566 Attribute, Attribute, Attribute>( 567 attr.getAsOpaquePointer(), index) {} 568 569 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 570 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 571 Type eltTy = owner.getType().getElementType(); 572 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 573 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 574 if (eltTy.isa<IndexType>()) 575 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 576 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 577 IntElementIterator intIt(owner, index); 578 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 579 return FloatAttr::get(eltTy, *floatIt); 580 } 581 if (owner.isa<DenseStringElementsAttr>()) { 582 ArrayRef<StringRef> vals = owner.getRawStringData(); 583 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 584 } 585 llvm_unreachable("unexpected element type"); 586 } 587 588 //===----------------------------------------------------------------------===// 589 // BoolElementIterator 590 591 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 592 DenseElementsAttr attr, size_t dataIndex) 593 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 594 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 595 596 bool DenseElementsAttr::BoolElementIterator::operator*() const { 597 return getBit(getData(), getDataIndex()); 598 } 599 600 //===----------------------------------------------------------------------===// 601 // IntElementIterator 602 603 DenseElementsAttr::IntElementIterator::IntElementIterator( 604 DenseElementsAttr attr, size_t dataIndex) 605 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 606 attr.getRawData().data(), attr.isSplat(), dataIndex), 607 bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} 608 609 APInt DenseElementsAttr::IntElementIterator::operator*() const { 610 return readBits(getData(), 611 getDataIndex() * getDenseElementStorageWidth(bitWidth), 612 bitWidth); 613 } 614 615 //===----------------------------------------------------------------------===// 616 // ComplexIntElementIterator 617 618 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 619 DenseElementsAttr attr, size_t dataIndex) 620 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 621 std::complex<APInt>, std::complex<APInt>, 622 std::complex<APInt>>( 623 attr.getRawData().data(), attr.isSplat(), dataIndex) { 624 auto complexType = attr.getType().getElementType().cast<ComplexType>(); 625 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 626 } 627 628 std::complex<APInt> 629 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 630 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 631 size_t offset = getDataIndex() * storageWidth * 2; 632 return {readBits(getData(), offset, bitWidth), 633 readBits(getData(), offset + storageWidth, bitWidth)}; 634 } 635 636 //===----------------------------------------------------------------------===// 637 // FloatElementIterator 638 639 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 640 const llvm::fltSemantics &smt, IntElementIterator it) 641 : llvm::mapped_iterator<IntElementIterator, 642 std::function<APFloat(const APInt &)>>( 643 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 644 645 //===----------------------------------------------------------------------===// 646 // ComplexFloatElementIterator 647 648 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( 649 const llvm::fltSemantics &smt, ComplexIntElementIterator it) 650 : llvm::mapped_iterator< 651 ComplexIntElementIterator, 652 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( 653 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { 654 return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; 655 }) {} 656 657 //===----------------------------------------------------------------------===// 658 // DenseElementsAttr 659 //===----------------------------------------------------------------------===// 660 661 /// Method for support type inquiry through isa, cast and dyn_cast. 662 bool DenseElementsAttr::classof(Attribute attr) { 663 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 664 } 665 666 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 667 ArrayRef<Attribute> values) { 668 assert(hasSameElementsOrSplat(type, values)); 669 670 // If the element type is not based on int/float/index, assume it is a string 671 // type. 672 auto eltType = type.getElementType(); 673 if (!type.getElementType().isIntOrIndexOrFloat()) { 674 SmallVector<StringRef, 8> stringValues; 675 stringValues.reserve(values.size()); 676 for (Attribute attr : values) { 677 assert(attr.isa<StringAttr>() && 678 "expected string value for non integer/index/float element"); 679 stringValues.push_back(attr.cast<StringAttr>().getValue()); 680 } 681 return get(type, stringValues); 682 } 683 684 // Otherwise, get the raw storage width to use for the allocation. 685 size_t bitWidth = getDenseElementBitWidth(eltType); 686 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 687 688 // Compress the attribute values into a character buffer. 689 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 690 values.size()); 691 APInt intVal; 692 for (unsigned i = 0, e = values.size(); i < e; ++i) { 693 assert(eltType == values[i].getType() && 694 "expected attribute value to have element type"); 695 if (eltType.isa<FloatType>()) 696 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 697 else if (eltType.isa<IntegerType, IndexType>()) 698 intVal = values[i].cast<IntegerAttr>().getValue(); 699 else 700 llvm_unreachable("unexpected element type"); 701 702 assert(intVal.getBitWidth() == bitWidth && 703 "expected value to have same bitwidth as element type"); 704 writeBits(data.data(), i * storageBitWidth, intVal); 705 } 706 return DenseIntOrFPElementsAttr::getRaw(type, data, 707 /*isSplat=*/(values.size() == 1)); 708 } 709 710 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 711 ArrayRef<bool> values) { 712 assert(hasSameElementsOrSplat(type, values)); 713 assert(type.getElementType().isInteger(1)); 714 715 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 716 for (int i = 0, e = values.size(); i != e; ++i) 717 setBit(buff.data(), i, values[i]); 718 return DenseIntOrFPElementsAttr::getRaw(type, buff, 719 /*isSplat=*/(values.size() == 1)); 720 } 721 722 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 723 ArrayRef<StringRef> values) { 724 assert(!type.getElementType().isIntOrFloat()); 725 return DenseStringElementsAttr::get(type, values); 726 } 727 728 /// Constructs a dense integer elements attribute from an array of APInt 729 /// values. Each APInt value is expected to have the same bitwidth as the 730 /// element type of 'type'. 731 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 732 ArrayRef<APInt> values) { 733 assert(type.getElementType().isIntOrIndex()); 734 assert(hasSameElementsOrSplat(type, values)); 735 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 736 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 737 /*isSplat=*/(values.size() == 1)); 738 } 739 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 740 ArrayRef<std::complex<APInt>> values) { 741 ComplexType complex = type.getElementType().cast<ComplexType>(); 742 assert(complex.getElementType().isa<IntegerType>()); 743 assert(hasSameElementsOrSplat(type, values)); 744 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 745 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 746 values.size() * 2); 747 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, 748 /*isSplat=*/(values.size() == 1)); 749 } 750 751 // Constructs a dense float elements attribute from an array of APFloat 752 // values. Each APFloat value is expected to have the same bitwidth as the 753 // element type of 'type'. 754 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 755 ArrayRef<APFloat> values) { 756 assert(type.getElementType().isa<FloatType>()); 757 assert(hasSameElementsOrSplat(type, values)); 758 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 759 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 760 /*isSplat=*/(values.size() == 1)); 761 } 762 DenseElementsAttr 763 DenseElementsAttr::get(ShapedType type, 764 ArrayRef<std::complex<APFloat>> values) { 765 ComplexType complex = type.getElementType().cast<ComplexType>(); 766 assert(complex.getElementType().isa<FloatType>()); 767 assert(hasSameElementsOrSplat(type, values)); 768 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 769 values.size() * 2); 770 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 771 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, 772 /*isSplat=*/(values.size() == 1)); 773 } 774 775 /// Construct a dense elements attribute from a raw buffer representing the 776 /// data for this attribute. Users should generally not use this methods as 777 /// the expected buffer format may not be a form the user expects. 778 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 779 ArrayRef<char> rawBuffer, 780 bool isSplatBuffer) { 781 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 782 } 783 784 /// Returns true if the given buffer is a valid raw buffer for the given type. 785 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 786 ArrayRef<char> rawBuffer, 787 bool &detectedSplat) { 788 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 789 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 790 791 // Storage width of 1 is special as it is packed by the bit. 792 if (storageWidth == 1) { 793 // Check for a splat, or a buffer equal to the number of elements. 794 if ((detectedSplat = rawBuffer.size() == 1)) 795 return true; 796 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); 797 } 798 // All other types are 8-bit aligned. 799 if ((detectedSplat = rawBufferWidth == storageWidth)) 800 return true; 801 return rawBufferWidth == (storageWidth * type.getNumElements()); 802 } 803 804 /// Check the information for a C++ data type, check if this type is valid for 805 /// the current attribute. This method is used to verify specific type 806 /// invariants that the templatized 'getValues' method cannot. 807 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 808 bool isSigned) { 809 // Make sure that the data element size is the same as the type element width. 810 if (getDenseElementBitWidth(type) != 811 static_cast<size_t>(dataEltSize * CHAR_BIT)) 812 return false; 813 814 // Check that the element type is either float or integer or index. 815 if (!isInt) 816 return type.isa<FloatType>(); 817 if (type.isIndex()) 818 return true; 819 820 auto intType = type.dyn_cast<IntegerType>(); 821 if (!intType) 822 return false; 823 824 // Make sure signedness semantics is consistent. 825 if (intType.isSignless()) 826 return true; 827 return intType.isSigned() ? isSigned : !isSigned; 828 } 829 830 /// Defaults down the subclass implementation. 831 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 832 ArrayRef<char> data, 833 int64_t dataEltSize, 834 bool isInt, bool isSigned) { 835 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 836 isSigned); 837 } 838 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 839 ArrayRef<char> data, 840 int64_t dataEltSize, 841 bool isInt, 842 bool isSigned) { 843 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 844 isInt, isSigned); 845 } 846 847 /// A method used to verify specific type invariants that the templatized 'get' 848 /// method cannot. 849 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 850 bool isSigned) const { 851 return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt, 852 isSigned); 853 } 854 855 /// Check the information for a C++ data type, check if this type is valid for 856 /// the current attribute. 857 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 858 bool isSigned) const { 859 return ::isValidIntOrFloat( 860 getType().getElementType().cast<ComplexType>().getElementType(), 861 dataEltSize / 2, isInt, isSigned); 862 } 863 864 /// Returns true if this attribute corresponds to a splat, i.e. if all element 865 /// values are the same. 866 bool DenseElementsAttr::isSplat() const { 867 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 868 } 869 870 /// Return the held element values as a range of Attributes. 871 auto DenseElementsAttr::getAttributeValues() const 872 -> llvm::iterator_range<AttributeElementIterator> { 873 return {attr_value_begin(), attr_value_end()}; 874 } 875 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 876 return AttributeElementIterator(*this, 0); 877 } 878 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 879 return AttributeElementIterator(*this, getNumElements()); 880 } 881 882 /// Return the held element values as a range of bool. The element type of 883 /// this attribute must be of integer type of bitwidth 1. 884 auto DenseElementsAttr::getBoolValues() const 885 -> llvm::iterator_range<BoolElementIterator> { 886 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 887 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 888 (void)eltType; 889 return {BoolElementIterator(*this, 0), 890 BoolElementIterator(*this, getNumElements())}; 891 } 892 893 /// Return the held element values as a range of APInts. The element type of 894 /// this attribute must be of integer type. 895 auto DenseElementsAttr::getIntValues() const 896 -> llvm::iterator_range<IntElementIterator> { 897 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 898 return {raw_int_begin(), raw_int_end()}; 899 } 900 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 901 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 902 return raw_int_begin(); 903 } 904 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 905 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 906 return raw_int_end(); 907 } 908 auto DenseElementsAttr::getComplexIntValues() const 909 -> llvm::iterator_range<ComplexIntElementIterator> { 910 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 911 (void)eltTy; 912 assert(eltTy.isa<IntegerType>() && "expected complex integral type"); 913 return {ComplexIntElementIterator(*this, 0), 914 ComplexIntElementIterator(*this, getNumElements())}; 915 } 916 917 /// Return the held element values as a range of APFloat. The element type of 918 /// this attribute must be of float type. 919 auto DenseElementsAttr::getFloatValues() const 920 -> llvm::iterator_range<FloatElementIterator> { 921 auto elementType = getType().getElementType().cast<FloatType>(); 922 const auto &elementSemantics = elementType.getFloatSemantics(); 923 return {FloatElementIterator(elementSemantics, raw_int_begin()), 924 FloatElementIterator(elementSemantics, raw_int_end())}; 925 } 926 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 927 return getFloatValues().begin(); 928 } 929 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 930 return getFloatValues().end(); 931 } 932 auto DenseElementsAttr::getComplexFloatValues() const 933 -> llvm::iterator_range<ComplexFloatElementIterator> { 934 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 935 assert(eltTy.isa<FloatType>() && "expected complex float type"); 936 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 937 return {{semantics, {*this, 0}}, 938 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 939 } 940 941 /// Return the raw storage data held by this attribute. 942 ArrayRef<char> DenseElementsAttr::getRawData() const { 943 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data; 944 } 945 946 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 947 return static_cast<DenseStringElementsAttrStorage *>(impl)->data; 948 } 949 950 /// Return a new DenseElementsAttr that has the same data as the current 951 /// attribute, but has been reshaped to 'newType'. The new type must have the 952 /// same total number of elements as well as element type. 953 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 954 ShapedType curType = getType(); 955 if (curType == newType) 956 return *this; 957 958 (void)curType; 959 assert(newType.getElementType() == curType.getElementType() && 960 "expected the same element type"); 961 assert(newType.getNumElements() == curType.getNumElements() && 962 "expected the same number of elements"); 963 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 964 } 965 966 DenseElementsAttr 967 DenseElementsAttr::mapValues(Type newElementType, 968 function_ref<APInt(const APInt &)> mapping) const { 969 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 970 } 971 972 DenseElementsAttr DenseElementsAttr::mapValues( 973 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 974 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 975 } 976 977 //===----------------------------------------------------------------------===// 978 // DenseIntOrFPElementsAttr 979 //===----------------------------------------------------------------------===// 980 981 /// Utility method to write a range of APInt values to a buffer. 982 template <typename APRangeT> 983 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 984 APRangeT &&values) { 985 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); 986 size_t offset = 0; 987 for (auto it = values.begin(), e = values.end(); it != e; 988 ++it, offset += storageWidth) { 989 assert((*it).getBitWidth() <= storageWidth); 990 writeBits(data.data(), offset, *it); 991 } 992 } 993 994 /// Constructs a dense elements attribute from an array of raw APFloat values. 995 /// Each APFloat value is expected to have the same bitwidth as the element 996 /// type of 'type'. 'type' must be a vector or tensor with static shape. 997 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 998 size_t storageWidth, 999 ArrayRef<APFloat> values, 1000 bool isSplat) { 1001 std::vector<char> data; 1002 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1003 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1004 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1005 } 1006 1007 /// Constructs a dense elements attribute from an array of raw APInt values. 1008 /// Each APInt value is expected to have the same bitwidth as the element type 1009 /// of 'type'. 1010 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1011 size_t storageWidth, 1012 ArrayRef<APInt> values, 1013 bool isSplat) { 1014 std::vector<char> data; 1015 writeAPIntsToBuffer(storageWidth, data, values); 1016 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1017 } 1018 1019 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1020 ArrayRef<char> data, 1021 bool isSplat) { 1022 assert((type.isa<RankedTensorType, VectorType>()) && 1023 "type must be ranked tensor or vector"); 1024 assert(type.hasStaticShape() && "type must have static shape"); 1025 return Base::get(type.getContext(), type, data, isSplat); 1026 } 1027 1028 /// Overload of the raw 'get' method that asserts that the given type is of 1029 /// complex type. This method is used to verify type invariants that the 1030 /// templatized 'get' method cannot. 1031 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1032 ArrayRef<char> data, 1033 int64_t dataEltSize, 1034 bool isInt, 1035 bool isSigned) { 1036 assert(::isValidIntOrFloat( 1037 type.getElementType().cast<ComplexType>().getElementType(), 1038 dataEltSize / 2, isInt, isSigned)); 1039 1040 int64_t numElements = data.size() / dataEltSize; 1041 assert(numElements == 1 || numElements == type.getNumElements()); 1042 return getRaw(type, data, /*isSplat=*/numElements == 1); 1043 } 1044 1045 /// Overload of the 'getRaw' method that asserts that the given type is of 1046 /// integer type. This method is used to verify type invariants that the 1047 /// templatized 'get' method cannot. 1048 DenseElementsAttr 1049 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1050 int64_t dataEltSize, bool isInt, 1051 bool isSigned) { 1052 assert( 1053 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1054 1055 int64_t numElements = data.size() / dataEltSize; 1056 assert(numElements == 1 || numElements == type.getNumElements()); 1057 return getRaw(type, data, /*isSplat=*/numElements == 1); 1058 } 1059 1060 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 1061 const char *inRawData, char *outRawData, size_t elementBitWidth, 1062 size_t numElements) { 1063 using llvm::support::ulittle16_t; 1064 using llvm::support::ulittle32_t; 1065 using llvm::support::ulittle64_t; 1066 1067 assert(llvm::support::endian::system_endianness() == // NOLINT 1068 llvm::support::endianness::big); // NOLINT 1069 // NOLINT to avoid warning message about replacing by static_assert() 1070 1071 // Following std::copy_n always converts endianness on BE machine. 1072 switch (elementBitWidth) { 1073 case 16: { 1074 const ulittle16_t *inRawDataPos = 1075 reinterpret_cast<const ulittle16_t *>(inRawData); 1076 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); 1077 std::copy_n(inRawDataPos, numElements, outDataPos); 1078 break; 1079 } 1080 case 32: { 1081 const ulittle32_t *inRawDataPos = 1082 reinterpret_cast<const ulittle32_t *>(inRawData); 1083 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); 1084 std::copy_n(inRawDataPos, numElements, outDataPos); 1085 break; 1086 } 1087 case 64: { 1088 const ulittle64_t *inRawDataPos = 1089 reinterpret_cast<const ulittle64_t *>(inRawData); 1090 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); 1091 std::copy_n(inRawDataPos, numElements, outDataPos); 1092 break; 1093 } 1094 default: { 1095 size_t nBytes = elementBitWidth / CHAR_BIT; 1096 for (size_t i = 0; i < nBytes; i++) 1097 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); 1098 break; 1099 } 1100 } 1101 } 1102 1103 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1104 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, 1105 ShapedType type) { 1106 size_t numElements = type.getNumElements(); 1107 Type elementType = type.getElementType(); 1108 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1109 elementType = complexTy.getElementType(); 1110 numElements = numElements * 2; 1111 } 1112 size_t elementBitWidth = getDenseElementStorageWidth(elementType); 1113 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && 1114 inRawData.size() <= outRawData.size()); 1115 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), 1116 elementBitWidth, numElements); 1117 } 1118 1119 //===----------------------------------------------------------------------===// 1120 // DenseFPElementsAttr 1121 //===----------------------------------------------------------------------===// 1122 1123 template <typename Fn, typename Attr> 1124 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1125 Type newElementType, 1126 llvm::SmallVectorImpl<char> &data) { 1127 size_t bitWidth = getDenseElementBitWidth(newElementType); 1128 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1129 1130 ShapedType newArrayType; 1131 if (inType.isa<RankedTensorType>()) 1132 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1133 else if (inType.isa<UnrankedTensorType>()) 1134 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1135 else if (inType.isa<VectorType>()) 1136 newArrayType = VectorType::get(inType.getShape(), newElementType); 1137 else 1138 assert(newArrayType && "Unhandled tensor type"); 1139 1140 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1141 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 1142 1143 // Functor used to process a single element value of the attribute. 1144 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1145 auto newInt = mapping(value); 1146 assert(newInt.getBitWidth() == bitWidth); 1147 writeBits(data.data(), index * storageBitWidth, newInt); 1148 }; 1149 1150 // Check for the splat case. 1151 if (attr.isSplat()) { 1152 processElt(*attr.begin(), /*index=*/0); 1153 return newArrayType; 1154 } 1155 1156 // Otherwise, process all of the element values. 1157 uint64_t elementIdx = 0; 1158 for (auto value : attr) 1159 processElt(value, elementIdx++); 1160 return newArrayType; 1161 } 1162 1163 DenseElementsAttr DenseFPElementsAttr::mapValues( 1164 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1165 llvm::SmallVector<char, 8> elementData; 1166 auto newArrayType = 1167 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1168 1169 return getRaw(newArrayType, elementData, isSplat()); 1170 } 1171 1172 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1173 bool DenseFPElementsAttr::classof(Attribute attr) { 1174 return attr.isa<DenseElementsAttr>() && 1175 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1176 } 1177 1178 //===----------------------------------------------------------------------===// 1179 // DenseIntElementsAttr 1180 //===----------------------------------------------------------------------===// 1181 1182 DenseElementsAttr DenseIntElementsAttr::mapValues( 1183 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1184 llvm::SmallVector<char, 8> elementData; 1185 auto newArrayType = 1186 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1187 1188 return getRaw(newArrayType, elementData, isSplat()); 1189 } 1190 1191 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1192 bool DenseIntElementsAttr::classof(Attribute attr) { 1193 return attr.isa<DenseElementsAttr>() && 1194 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1195 } 1196 1197 //===----------------------------------------------------------------------===// 1198 // OpaqueElementsAttr 1199 //===----------------------------------------------------------------------===// 1200 1201 /// Return the value at the given index. If index does not refer to a valid 1202 /// element, then a null attribute is returned. 1203 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1204 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1205 return Attribute(); 1206 } 1207 1208 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1209 Dialect *dialect = getDialect().getDialect(); 1210 if (!dialect) 1211 return true; 1212 auto *interface = 1213 dialect->getRegisteredInterface<DialectDecodeAttributesInterface>(); 1214 if (!interface) 1215 return true; 1216 return failed(interface->decode(*this, result)); 1217 } 1218 1219 LogicalResult 1220 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1221 Identifier dialect, StringRef value, 1222 ShapedType type) { 1223 if (!Dialect::isValidNamespace(dialect.strref())) 1224 return emitError() << "invalid dialect namespace '" << dialect << "'"; 1225 return success(); 1226 } 1227 1228 //===----------------------------------------------------------------------===// 1229 // SparseElementsAttr 1230 //===----------------------------------------------------------------------===// 1231 1232 /// Return the value of the element at the given index. 1233 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1234 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1235 auto type = getType(); 1236 1237 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1238 // as a 1-D index array. 1239 auto sparseIndices = getIndices(); 1240 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1241 1242 // Check to see if the indices are a splat. 1243 if (sparseIndices.isSplat()) { 1244 // If the index is also not a splat of the index value, we know that the 1245 // value is zero. 1246 auto splatIndex = *sparseIndexValues.begin(); 1247 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1248 return getZeroAttr(); 1249 1250 // If the indices are a splat, we also expect the values to be a splat. 1251 assert(getValues().isSplat() && "expected splat values"); 1252 return getValues().getSplatValue(); 1253 } 1254 1255 // Build a mapping between known indices and the offset of the stored element. 1256 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1257 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1258 size_t rank = type.getRank(); 1259 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1260 mappedIndices.try_emplace( 1261 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1262 1263 // Look for the provided index key within the mapped indices. If the provided 1264 // index is not found, then return a zero attribute. 1265 auto it = mappedIndices.find(index); 1266 if (it == mappedIndices.end()) 1267 return getZeroAttr(); 1268 1269 // Otherwise, return the held sparse value element. 1270 return getValues().getValue(it->second); 1271 } 1272 1273 /// Get a zero APFloat for the given sparse attribute. 1274 APFloat SparseElementsAttr::getZeroAPFloat() const { 1275 auto eltType = getType().getElementType().cast<FloatType>(); 1276 return APFloat(eltType.getFloatSemantics()); 1277 } 1278 1279 /// Get a zero APInt for the given sparse attribute. 1280 APInt SparseElementsAttr::getZeroAPInt() const { 1281 auto eltType = getType().getElementType().cast<IntegerType>(); 1282 return APInt::getNullValue(eltType.getWidth()); 1283 } 1284 1285 /// Get a zero attribute for the given attribute type. 1286 Attribute SparseElementsAttr::getZeroAttr() const { 1287 auto eltType = getType().getElementType(); 1288 1289 // Handle floating point elements. 1290 if (eltType.isa<FloatType>()) 1291 return FloatAttr::get(eltType, 0); 1292 1293 // Otherwise, this is an integer. 1294 // TODO: Handle StringAttr here. 1295 return IntegerAttr::get(eltType, 0); 1296 } 1297 1298 /// Flatten, and return, all of the sparse indices in this attribute in 1299 /// row-major order. 1300 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1301 std::vector<ptrdiff_t> flatSparseIndices; 1302 1303 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1304 // as a 1-D index array. 1305 auto sparseIndices = getIndices(); 1306 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1307 if (sparseIndices.isSplat()) { 1308 SmallVector<uint64_t, 8> indices(getType().getRank(), 1309 *sparseIndexValues.begin()); 1310 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1311 return flatSparseIndices; 1312 } 1313 1314 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1315 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1316 size_t rank = getType().getRank(); 1317 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1318 flatSparseIndices.push_back(getFlattenedIndex( 1319 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1320 return flatSparseIndices; 1321 } 1322