1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/BuiltinDialect.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/IntegerSet.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/IR/Operation.h" 17 #include "mlir/IR/SymbolTable.h" 18 #include "mlir/IR/Types.h" 19 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" 20 #include "llvm/ADT/APSInt.h" 21 #include "llvm/ADT/Sequence.h" 22 #include "llvm/Support/Endian.h" 23 24 using namespace mlir; 25 using namespace mlir::detail; 26 27 //===----------------------------------------------------------------------===// 28 /// Tablegen Attribute Definitions 29 //===----------------------------------------------------------------------===// 30 31 #define GET_ATTRDEF_CLASSES 32 #include "mlir/IR/BuiltinAttributes.cpp.inc" 33 34 //===----------------------------------------------------------------------===// 35 // BuiltinDialect 36 //===----------------------------------------------------------------------===// 37 38 void BuiltinDialect::registerAttributes() { 39 addAttributes<AffineMapAttr, ArrayAttr, DenseArrayBaseAttr, 40 DenseIntOrFPElementsAttr, DenseStringElementsAttr, 41 DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr, 42 IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr, 43 SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>(); 44 } 45 46 //===----------------------------------------------------------------------===// 47 // ArrayAttr 48 //===----------------------------------------------------------------------===// 49 50 void ArrayAttr::walkImmediateSubElements( 51 function_ref<void(Attribute)> walkAttrsFn, 52 function_ref<void(Type)> walkTypesFn) const { 53 for (Attribute attr : getValue()) 54 walkAttrsFn(attr); 55 } 56 57 SubElementAttrInterface ArrayAttr::replaceImmediateSubAttribute( 58 ArrayRef<std::pair<size_t, Attribute>> replacements) const { 59 std::vector<Attribute> vector = getValue().vec(); 60 for (auto &it : replacements) { 61 vector[it.first] = it.second; 62 } 63 return get(getContext(), vector); 64 } 65 66 //===----------------------------------------------------------------------===// 67 // DictionaryAttr 68 //===----------------------------------------------------------------------===// 69 70 /// Helper function that does either an in place sort or sorts from source array 71 /// into destination. If inPlace then storage is both the source and the 72 /// destination, else value is the source and storage destination. Returns 73 /// whether source was sorted. 74 template <bool inPlace> 75 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 76 SmallVectorImpl<NamedAttribute> &storage) { 77 // Specialize for the common case. 78 switch (value.size()) { 79 case 0: 80 // Zero already sorted. 81 if (!inPlace) 82 storage.clear(); 83 break; 84 case 1: 85 // One already sorted but may need to be copied. 86 if (!inPlace) 87 storage.assign({value[0]}); 88 break; 89 case 2: { 90 bool isSorted = value[0] < value[1]; 91 if (inPlace) { 92 if (!isSorted) 93 std::swap(storage[0], storage[1]); 94 } else if (isSorted) { 95 storage.assign({value[0], value[1]}); 96 } else { 97 storage.assign({value[1], value[0]}); 98 } 99 return !isSorted; 100 } 101 default: 102 if (!inPlace) 103 storage.assign(value.begin(), value.end()); 104 // Check to see they are sorted already. 105 bool isSorted = llvm::is_sorted(value); 106 // If not, do a general sort. 107 if (!isSorted) 108 llvm::array_pod_sort(storage.begin(), storage.end()); 109 return !isSorted; 110 } 111 return false; 112 } 113 114 /// Returns an entry with a duplicate name from the given sorted array of named 115 /// attributes. Returns llvm::None if all elements have unique names. 116 static Optional<NamedAttribute> 117 findDuplicateElement(ArrayRef<NamedAttribute> value) { 118 const Optional<NamedAttribute> none{llvm::None}; 119 if (value.size() < 2) 120 return none; 121 122 if (value.size() == 2) 123 return value[0].getName() == value[1].getName() ? value[0] : none; 124 125 const auto *it = std::adjacent_find(value.begin(), value.end(), 126 [](NamedAttribute l, NamedAttribute r) { 127 return l.getName() == r.getName(); 128 }); 129 return it != value.end() ? *it : none; 130 } 131 132 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 133 SmallVectorImpl<NamedAttribute> &storage) { 134 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); 135 assert(!findDuplicateElement(storage) && 136 "DictionaryAttr element names must be unique"); 137 return isSorted; 138 } 139 140 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 141 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); 142 assert(!findDuplicateElement(array) && 143 "DictionaryAttr element names must be unique"); 144 return isSorted; 145 } 146 147 Optional<NamedAttribute> 148 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, 149 bool isSorted) { 150 if (!isSorted) 151 dictionaryAttrSort</*inPlace=*/true>(array, array); 152 return findDuplicateElement(array); 153 } 154 155 DictionaryAttr DictionaryAttr::get(MLIRContext *context, 156 ArrayRef<NamedAttribute> value) { 157 if (value.empty()) 158 return DictionaryAttr::getEmpty(context); 159 160 // We need to sort the element list to canonicalize it. 161 SmallVector<NamedAttribute, 8> storage; 162 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 163 value = storage; 164 assert(!findDuplicateElement(value) && 165 "DictionaryAttr element names must be unique"); 166 return Base::get(context, value); 167 } 168 /// Construct a dictionary with an array of values that is known to already be 169 /// sorted by name and uniqued. 170 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context, 171 ArrayRef<NamedAttribute> value) { 172 if (value.empty()) 173 return DictionaryAttr::getEmpty(context); 174 // Ensure that the attribute elements are unique and sorted. 175 assert(llvm::is_sorted( 176 value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && 177 "expected attribute values to be sorted"); 178 assert(!findDuplicateElement(value) && 179 "DictionaryAttr element names must be unique"); 180 return Base::get(context, value); 181 } 182 183 /// Return the specified attribute if present, null otherwise. 184 Attribute DictionaryAttr::get(StringRef name) const { 185 auto it = impl::findAttrSorted(begin(), end(), name); 186 return it.second ? it.first->getValue() : Attribute(); 187 } 188 Attribute DictionaryAttr::get(StringAttr name) const { 189 auto it = impl::findAttrSorted(begin(), end(), name); 190 return it.second ? it.first->getValue() : Attribute(); 191 } 192 193 /// Return the specified named attribute if present, None otherwise. 194 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 195 auto it = impl::findAttrSorted(begin(), end(), name); 196 return it.second ? *it.first : Optional<NamedAttribute>(); 197 } 198 Optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const { 199 auto it = impl::findAttrSorted(begin(), end(), name); 200 return it.second ? *it.first : Optional<NamedAttribute>(); 201 } 202 203 /// Return whether the specified attribute is present. 204 bool DictionaryAttr::contains(StringRef name) const { 205 return impl::findAttrSorted(begin(), end(), name).second; 206 } 207 bool DictionaryAttr::contains(StringAttr name) const { 208 return impl::findAttrSorted(begin(), end(), name).second; 209 } 210 211 DictionaryAttr::iterator DictionaryAttr::begin() const { 212 return getValue().begin(); 213 } 214 DictionaryAttr::iterator DictionaryAttr::end() const { 215 return getValue().end(); 216 } 217 size_t DictionaryAttr::size() const { return getValue().size(); } 218 219 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { 220 return Base::get(context, ArrayRef<NamedAttribute>()); 221 } 222 223 void DictionaryAttr::walkImmediateSubElements( 224 function_ref<void(Attribute)> walkAttrsFn, 225 function_ref<void(Type)> walkTypesFn) const { 226 for (const NamedAttribute &attr : getValue()) 227 walkAttrsFn(attr.getValue()); 228 } 229 230 SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute( 231 ArrayRef<std::pair<size_t, Attribute>> replacements) const { 232 std::vector<NamedAttribute> vec = getValue().vec(); 233 for (auto &it : replacements) 234 vec[it.first].setValue(it.second); 235 236 // The above only modifies the mapped value, but not the key, and therefore 237 // not the order of the elements. It remains sorted 238 return getWithSorted(getContext(), vec); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // StringAttr 243 //===----------------------------------------------------------------------===// 244 245 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) { 246 return Base::get(context, "", NoneType::get(context)); 247 } 248 249 /// Twine support for StringAttr. 250 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) { 251 // Fast-path empty twine. 252 if (twine.isTriviallyEmpty()) 253 return get(context); 254 SmallVector<char, 32> tempStr; 255 return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context)); 256 } 257 258 /// Twine support for StringAttr. 259 StringAttr StringAttr::get(const Twine &twine, Type type) { 260 SmallVector<char, 32> tempStr; 261 return Base::get(type.getContext(), twine.toStringRef(tempStr), type); 262 } 263 264 StringRef StringAttr::getValue() const { return getImpl()->value; } 265 266 Dialect *StringAttr::getReferencedDialect() const { 267 return getImpl()->referencedDialect; 268 } 269 270 //===----------------------------------------------------------------------===// 271 // FloatAttr 272 //===----------------------------------------------------------------------===// 273 274 double FloatAttr::getValueAsDouble() const { 275 return getValueAsDouble(getValue()); 276 } 277 double FloatAttr::getValueAsDouble(APFloat value) { 278 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 279 bool losesInfo = false; 280 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 281 &losesInfo); 282 } 283 return value.convertToDouble(); 284 } 285 286 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 287 Type type, APFloat value) { 288 // Verify that the type is correct. 289 if (!type.isa<FloatType>()) 290 return emitError() << "expected floating point type"; 291 292 // Verify that the type semantics match that of the value. 293 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 294 return emitError() 295 << "FloatAttr type doesn't match the type implied by its value"; 296 } 297 return success(); 298 } 299 300 //===----------------------------------------------------------------------===// 301 // SymbolRefAttr 302 //===----------------------------------------------------------------------===// 303 304 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value, 305 ArrayRef<FlatSymbolRefAttr> nestedRefs) { 306 return get(StringAttr::get(ctx, value), nestedRefs); 307 } 308 309 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { 310 return get(ctx, value, {}).cast<FlatSymbolRefAttr>(); 311 } 312 313 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) { 314 return get(value, {}).cast<FlatSymbolRefAttr>(); 315 } 316 317 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) { 318 auto symName = 319 symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); 320 assert(symName && "value does not have a valid symbol name"); 321 return SymbolRefAttr::get(symName); 322 } 323 324 StringAttr SymbolRefAttr::getLeafReference() const { 325 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 326 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr(); 327 } 328 329 //===----------------------------------------------------------------------===// 330 // IntegerAttr 331 //===----------------------------------------------------------------------===// 332 333 int64_t IntegerAttr::getInt() const { 334 assert((getType().isIndex() || getType().isSignlessInteger()) && 335 "must be signless integer"); 336 return getValue().getSExtValue(); 337 } 338 339 int64_t IntegerAttr::getSInt() const { 340 assert(getType().isSignedInteger() && "must be signed integer"); 341 return getValue().getSExtValue(); 342 } 343 344 uint64_t IntegerAttr::getUInt() const { 345 assert(getType().isUnsignedInteger() && "must be unsigned integer"); 346 return getValue().getZExtValue(); 347 } 348 349 /// Return the value as an APSInt which carries the signed from the type of 350 /// the attribute. This traps on signless integers types! 351 APSInt IntegerAttr::getAPSInt() const { 352 assert(!getType().isSignlessInteger() && 353 "Signless integers don't carry a sign for APSInt"); 354 return APSInt(getValue(), getType().isUnsignedInteger()); 355 } 356 357 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, 358 Type type, APInt value) { 359 if (IntegerType integerType = type.dyn_cast<IntegerType>()) { 360 if (integerType.getWidth() != value.getBitWidth()) 361 return emitError() << "integer type bit width (" << integerType.getWidth() 362 << ") doesn't match value bit width (" 363 << value.getBitWidth() << ")"; 364 return success(); 365 } 366 if (type.isa<IndexType>()) 367 return success(); 368 return emitError() << "expected integer or index type"; 369 } 370 371 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { 372 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value)); 373 return attr.cast<BoolAttr>(); 374 } 375 376 //===----------------------------------------------------------------------===// 377 // BoolAttr 378 //===----------------------------------------------------------------------===// 379 380 bool BoolAttr::getValue() const { 381 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl); 382 return storage->value.getBoolValue(); 383 } 384 385 bool BoolAttr::classof(Attribute attr) { 386 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 387 return intAttr && intAttr.getType().isSignlessInteger(1); 388 } 389 390 //===----------------------------------------------------------------------===// 391 // OpaqueAttr 392 //===----------------------------------------------------------------------===// 393 394 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError, 395 StringAttr dialect, StringRef attrData, 396 Type type) { 397 if (!Dialect::isValidNamespace(dialect.strref())) 398 return emitError() << "invalid dialect namespace '" << dialect << "'"; 399 400 // Check that the dialect is actually registered. 401 MLIRContext *context = dialect.getContext(); 402 if (!context->allowsUnregisteredDialects() && 403 !context->getLoadedDialect(dialect.strref())) { 404 return emitError() 405 << "#" << dialect << "<\"" << attrData << "\"> : " << type 406 << " attribute created with unregistered dialect. If this is " 407 "intended, please call allowUnregisteredDialects() on the " 408 "MLIRContext, or use -allow-unregistered-dialect with " 409 "the MLIR opt tool used"; 410 } 411 412 return success(); 413 } 414 415 //===----------------------------------------------------------------------===// 416 // DenseElementsAttr Utilities 417 //===----------------------------------------------------------------------===// 418 419 /// Get the bitwidth of a dense element type within the buffer. 420 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 421 static size_t getDenseElementStorageWidth(size_t origWidth) { 422 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 423 } 424 static size_t getDenseElementStorageWidth(Type elementType) { 425 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 426 } 427 428 /// Set a bit to a specific value. 429 static void setBit(char *rawData, size_t bitPos, bool value) { 430 if (value) 431 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 432 else 433 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 434 } 435 436 /// Return the value of the specified bit. 437 static bool getBit(const char *rawData, size_t bitPos) { 438 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 439 } 440 441 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for 442 /// BE format. 443 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, 444 char *result) { 445 assert(llvm::support::endian::system_endianness() == // NOLINT 446 llvm::support::endianness::big); // NOLINT 447 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 448 449 // Copy the words filled with data. 450 // For example, when `value` has 2 words, the first word is filled with data. 451 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| 452 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 453 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 454 numFilledWords, result); 455 // Convert last word of APInt to LE format and store it in char 456 // array(`valueLE`). 457 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| 458 size_t lastWordPos = numFilledWords; 459 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); 460 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 461 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, 462 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); 463 // Extract actual APInt data from `valueLE`, convert endianness to BE format, 464 // and store it in `result`. 465 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| 466 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 467 valueLE.begin(), result + lastWordPos, 468 (numBytes - lastWordPos) * CHAR_BIT, 1); 469 } 470 471 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE 472 /// format. 473 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, 474 APInt &result) { 475 assert(llvm::support::endian::system_endianness() == // NOLINT 476 llvm::support::endianness::big); // NOLINT 477 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 478 479 // Copy the data that fills the word of `result` from `inArray`. 480 // For example, when `result` has 2 words, the first word will be filled with 481 // data. So, the first 8 bytes are copied from `inArray` here. 482 // `inArray` (10 bytes, BE): |abcdefgh|ij| 483 // ==> `result` (2 words, BE): |abcdefgh|--------| 484 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 485 std::copy_n( 486 inArray, numFilledWords, 487 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 488 489 // Convert array data which will be last word of `result` to LE format, and 490 // store it in char array(`inArrayLE`). 491 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| 492 size_t lastWordPos = numFilledWords; 493 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); 494 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 495 inArray + lastWordPos, inArrayLE.begin(), 496 (numBytes - lastWordPos) * CHAR_BIT, 1); 497 498 // Convert `inArrayLE` to BE format, and store it in last word of `result`. 499 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| 500 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 501 inArrayLE.begin(), 502 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + 503 lastWordPos, 504 APInt::APINT_BITS_PER_WORD, 1); 505 } 506 507 /// Writes value to the bit position `bitPos` in array `rawData`. 508 static void writeBits(char *rawData, size_t bitPos, APInt value) { 509 size_t bitWidth = value.getBitWidth(); 510 511 // If the bitwidth is 1 we just toggle the specific bit. 512 if (bitWidth == 1) 513 return setBit(rawData, bitPos, value.isOneValue()); 514 515 // Otherwise, the bit position is guaranteed to be byte aligned. 516 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 517 if (llvm::support::endian::system_endianness() == 518 llvm::support::endianness::big) { 519 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. 520 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 521 // work correctly in BE format. 522 // ex. `value` (2 words including 10 bytes) 523 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| 524 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), 525 rawData + (bitPos / CHAR_BIT)); 526 } else { 527 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 528 llvm::divideCeil(bitWidth, CHAR_BIT), 529 rawData + (bitPos / CHAR_BIT)); 530 } 531 } 532 533 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 534 /// `rawData`. 535 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 536 // Handle a boolean bit position. 537 if (bitWidth == 1) 538 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 539 540 // Otherwise, the bit position must be 8-bit aligned. 541 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 542 APInt result(bitWidth, 0); 543 if (llvm::support::endian::system_endianness() == 544 llvm::support::endianness::big) { 545 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. 546 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 547 // work correctly in BE format. 548 // ex. `result` (2 words including 10 bytes) 549 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function 550 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), 551 llvm::divideCeil(bitWidth, CHAR_BIT), result); 552 } else { 553 std::copy_n(rawData + (bitPos / CHAR_BIT), 554 llvm::divideCeil(bitWidth, CHAR_BIT), 555 const_cast<char *>( 556 reinterpret_cast<const char *>(result.getRawData()))); 557 } 558 return result; 559 } 560 561 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 562 /// the same element count as 'type'. 563 template <typename Values> 564 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 565 return (values.size() == 1) || 566 (type.getNumElements() == static_cast<int64_t>(values.size())); 567 } 568 569 //===----------------------------------------------------------------------===// 570 // DenseElementsAttr Iterators 571 //===----------------------------------------------------------------------===// 572 573 //===----------------------------------------------------------------------===// 574 // AttributeElementIterator 575 576 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 577 DenseElementsAttr attr, size_t index) 578 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 579 Attribute, Attribute, Attribute>( 580 attr.getAsOpaquePointer(), index) {} 581 582 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 583 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 584 Type eltTy = owner.getElementType(); 585 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 586 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 587 if (eltTy.isa<IndexType>()) 588 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 589 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 590 IntElementIterator intIt(owner, index); 591 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 592 return FloatAttr::get(eltTy, *floatIt); 593 } 594 if (auto complexTy = eltTy.dyn_cast<ComplexType>()) { 595 auto complexEltTy = complexTy.getElementType(); 596 ComplexIntElementIterator complexIntIt(owner, index); 597 if (complexEltTy.isa<IntegerType>()) { 598 auto value = *complexIntIt; 599 auto real = IntegerAttr::get(complexEltTy, value.real()); 600 auto imag = IntegerAttr::get(complexEltTy, value.imag()); 601 return ArrayAttr::get(complexTy.getContext(), 602 ArrayRef<Attribute>{real, imag}); 603 } 604 605 ComplexFloatElementIterator complexFloatIt( 606 complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt); 607 auto value = *complexFloatIt; 608 auto real = FloatAttr::get(complexEltTy, value.real()); 609 auto imag = FloatAttr::get(complexEltTy, value.imag()); 610 return ArrayAttr::get(complexTy.getContext(), 611 ArrayRef<Attribute>{real, imag}); 612 } 613 if (owner.isa<DenseStringElementsAttr>()) { 614 ArrayRef<StringRef> vals = owner.getRawStringData(); 615 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 616 } 617 llvm_unreachable("unexpected element type"); 618 } 619 620 //===----------------------------------------------------------------------===// 621 // BoolElementIterator 622 623 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 624 DenseElementsAttr attr, size_t dataIndex) 625 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 626 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 627 628 bool DenseElementsAttr::BoolElementIterator::operator*() const { 629 return getBit(getData(), getDataIndex()); 630 } 631 632 //===----------------------------------------------------------------------===// 633 // IntElementIterator 634 635 DenseElementsAttr::IntElementIterator::IntElementIterator( 636 DenseElementsAttr attr, size_t dataIndex) 637 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 638 attr.getRawData().data(), attr.isSplat(), dataIndex), 639 bitWidth(getDenseElementBitWidth(attr.getElementType())) {} 640 641 APInt DenseElementsAttr::IntElementIterator::operator*() const { 642 return readBits(getData(), 643 getDataIndex() * getDenseElementStorageWidth(bitWidth), 644 bitWidth); 645 } 646 647 //===----------------------------------------------------------------------===// 648 // ComplexIntElementIterator 649 650 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 651 DenseElementsAttr attr, size_t dataIndex) 652 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 653 std::complex<APInt>, std::complex<APInt>, 654 std::complex<APInt>>( 655 attr.getRawData().data(), attr.isSplat(), dataIndex) { 656 auto complexType = attr.getElementType().cast<ComplexType>(); 657 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 658 } 659 660 std::complex<APInt> 661 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 662 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 663 size_t offset = getDataIndex() * storageWidth * 2; 664 return {readBits(getData(), offset, bitWidth), 665 readBits(getData(), offset + storageWidth, bitWidth)}; 666 } 667 668 //===----------------------------------------------------------------------===// 669 // DenseArrayAttr 670 //===----------------------------------------------------------------------===// 671 672 /// Custom storage to ensure proper memory alignment for the allocation of 673 /// DenseArray of any element type. 674 struct mlir::detail::DenseArrayBaseAttrStorage : public AttributeStorage { 675 using KeyTy = std::tuple<ShapedType, DenseArrayBaseAttr::EltType, 676 ::llvm::ArrayRef<char>>; 677 DenseArrayBaseAttrStorage(ShapedType type, 678 DenseArrayBaseAttr::EltType eltType, 679 ::llvm::ArrayRef<char> elements) 680 : AttributeStorage(type), eltType(eltType), elements(elements) {} 681 682 bool operator==(const KeyTy &tblgenKey) const { 683 return (getType() == std::get<0>(tblgenKey)) && 684 (eltType == std::get<1>(tblgenKey)) && 685 (elements == std::get<2>(tblgenKey)); 686 } 687 688 static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { 689 return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey), 690 std::get<2>(tblgenKey)); 691 } 692 693 static DenseArrayBaseAttrStorage * 694 construct(AttributeStorageAllocator &allocator, const KeyTy &tblgenKey) { 695 auto type = std::get<0>(tblgenKey); 696 auto eltType = std::get<1>(tblgenKey); 697 auto elements = std::get<2>(tblgenKey); 698 if (!elements.empty()) { 699 char *alloc = static_cast<char *>( 700 allocator.allocate(elements.size(), alignof(uint64_t))); 701 std::uninitialized_copy(elements.begin(), elements.end(), alloc); 702 elements = ArrayRef<char>(alloc, elements.size()); 703 } 704 return new (allocator.allocate<DenseArrayBaseAttrStorage>()) 705 DenseArrayBaseAttrStorage(type, eltType, elements); 706 } 707 708 DenseArrayBaseAttr::EltType eltType; 709 ::llvm::ArrayRef<char> elements; 710 }; 711 712 DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const { 713 return getImpl()->eltType; 714 } 715 716 const int8_t * 717 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const { 718 return cast<DenseI8ArrayAttr>().asArrayRef().begin(); 719 } 720 const int16_t * 721 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int16_t>) const { 722 return cast<DenseI16ArrayAttr>().asArrayRef().begin(); 723 } 724 const int32_t * 725 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int32_t>) const { 726 return cast<DenseI32ArrayAttr>().asArrayRef().begin(); 727 } 728 const int64_t * 729 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int64_t>) const { 730 return cast<DenseI64ArrayAttr>().asArrayRef().begin(); 731 } 732 const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken<float>) const { 733 return cast<DenseF32ArrayAttr>().asArrayRef().begin(); 734 } 735 const double * 736 DenseArrayBaseAttr::value_begin_impl(OverloadToken<double>) const { 737 return cast<DenseF64ArrayAttr>().asArrayRef().begin(); 738 } 739 740 void DenseArrayBaseAttr::print(AsmPrinter &printer) const { 741 print(printer.getStream()); 742 } 743 744 void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const { 745 switch (getElementType()) { 746 case DenseArrayBaseAttr::EltType::I8: 747 this->cast<DenseI8ArrayAttr>().printWithoutBraces(os); 748 return; 749 case DenseArrayBaseAttr::EltType::I16: 750 this->cast<DenseI16ArrayAttr>().printWithoutBraces(os); 751 return; 752 case DenseArrayBaseAttr::EltType::I32: 753 this->cast<DenseI32ArrayAttr>().printWithoutBraces(os); 754 return; 755 case DenseArrayBaseAttr::EltType::I64: 756 this->cast<DenseI64ArrayAttr>().printWithoutBraces(os); 757 return; 758 case DenseArrayBaseAttr::EltType::F32: 759 this->cast<DenseF32ArrayAttr>().printWithoutBraces(os); 760 return; 761 case DenseArrayBaseAttr::EltType::F64: 762 this->cast<DenseF64ArrayAttr>().printWithoutBraces(os); 763 return; 764 } 765 llvm_unreachable("<unknown DenseArrayBaseAttr>"); 766 } 767 768 void DenseArrayBaseAttr::print(raw_ostream &os) const { 769 os << "["; 770 printWithoutBraces(os); 771 os << "]"; 772 } 773 774 template <typename T> 775 void DenseArrayAttr<T>::print(AsmPrinter &printer) const { 776 print(printer.getStream()); 777 } 778 779 template <typename T> 780 void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const { 781 ArrayRef<T> values{*this}; 782 llvm::interleaveComma(values, os); 783 } 784 785 /// Specialization for int8_t for forcing printing as number instead of chars. 786 template <> 787 void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const { 788 ArrayRef<int8_t> values{*this}; 789 llvm::interleaveComma(values, os, [&](int64_t v) { os << v; }); 790 } 791 792 template <typename T> 793 void DenseArrayAttr<T>::print(raw_ostream &os) const { 794 os << "["; 795 printWithoutBraces(os); 796 os << "]"; 797 } 798 799 /// Parse a single element: generic template for int types, specialized for 800 /// floating points below. 801 template <typename T> 802 static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) { 803 return parser.parseInteger(value); 804 } 805 806 template <> 807 ParseResult parseDenseArrayAttrElt<float>(AsmParser &parser, float &value) { 808 double doubleVal; 809 if (parser.parseFloat(doubleVal)) 810 return failure(); 811 value = doubleVal; 812 return success(); 813 } 814 815 template <> 816 ParseResult parseDenseArrayAttrElt<double>(AsmParser &parser, double &value) { 817 return parser.parseFloat(value); 818 } 819 820 /// Parse a DenseArrayAttr without the braces: `1, 2, 3` 821 template <typename T> 822 Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser, 823 Type odsType) { 824 SmallVector<T> data; 825 if (failed(parser.parseCommaSeparatedList([&]() { 826 T value; 827 if (parseDenseArrayAttrElt(parser, value)) 828 return failure(); 829 data.push_back(value); 830 return success(); 831 }))) 832 return {}; 833 return get(parser.getContext(), data); 834 } 835 836 /// Parse a DenseArrayAttr: `[ 1, 2, 3 ]` 837 template <typename T> 838 Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) { 839 if (parser.parseLSquare()) 840 return {}; 841 Attribute result = parseWithoutBraces(parser, odsType); 842 if (parser.parseRSquare()) 843 return {}; 844 return result; 845 } 846 847 /// Conversion from DenseArrayAttr<T> to ArrayRef<T>. 848 template <typename T> 849 DenseArrayAttr<T>::operator ArrayRef<T>() const { 850 ArrayRef<char> raw = getImpl()->elements; 851 assert((raw.size() % sizeof(T)) == 0); 852 return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()), 853 raw.size() / sizeof(T)); 854 } 855 856 namespace { 857 /// Mapping from C++ element type to MLIR DenseArrayAttr internals. 858 template <typename T> 859 struct denseArrayAttrEltTypeBuilder; 860 template <> 861 struct denseArrayAttrEltTypeBuilder<int8_t> { 862 constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8; 863 static ShapedType getShapedType(MLIRContext *context, int64_t shape) { 864 return VectorType::get(shape, IntegerType::get(context, 8)); 865 } 866 }; 867 template <> 868 struct denseArrayAttrEltTypeBuilder<int16_t> { 869 constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16; 870 static ShapedType getShapedType(MLIRContext *context, int64_t shape) { 871 return VectorType::get(shape, IntegerType::get(context, 16)); 872 } 873 }; 874 template <> 875 struct denseArrayAttrEltTypeBuilder<int32_t> { 876 constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32; 877 static ShapedType getShapedType(MLIRContext *context, int64_t shape) { 878 return VectorType::get(shape, IntegerType::get(context, 32)); 879 } 880 }; 881 template <> 882 struct denseArrayAttrEltTypeBuilder<int64_t> { 883 constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64; 884 static ShapedType getShapedType(MLIRContext *context, int64_t shape) { 885 return VectorType::get(shape, IntegerType::get(context, 64)); 886 } 887 }; 888 template <> 889 struct denseArrayAttrEltTypeBuilder<float> { 890 constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32; 891 static ShapedType getShapedType(MLIRContext *context, int64_t shape) { 892 return VectorType::get(shape, Float32Type::get(context)); 893 } 894 }; 895 template <> 896 struct denseArrayAttrEltTypeBuilder<double> { 897 constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64; 898 static ShapedType getShapedType(MLIRContext *context, int64_t shape) { 899 return VectorType::get(shape, Float64Type::get(context)); 900 } 901 }; 902 } // namespace 903 904 /// Builds a DenseArrayAttr<T> from an ArrayRef<T>. 905 template <typename T> 906 DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context, 907 ArrayRef<T> content) { 908 auto shapedType = 909 denseArrayAttrEltTypeBuilder<T>::getShapedType(context, content.size()); 910 auto eltType = denseArrayAttrEltTypeBuilder<T>::eltType; 911 auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()), 912 content.size() * sizeof(T)); 913 return Base::get(context, shapedType, eltType, rawArray) 914 .template cast<DenseArrayAttr<T>>(); 915 } 916 917 template <typename T> 918 bool DenseArrayAttr<T>::classof(Attribute attr) { 919 return attr.isa<DenseArrayBaseAttr>() && 920 attr.cast<DenseArrayBaseAttr>().getElementType() == 921 denseArrayAttrEltTypeBuilder<T>::eltType; 922 } 923 924 namespace mlir { 925 namespace detail { 926 // Explicit instantiation for all the supported DenseArrayAttr. 927 template class DenseArrayAttr<int8_t>; 928 template class DenseArrayAttr<int16_t>; 929 template class DenseArrayAttr<int32_t>; 930 template class DenseArrayAttr<int64_t>; 931 template class DenseArrayAttr<float>; 932 template class DenseArrayAttr<double>; 933 } // namespace detail 934 } // namespace mlir 935 936 //===----------------------------------------------------------------------===// 937 // DenseElementsAttr 938 //===----------------------------------------------------------------------===// 939 940 /// Method for support type inquiry through isa, cast and dyn_cast. 941 bool DenseElementsAttr::classof(Attribute attr) { 942 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 943 } 944 945 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 946 ArrayRef<Attribute> values) { 947 assert(hasSameElementsOrSplat(type, values)); 948 949 // If the element type is not based on int/float/index, assume it is a string 950 // type. 951 auto eltType = type.getElementType(); 952 if (!type.getElementType().isIntOrIndexOrFloat()) { 953 SmallVector<StringRef, 8> stringValues; 954 stringValues.reserve(values.size()); 955 for (Attribute attr : values) { 956 assert(attr.isa<StringAttr>() && 957 "expected string value for non integer/index/float element"); 958 stringValues.push_back(attr.cast<StringAttr>().getValue()); 959 } 960 return get(type, stringValues); 961 } 962 963 // Otherwise, get the raw storage width to use for the allocation. 964 size_t bitWidth = getDenseElementBitWidth(eltType); 965 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 966 967 // Compress the attribute values into a character buffer. 968 SmallVector<char, 8> data( 969 llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT)); 970 APInt intVal; 971 for (unsigned i = 0, e = values.size(); i < e; ++i) { 972 assert(eltType == values[i].getType() && 973 "expected attribute value to have element type"); 974 if (eltType.isa<FloatType>()) 975 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 976 else if (eltType.isa<IntegerType, IndexType>()) 977 intVal = values[i].cast<IntegerAttr>().getValue(); 978 else 979 llvm_unreachable("unexpected element type"); 980 981 assert(intVal.getBitWidth() == bitWidth && 982 "expected value to have same bitwidth as element type"); 983 writeBits(data.data(), i * storageBitWidth, intVal); 984 } 985 986 // Handle the special encoding of splat of bool. 987 if (values.size() == 1 && values[0].getType().isInteger(1)) 988 data[0] = data[0] ? -1 : 0; 989 990 return DenseIntOrFPElementsAttr::getRaw(type, data); 991 } 992 993 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 994 ArrayRef<bool> values) { 995 assert(hasSameElementsOrSplat(type, values)); 996 assert(type.getElementType().isInteger(1)); 997 998 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 999 1000 if (!values.empty()) { 1001 bool isSplat = true; 1002 bool firstValue = values[0]; 1003 for (int i = 0, e = values.size(); i != e; ++i) { 1004 isSplat &= values[i] == firstValue; 1005 setBit(buff.data(), i, values[i]); 1006 } 1007 1008 // Splat of bool is encoded as a byte with all-ones in it. 1009 if (isSplat) { 1010 buff.resize(1); 1011 buff[0] = values[0] ? -1 : 0; 1012 } 1013 } 1014 1015 return DenseIntOrFPElementsAttr::getRaw(type, buff); 1016 } 1017 1018 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 1019 ArrayRef<StringRef> values) { 1020 assert(!type.getElementType().isIntOrFloat()); 1021 return DenseStringElementsAttr::get(type, values); 1022 } 1023 1024 /// Constructs a dense integer elements attribute from an array of APInt 1025 /// values. Each APInt value is expected to have the same bitwidth as the 1026 /// element type of 'type'. 1027 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 1028 ArrayRef<APInt> values) { 1029 assert(type.getElementType().isIntOrIndex()); 1030 assert(hasSameElementsOrSplat(type, values)); 1031 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 1032 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); 1033 } 1034 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 1035 ArrayRef<std::complex<APInt>> values) { 1036 ComplexType complex = type.getElementType().cast<ComplexType>(); 1037 assert(complex.getElementType().isa<IntegerType>()); 1038 assert(hasSameElementsOrSplat(type, values)); 1039 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 1040 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 1041 values.size() * 2); 1042 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals); 1043 } 1044 1045 // Constructs a dense float elements attribute from an array of APFloat 1046 // values. Each APFloat value is expected to have the same bitwidth as the 1047 // element type of 'type'. 1048 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 1049 ArrayRef<APFloat> values) { 1050 assert(type.getElementType().isa<FloatType>()); 1051 assert(hasSameElementsOrSplat(type, values)); 1052 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 1053 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); 1054 } 1055 DenseElementsAttr 1056 DenseElementsAttr::get(ShapedType type, 1057 ArrayRef<std::complex<APFloat>> values) { 1058 ComplexType complex = type.getElementType().cast<ComplexType>(); 1059 assert(complex.getElementType().isa<FloatType>()); 1060 assert(hasSameElementsOrSplat(type, values)); 1061 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 1062 values.size() * 2); 1063 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 1064 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals); 1065 } 1066 1067 /// Construct a dense elements attribute from a raw buffer representing the 1068 /// data for this attribute. Users should generally not use this methods as 1069 /// the expected buffer format may not be a form the user expects. 1070 DenseElementsAttr 1071 DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) { 1072 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer); 1073 } 1074 1075 /// Returns true if the given buffer is a valid raw buffer for the given type. 1076 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 1077 ArrayRef<char> rawBuffer, 1078 bool &detectedSplat) { 1079 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 1080 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 1081 int64_t numElements = type.getNumElements(); 1082 1083 // The initializer is always a splat if the result type has a single element. 1084 detectedSplat = numElements == 1; 1085 1086 // Storage width of 1 is special as it is packed by the bit. 1087 if (storageWidth == 1) { 1088 // Check for a splat, or a buffer equal to the number of elements which 1089 // consists of either all 0's or all 1's. 1090 if (rawBuffer.size() == 1) { 1091 auto rawByte = static_cast<uint8_t>(rawBuffer[0]); 1092 if (rawByte == 0 || rawByte == 0xff) { 1093 detectedSplat = true; 1094 return true; 1095 } 1096 } 1097 1098 // This is a valid non-splat buffer if it has the right size. 1099 return rawBufferWidth == llvm::alignTo<8>(numElements); 1100 } 1101 1102 // All other types are 8-bit aligned, so we can just check the buffer width 1103 // to know if only a single initializer element was passed in. 1104 if (rawBufferWidth == storageWidth) { 1105 detectedSplat = true; 1106 return true; 1107 } 1108 1109 // The raw buffer is valid if it has the right size. 1110 return rawBufferWidth == storageWidth * numElements; 1111 } 1112 1113 /// Check the information for a C++ data type, check if this type is valid for 1114 /// the current attribute. This method is used to verify specific type 1115 /// invariants that the templatized 'getValues' method cannot. 1116 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 1117 bool isSigned) { 1118 // Make sure that the data element size is the same as the type element width. 1119 if (getDenseElementBitWidth(type) != 1120 static_cast<size_t>(dataEltSize * CHAR_BIT)) 1121 return false; 1122 1123 // Check that the element type is either float or integer or index. 1124 if (!isInt) 1125 return type.isa<FloatType>(); 1126 if (type.isIndex()) 1127 return true; 1128 1129 auto intType = type.dyn_cast<IntegerType>(); 1130 if (!intType) 1131 return false; 1132 1133 // Make sure signedness semantics is consistent. 1134 if (intType.isSignless()) 1135 return true; 1136 return intType.isSigned() ? isSigned : !isSigned; 1137 } 1138 1139 /// Defaults down the subclass implementation. 1140 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 1141 ArrayRef<char> data, 1142 int64_t dataEltSize, 1143 bool isInt, bool isSigned) { 1144 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 1145 isSigned); 1146 } 1147 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 1148 ArrayRef<char> data, 1149 int64_t dataEltSize, 1150 bool isInt, 1151 bool isSigned) { 1152 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 1153 isInt, isSigned); 1154 } 1155 1156 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 1157 bool isSigned) const { 1158 return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned); 1159 } 1160 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 1161 bool isSigned) const { 1162 return ::isValidIntOrFloat( 1163 getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, 1164 isInt, isSigned); 1165 } 1166 1167 /// Returns true if this attribute corresponds to a splat, i.e. if all element 1168 /// values are the same. 1169 bool DenseElementsAttr::isSplat() const { 1170 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 1171 } 1172 1173 /// Return if the given complex type has an integer element type. 1174 LLVM_ATTRIBUTE_UNUSED static bool isComplexOfIntType(Type type) { 1175 return type.cast<ComplexType>().getElementType().isa<IntegerType>(); 1176 } 1177 1178 auto DenseElementsAttr::getComplexIntValues() const 1179 -> iterator_range_impl<ComplexIntElementIterator> { 1180 assert(isComplexOfIntType(getElementType()) && 1181 "expected complex integral type"); 1182 return {getType(), ComplexIntElementIterator(*this, 0), 1183 ComplexIntElementIterator(*this, getNumElements())}; 1184 } 1185 auto DenseElementsAttr::complex_value_begin() const 1186 -> ComplexIntElementIterator { 1187 assert(isComplexOfIntType(getElementType()) && 1188 "expected complex integral type"); 1189 return ComplexIntElementIterator(*this, 0); 1190 } 1191 auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator { 1192 assert(isComplexOfIntType(getElementType()) && 1193 "expected complex integral type"); 1194 return ComplexIntElementIterator(*this, getNumElements()); 1195 } 1196 1197 /// Return the held element values as a range of APFloat. The element type of 1198 /// this attribute must be of float type. 1199 auto DenseElementsAttr::getFloatValues() const 1200 -> iterator_range_impl<FloatElementIterator> { 1201 auto elementType = getElementType().cast<FloatType>(); 1202 const auto &elementSemantics = elementType.getFloatSemantics(); 1203 return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()), 1204 FloatElementIterator(elementSemantics, raw_int_end())}; 1205 } 1206 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 1207 auto elementType = getElementType().cast<FloatType>(); 1208 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin()); 1209 } 1210 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 1211 auto elementType = getElementType().cast<FloatType>(); 1212 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end()); 1213 } 1214 1215 auto DenseElementsAttr::getComplexFloatValues() const 1216 -> iterator_range_impl<ComplexFloatElementIterator> { 1217 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 1218 assert(eltTy.isa<FloatType>() && "expected complex float type"); 1219 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 1220 return {getType(), 1221 {semantics, {*this, 0}}, 1222 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 1223 } 1224 auto DenseElementsAttr::complex_float_value_begin() const 1225 -> ComplexFloatElementIterator { 1226 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 1227 assert(eltTy.isa<FloatType>() && "expected complex float type"); 1228 return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}}; 1229 } 1230 auto DenseElementsAttr::complex_float_value_end() const 1231 -> ComplexFloatElementIterator { 1232 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 1233 assert(eltTy.isa<FloatType>() && "expected complex float type"); 1234 return {eltTy.cast<FloatType>().getFloatSemantics(), 1235 {*this, static_cast<size_t>(getNumElements())}}; 1236 } 1237 1238 /// Return the raw storage data held by this attribute. 1239 ArrayRef<char> DenseElementsAttr::getRawData() const { 1240 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data; 1241 } 1242 1243 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 1244 return static_cast<DenseStringElementsAttrStorage *>(impl)->data; 1245 } 1246 1247 /// Return a new DenseElementsAttr that has the same data as the current 1248 /// attribute, but has been reshaped to 'newType'. The new type must have the 1249 /// same total number of elements as well as element type. 1250 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 1251 ShapedType curType = getType(); 1252 if (curType == newType) 1253 return *this; 1254 1255 assert(newType.getElementType() == curType.getElementType() && 1256 "expected the same element type"); 1257 assert(newType.getNumElements() == curType.getNumElements() && 1258 "expected the same number of elements"); 1259 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData()); 1260 } 1261 1262 DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) { 1263 assert(isSplat() && "expected a splat type"); 1264 1265 ShapedType curType = getType(); 1266 if (curType == newType) 1267 return *this; 1268 1269 assert(newType.getElementType() == curType.getElementType() && 1270 "expected the same element type"); 1271 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData()); 1272 } 1273 1274 /// Return a new DenseElementsAttr that has the same data as the current 1275 /// attribute, but has bitcast elements such that it is now 'newType'. The new 1276 /// type must have the same shape and element types of the same bitwidth as the 1277 /// current type. 1278 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) { 1279 ShapedType curType = getType(); 1280 Type curElType = curType.getElementType(); 1281 if (curElType == newElType) 1282 return *this; 1283 1284 assert(getDenseElementBitWidth(newElType) == 1285 getDenseElementBitWidth(curElType) && 1286 "expected element types with the same bitwidth"); 1287 return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType), 1288 getRawData()); 1289 } 1290 1291 DenseElementsAttr 1292 DenseElementsAttr::mapValues(Type newElementType, 1293 function_ref<APInt(const APInt &)> mapping) const { 1294 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 1295 } 1296 1297 DenseElementsAttr DenseElementsAttr::mapValues( 1298 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1299 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 1300 } 1301 1302 ShapedType DenseElementsAttr::getType() const { 1303 return Attribute::getType().cast<ShapedType>(); 1304 } 1305 1306 Type DenseElementsAttr::getElementType() const { 1307 return getType().getElementType(); 1308 } 1309 1310 int64_t DenseElementsAttr::getNumElements() const { 1311 return getType().getNumElements(); 1312 } 1313 1314 //===----------------------------------------------------------------------===// 1315 // DenseIntOrFPElementsAttr 1316 //===----------------------------------------------------------------------===// 1317 1318 /// Utility method to write a range of APInt values to a buffer. 1319 template <typename APRangeT> 1320 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1321 APRangeT &&values) { 1322 size_t numValues = llvm::size(values); 1323 data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT)); 1324 size_t offset = 0; 1325 for (auto it = values.begin(), e = values.end(); it != e; 1326 ++it, offset += storageWidth) { 1327 assert((*it).getBitWidth() <= storageWidth); 1328 writeBits(data.data(), offset, *it); 1329 } 1330 1331 // Handle the special encoding of splat of a boolean. 1332 if (numValues == 1 && (*values.begin()).getBitWidth() == 1) 1333 data[0] = data[0] ? -1 : 0; 1334 } 1335 1336 /// Constructs a dense elements attribute from an array of raw APFloat values. 1337 /// Each APFloat value is expected to have the same bitwidth as the element 1338 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1339 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1340 size_t storageWidth, 1341 ArrayRef<APFloat> values) { 1342 std::vector<char> data; 1343 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1344 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1345 return DenseIntOrFPElementsAttr::getRaw(type, data); 1346 } 1347 1348 /// Constructs a dense elements attribute from an array of raw APInt values. 1349 /// Each APInt value is expected to have the same bitwidth as the element type 1350 /// of 'type'. 1351 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1352 size_t storageWidth, 1353 ArrayRef<APInt> values) { 1354 std::vector<char> data; 1355 writeAPIntsToBuffer(storageWidth, data, values); 1356 return DenseIntOrFPElementsAttr::getRaw(type, data); 1357 } 1358 1359 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1360 ArrayRef<char> data) { 1361 assert((type.isa<RankedTensorType, VectorType>()) && 1362 "type must be ranked tensor or vector"); 1363 assert(type.hasStaticShape() && "type must have static shape"); 1364 bool isSplat = false; 1365 bool isValid = isValidRawBuffer(type, data, isSplat); 1366 assert(isValid); 1367 (void)isValid; 1368 return Base::get(type.getContext(), type, data, isSplat); 1369 } 1370 1371 /// Overload of the raw 'get' method that asserts that the given type is of 1372 /// complex type. This method is used to verify type invariants that the 1373 /// templatized 'get' method cannot. 1374 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1375 ArrayRef<char> data, 1376 int64_t dataEltSize, 1377 bool isInt, 1378 bool isSigned) { 1379 assert(::isValidIntOrFloat( 1380 type.getElementType().cast<ComplexType>().getElementType(), 1381 dataEltSize / 2, isInt, isSigned)); 1382 1383 int64_t numElements = data.size() / dataEltSize; 1384 (void)numElements; 1385 assert(numElements == 1 || numElements == type.getNumElements()); 1386 return getRaw(type, data); 1387 } 1388 1389 /// Overload of the 'getRaw' method that asserts that the given type is of 1390 /// integer type. This method is used to verify type invariants that the 1391 /// templatized 'get' method cannot. 1392 DenseElementsAttr 1393 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1394 int64_t dataEltSize, bool isInt, 1395 bool isSigned) { 1396 assert( 1397 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1398 1399 int64_t numElements = data.size() / dataEltSize; 1400 assert(numElements == 1 || numElements == type.getNumElements()); 1401 (void)numElements; 1402 return getRaw(type, data); 1403 } 1404 1405 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 1406 const char *inRawData, char *outRawData, size_t elementBitWidth, 1407 size_t numElements) { 1408 using llvm::support::ulittle16_t; 1409 using llvm::support::ulittle32_t; 1410 using llvm::support::ulittle64_t; 1411 1412 assert(llvm::support::endian::system_endianness() == // NOLINT 1413 llvm::support::endianness::big); // NOLINT 1414 // NOLINT to avoid warning message about replacing by static_assert() 1415 1416 // Following std::copy_n always converts endianness on BE machine. 1417 switch (elementBitWidth) { 1418 case 16: { 1419 const ulittle16_t *inRawDataPos = 1420 reinterpret_cast<const ulittle16_t *>(inRawData); 1421 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); 1422 std::copy_n(inRawDataPos, numElements, outDataPos); 1423 break; 1424 } 1425 case 32: { 1426 const ulittle32_t *inRawDataPos = 1427 reinterpret_cast<const ulittle32_t *>(inRawData); 1428 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); 1429 std::copy_n(inRawDataPos, numElements, outDataPos); 1430 break; 1431 } 1432 case 64: { 1433 const ulittle64_t *inRawDataPos = 1434 reinterpret_cast<const ulittle64_t *>(inRawData); 1435 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); 1436 std::copy_n(inRawDataPos, numElements, outDataPos); 1437 break; 1438 } 1439 default: { 1440 size_t nBytes = elementBitWidth / CHAR_BIT; 1441 for (size_t i = 0; i < nBytes; i++) 1442 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); 1443 break; 1444 } 1445 } 1446 } 1447 1448 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1449 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, 1450 ShapedType type) { 1451 size_t numElements = type.getNumElements(); 1452 Type elementType = type.getElementType(); 1453 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1454 elementType = complexTy.getElementType(); 1455 numElements = numElements * 2; 1456 } 1457 size_t elementBitWidth = getDenseElementStorageWidth(elementType); 1458 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && 1459 inRawData.size() <= outRawData.size()); 1460 if (elementBitWidth <= CHAR_BIT) 1461 std::memcpy(outRawData.begin(), inRawData.begin(), inRawData.size()); 1462 else 1463 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), 1464 elementBitWidth, numElements); 1465 } 1466 1467 //===----------------------------------------------------------------------===// 1468 // DenseFPElementsAttr 1469 //===----------------------------------------------------------------------===// 1470 1471 template <typename Fn, typename Attr> 1472 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1473 Type newElementType, 1474 llvm::SmallVectorImpl<char> &data) { 1475 size_t bitWidth = getDenseElementBitWidth(newElementType); 1476 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1477 1478 ShapedType newArrayType; 1479 if (inType.isa<RankedTensorType>()) 1480 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1481 else if (inType.isa<UnrankedTensorType>()) 1482 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1483 else if (auto vType = inType.dyn_cast<VectorType>()) 1484 newArrayType = VectorType::get(vType.getShape(), newElementType, 1485 vType.getNumScalableDims()); 1486 else 1487 assert(newArrayType && "Unhandled tensor type"); 1488 1489 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1490 data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT)); 1491 1492 // Functor used to process a single element value of the attribute. 1493 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1494 auto newInt = mapping(value); 1495 assert(newInt.getBitWidth() == bitWidth); 1496 writeBits(data.data(), index * storageBitWidth, newInt); 1497 }; 1498 1499 // Check for the splat case. 1500 if (attr.isSplat()) { 1501 processElt(*attr.begin(), /*index=*/0); 1502 return newArrayType; 1503 } 1504 1505 // Otherwise, process all of the element values. 1506 uint64_t elementIdx = 0; 1507 for (auto value : attr) 1508 processElt(value, elementIdx++); 1509 return newArrayType; 1510 } 1511 1512 DenseElementsAttr DenseFPElementsAttr::mapValues( 1513 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1514 llvm::SmallVector<char, 8> elementData; 1515 auto newArrayType = 1516 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1517 1518 return getRaw(newArrayType, elementData); 1519 } 1520 1521 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1522 bool DenseFPElementsAttr::classof(Attribute attr) { 1523 return attr.isa<DenseElementsAttr>() && 1524 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1525 } 1526 1527 //===----------------------------------------------------------------------===// 1528 // DenseIntElementsAttr 1529 //===----------------------------------------------------------------------===// 1530 1531 DenseElementsAttr DenseIntElementsAttr::mapValues( 1532 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1533 llvm::SmallVector<char, 8> elementData; 1534 auto newArrayType = 1535 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1536 return getRaw(newArrayType, elementData); 1537 } 1538 1539 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1540 bool DenseIntElementsAttr::classof(Attribute attr) { 1541 return attr.isa<DenseElementsAttr>() && 1542 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1543 } 1544 1545 //===----------------------------------------------------------------------===// 1546 // OpaqueElementsAttr 1547 //===----------------------------------------------------------------------===// 1548 1549 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1550 Dialect *dialect = getContext()->getLoadedDialect(getDialect()); 1551 if (!dialect) 1552 return true; 1553 auto *interface = llvm::dyn_cast<DialectDecodeAttributesInterface>(dialect); 1554 if (!interface) 1555 return true; 1556 return failed(interface->decode(*this, result)); 1557 } 1558 1559 LogicalResult 1560 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1561 StringAttr dialect, StringRef value, 1562 ShapedType type) { 1563 if (!Dialect::isValidNamespace(dialect.strref())) 1564 return emitError() << "invalid dialect namespace '" << dialect << "'"; 1565 return success(); 1566 } 1567 1568 //===----------------------------------------------------------------------===// 1569 // SparseElementsAttr 1570 //===----------------------------------------------------------------------===// 1571 1572 /// Get a zero APFloat for the given sparse attribute. 1573 APFloat SparseElementsAttr::getZeroAPFloat() const { 1574 auto eltType = getElementType().cast<FloatType>(); 1575 return APFloat(eltType.getFloatSemantics()); 1576 } 1577 1578 /// Get a zero APInt for the given sparse attribute. 1579 APInt SparseElementsAttr::getZeroAPInt() const { 1580 auto eltType = getElementType().cast<IntegerType>(); 1581 return APInt::getZero(eltType.getWidth()); 1582 } 1583 1584 /// Get a zero attribute for the given attribute type. 1585 Attribute SparseElementsAttr::getZeroAttr() const { 1586 auto eltType = getElementType(); 1587 1588 // Handle floating point elements. 1589 if (eltType.isa<FloatType>()) 1590 return FloatAttr::get(eltType, 0); 1591 1592 // Handle string type. 1593 if (getValues().isa<DenseStringElementsAttr>()) 1594 return StringAttr::get("", eltType); 1595 1596 // Otherwise, this is an integer. 1597 return IntegerAttr::get(eltType, 0); 1598 } 1599 1600 /// Flatten, and return, all of the sparse indices in this attribute in 1601 /// row-major order. 1602 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1603 std::vector<ptrdiff_t> flatSparseIndices; 1604 1605 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1606 // as a 1-D index array. 1607 auto sparseIndices = getIndices(); 1608 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1609 if (sparseIndices.isSplat()) { 1610 SmallVector<uint64_t, 8> indices(getType().getRank(), 1611 *sparseIndexValues.begin()); 1612 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1613 return flatSparseIndices; 1614 } 1615 1616 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1617 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1618 size_t rank = getType().getRank(); 1619 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1620 flatSparseIndices.push_back(getFlattenedIndex( 1621 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1622 return flatSparseIndices; 1623 } 1624 1625 LogicalResult 1626 SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1627 ShapedType type, DenseIntElementsAttr sparseIndices, 1628 DenseElementsAttr values) { 1629 ShapedType valuesType = values.getType(); 1630 if (valuesType.getRank() != 1) 1631 return emitError() << "expected 1-d tensor for sparse element values"; 1632 1633 // Verify the indices and values shape. 1634 ShapedType indicesType = sparseIndices.getType(); 1635 auto emitShapeError = [&]() { 1636 return emitError() << "expected shape ([" << type.getShape() 1637 << "]); inferred shape of indices literal ([" 1638 << indicesType.getShape() 1639 << "]); inferred shape of values literal ([" 1640 << valuesType.getShape() << "])"; 1641 }; 1642 // Verify indices shape. 1643 size_t rank = type.getRank(), indicesRank = indicesType.getRank(); 1644 if (indicesRank == 2) { 1645 if (indicesType.getDimSize(1) != static_cast<int64_t>(rank)) 1646 return emitShapeError(); 1647 } else if (indicesRank != 1 || rank != 1) { 1648 return emitShapeError(); 1649 } 1650 // Verify the values shape. 1651 int64_t numSparseIndices = indicesType.getDimSize(0); 1652 if (numSparseIndices != valuesType.getDimSize(0)) 1653 return emitShapeError(); 1654 1655 // Verify that the sparse indices are within the value shape. 1656 auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) { 1657 return emitError() 1658 << "sparse index #" << indexNum 1659 << " is not contained within the value shape, with index=[" << index 1660 << "], and type=" << type; 1661 }; 1662 1663 // Handle the case where the index values are a splat. 1664 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1665 if (sparseIndices.isSplat()) { 1666 SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin()); 1667 if (!ElementsAttr::isValidIndex(type, indices)) 1668 return emitIndexError(0, indices); 1669 return success(); 1670 } 1671 1672 // Otherwise, reinterpret each index as an ArrayRef. 1673 for (size_t i = 0, e = numSparseIndices; i != e; ++i) { 1674 ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank), 1675 rank); 1676 if (!ElementsAttr::isValidIndex(type, index)) 1677 return emitIndexError(i, index); 1678 } 1679 1680 return success(); 1681 } 1682 1683 //===----------------------------------------------------------------------===// 1684 // TypeAttr 1685 //===----------------------------------------------------------------------===// 1686 1687 void TypeAttr::walkImmediateSubElements( 1688 function_ref<void(Attribute)> walkAttrsFn, 1689 function_ref<void(Type)> walkTypesFn) const { 1690 walkTypesFn(getValue()); 1691 } 1692