1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/BuiltinDialect.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/IntegerSet.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/IR/Operation.h" 17 #include "mlir/IR/SymbolTable.h" 18 #include "mlir/IR/Types.h" 19 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" 20 #include "llvm/ADT/APSInt.h" 21 #include "llvm/ADT/Sequence.h" 22 #include "llvm/Support/Endian.h" 23 24 using namespace mlir; 25 using namespace mlir::detail; 26 27 //===----------------------------------------------------------------------===// 28 /// Tablegen Attribute Definitions 29 //===----------------------------------------------------------------------===// 30 31 #define GET_ATTRDEF_CLASSES 32 #include "mlir/IR/BuiltinAttributes.cpp.inc" 33 34 //===----------------------------------------------------------------------===// 35 // BuiltinDialect 36 //===----------------------------------------------------------------------===// 37 38 void BuiltinDialect::registerAttributes() { 39 addAttributes<AffineMapAttr, ArrayAttr, DenseArrayBaseAttr, 40 DenseIntOrFPElementsAttr, DenseStringElementsAttr, 41 DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr, 42 IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr, 43 SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>(); 44 } 45 46 //===----------------------------------------------------------------------===// 47 // ArrayAttr 48 //===----------------------------------------------------------------------===// 49 50 void ArrayAttr::walkImmediateSubElements( 51 function_ref<void(Attribute)> walkAttrsFn, 52 function_ref<void(Type)> walkTypesFn) const { 53 for (Attribute attr : getValue()) 54 walkAttrsFn(attr); 55 } 56 57 SubElementAttrInterface ArrayAttr::replaceImmediateSubAttribute( 58 ArrayRef<std::pair<size_t, Attribute>> replacements) const { 59 std::vector<Attribute> vector = getValue().vec(); 60 for (auto &it : replacements) { 61 vector[it.first] = it.second; 62 } 63 return get(getContext(), vector); 64 } 65 66 //===----------------------------------------------------------------------===// 67 // DictionaryAttr 68 //===----------------------------------------------------------------------===// 69 70 /// Helper function that does either an in place sort or sorts from source array 71 /// into destination. If inPlace then storage is both the source and the 72 /// destination, else value is the source and storage destination. Returns 73 /// whether source was sorted. 74 template <bool inPlace> 75 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 76 SmallVectorImpl<NamedAttribute> &storage) { 77 // Specialize for the common case. 78 switch (value.size()) { 79 case 0: 80 // Zero already sorted. 81 if (!inPlace) 82 storage.clear(); 83 break; 84 case 1: 85 // One already sorted but may need to be copied. 86 if (!inPlace) 87 storage.assign({value[0]}); 88 break; 89 case 2: { 90 bool isSorted = value[0] < value[1]; 91 if (inPlace) { 92 if (!isSorted) 93 std::swap(storage[0], storage[1]); 94 } else if (isSorted) { 95 storage.assign({value[0], value[1]}); 96 } else { 97 storage.assign({value[1], value[0]}); 98 } 99 return !isSorted; 100 } 101 default: 102 if (!inPlace) 103 storage.assign(value.begin(), value.end()); 104 // Check to see they are sorted already. 105 bool isSorted = llvm::is_sorted(value); 106 // If not, do a general sort. 107 if (!isSorted) 108 llvm::array_pod_sort(storage.begin(), storage.end()); 109 return !isSorted; 110 } 111 return false; 112 } 113 114 /// Returns an entry with a duplicate name from the given sorted array of named 115 /// attributes. Returns llvm::None if all elements have unique names. 116 static Optional<NamedAttribute> 117 findDuplicateElement(ArrayRef<NamedAttribute> value) { 118 const Optional<NamedAttribute> none{llvm::None}; 119 if (value.size() < 2) 120 return none; 121 122 if (value.size() == 2) 123 return value[0].getName() == value[1].getName() ? value[0] : none; 124 125 const auto *it = std::adjacent_find(value.begin(), value.end(), 126 [](NamedAttribute l, NamedAttribute r) { 127 return l.getName() == r.getName(); 128 }); 129 return it != value.end() ? *it : none; 130 } 131 132 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 133 SmallVectorImpl<NamedAttribute> &storage) { 134 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); 135 assert(!findDuplicateElement(storage) && 136 "DictionaryAttr element names must be unique"); 137 return isSorted; 138 } 139 140 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 141 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); 142 assert(!findDuplicateElement(array) && 143 "DictionaryAttr element names must be unique"); 144 return isSorted; 145 } 146 147 Optional<NamedAttribute> 148 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, 149 bool isSorted) { 150 if (!isSorted) 151 dictionaryAttrSort</*inPlace=*/true>(array, array); 152 return findDuplicateElement(array); 153 } 154 155 DictionaryAttr DictionaryAttr::get(MLIRContext *context, 156 ArrayRef<NamedAttribute> value) { 157 if (value.empty()) 158 return DictionaryAttr::getEmpty(context); 159 160 // We need to sort the element list to canonicalize it. 161 SmallVector<NamedAttribute, 8> storage; 162 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 163 value = storage; 164 assert(!findDuplicateElement(value) && 165 "DictionaryAttr element names must be unique"); 166 return Base::get(context, value); 167 } 168 /// Construct a dictionary with an array of values that is known to already be 169 /// sorted by name and uniqued. 170 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context, 171 ArrayRef<NamedAttribute> value) { 172 if (value.empty()) 173 return DictionaryAttr::getEmpty(context); 174 // Ensure that the attribute elements are unique and sorted. 175 assert(llvm::is_sorted( 176 value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && 177 "expected attribute values to be sorted"); 178 assert(!findDuplicateElement(value) && 179 "DictionaryAttr element names must be unique"); 180 return Base::get(context, value); 181 } 182 183 /// Return the specified attribute if present, null otherwise. 184 Attribute DictionaryAttr::get(StringRef name) const { 185 auto it = impl::findAttrSorted(begin(), end(), name); 186 return it.second ? it.first->getValue() : Attribute(); 187 } 188 Attribute DictionaryAttr::get(StringAttr name) const { 189 auto it = impl::findAttrSorted(begin(), end(), name); 190 return it.second ? it.first->getValue() : Attribute(); 191 } 192 193 /// Return the specified named attribute if present, None otherwise. 194 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 195 auto it = impl::findAttrSorted(begin(), end(), name); 196 return it.second ? *it.first : Optional<NamedAttribute>(); 197 } 198 Optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const { 199 auto it = impl::findAttrSorted(begin(), end(), name); 200 return it.second ? *it.first : Optional<NamedAttribute>(); 201 } 202 203 /// Return whether the specified attribute is present. 204 bool DictionaryAttr::contains(StringRef name) const { 205 return impl::findAttrSorted(begin(), end(), name).second; 206 } 207 bool DictionaryAttr::contains(StringAttr name) const { 208 return impl::findAttrSorted(begin(), end(), name).second; 209 } 210 211 DictionaryAttr::iterator DictionaryAttr::begin() const { 212 return getValue().begin(); 213 } 214 DictionaryAttr::iterator DictionaryAttr::end() const { 215 return getValue().end(); 216 } 217 size_t DictionaryAttr::size() const { return getValue().size(); } 218 219 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { 220 return Base::get(context, ArrayRef<NamedAttribute>()); 221 } 222 223 void DictionaryAttr::walkImmediateSubElements( 224 function_ref<void(Attribute)> walkAttrsFn, 225 function_ref<void(Type)> walkTypesFn) const { 226 for (const NamedAttribute &attr : getValue()) 227 walkAttrsFn(attr.getValue()); 228 } 229 230 SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute( 231 ArrayRef<std::pair<size_t, Attribute>> replacements) const { 232 std::vector<NamedAttribute> vec = getValue().vec(); 233 for (auto &it : replacements) 234 vec[it.first].setValue(it.second); 235 236 // The above only modifies the mapped value, but not the key, and therefore 237 // not the order of the elements. It remains sorted 238 return getWithSorted(getContext(), vec); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // StringAttr 243 //===----------------------------------------------------------------------===// 244 245 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) { 246 return Base::get(context, "", NoneType::get(context)); 247 } 248 249 /// Twine support for StringAttr. 250 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) { 251 // Fast-path empty twine. 252 if (twine.isTriviallyEmpty()) 253 return get(context); 254 SmallVector<char, 32> tempStr; 255 return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context)); 256 } 257 258 /// Twine support for StringAttr. 259 StringAttr StringAttr::get(const Twine &twine, Type type) { 260 SmallVector<char, 32> tempStr; 261 return Base::get(type.getContext(), twine.toStringRef(tempStr), type); 262 } 263 264 StringRef StringAttr::getValue() const { return getImpl()->value; } 265 266 Dialect *StringAttr::getReferencedDialect() const { 267 return getImpl()->referencedDialect; 268 } 269 270 //===----------------------------------------------------------------------===// 271 // FloatAttr 272 //===----------------------------------------------------------------------===// 273 274 double FloatAttr::getValueAsDouble() const { 275 return getValueAsDouble(getValue()); 276 } 277 double FloatAttr::getValueAsDouble(APFloat value) { 278 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 279 bool losesInfo = false; 280 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 281 &losesInfo); 282 } 283 return value.convertToDouble(); 284 } 285 286 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 287 Type type, APFloat value) { 288 // Verify that the type is correct. 289 if (!type.isa<FloatType>()) 290 return emitError() << "expected floating point type"; 291 292 // Verify that the type semantics match that of the value. 293 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 294 return emitError() 295 << "FloatAttr type doesn't match the type implied by its value"; 296 } 297 return success(); 298 } 299 300 //===----------------------------------------------------------------------===// 301 // SymbolRefAttr 302 //===----------------------------------------------------------------------===// 303 304 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value, 305 ArrayRef<FlatSymbolRefAttr> nestedRefs) { 306 return get(StringAttr::get(ctx, value), nestedRefs); 307 } 308 309 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { 310 return get(ctx, value, {}).cast<FlatSymbolRefAttr>(); 311 } 312 313 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) { 314 return get(value, {}).cast<FlatSymbolRefAttr>(); 315 } 316 317 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) { 318 auto symName = 319 symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); 320 assert(symName && "value does not have a valid symbol name"); 321 return SymbolRefAttr::get(symName); 322 } 323 324 StringAttr SymbolRefAttr::getLeafReference() const { 325 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 326 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr(); 327 } 328 329 //===----------------------------------------------------------------------===// 330 // IntegerAttr 331 //===----------------------------------------------------------------------===// 332 333 int64_t IntegerAttr::getInt() const { 334 assert((getType().isIndex() || getType().isSignlessInteger()) && 335 "must be signless integer"); 336 return getValue().getSExtValue(); 337 } 338 339 int64_t IntegerAttr::getSInt() const { 340 assert(getType().isSignedInteger() && "must be signed integer"); 341 return getValue().getSExtValue(); 342 } 343 344 uint64_t IntegerAttr::getUInt() const { 345 assert(getType().isUnsignedInteger() && "must be unsigned integer"); 346 return getValue().getZExtValue(); 347 } 348 349 /// Return the value as an APSInt which carries the signed from the type of 350 /// the attribute. This traps on signless integers types! 351 APSInt IntegerAttr::getAPSInt() const { 352 assert(!getType().isSignlessInteger() && 353 "Signless integers don't carry a sign for APSInt"); 354 return APSInt(getValue(), getType().isUnsignedInteger()); 355 } 356 357 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, 358 Type type, APInt value) { 359 if (IntegerType integerType = type.dyn_cast<IntegerType>()) { 360 if (integerType.getWidth() != value.getBitWidth()) 361 return emitError() << "integer type bit width (" << integerType.getWidth() 362 << ") doesn't match value bit width (" 363 << value.getBitWidth() << ")"; 364 return success(); 365 } 366 if (type.isa<IndexType>()) 367 return success(); 368 return emitError() << "expected integer or index type"; 369 } 370 371 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { 372 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value)); 373 return attr.cast<BoolAttr>(); 374 } 375 376 //===----------------------------------------------------------------------===// 377 // BoolAttr 378 //===----------------------------------------------------------------------===// 379 380 bool BoolAttr::getValue() const { 381 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl); 382 return storage->value.getBoolValue(); 383 } 384 385 bool BoolAttr::classof(Attribute attr) { 386 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 387 return intAttr && intAttr.getType().isSignlessInteger(1); 388 } 389 390 //===----------------------------------------------------------------------===// 391 // OpaqueAttr 392 //===----------------------------------------------------------------------===// 393 394 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError, 395 StringAttr dialect, StringRef attrData, 396 Type type) { 397 if (!Dialect::isValidNamespace(dialect.strref())) 398 return emitError() << "invalid dialect namespace '" << dialect << "'"; 399 400 // Check that the dialect is actually registered. 401 MLIRContext *context = dialect.getContext(); 402 if (!context->allowsUnregisteredDialects() && 403 !context->getLoadedDialect(dialect.strref())) { 404 return emitError() 405 << "#" << dialect << "<\"" << attrData << "\"> : " << type 406 << " attribute created with unregistered dialect. If this is " 407 "intended, please call allowUnregisteredDialects() on the " 408 "MLIRContext, or use -allow-unregistered-dialect with " 409 "the MLIR opt tool used"; 410 } 411 412 return success(); 413 } 414 415 //===----------------------------------------------------------------------===// 416 // DenseElementsAttr Utilities 417 //===----------------------------------------------------------------------===// 418 419 /// Get the bitwidth of a dense element type within the buffer. 420 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 421 static size_t getDenseElementStorageWidth(size_t origWidth) { 422 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 423 } 424 static size_t getDenseElementStorageWidth(Type elementType) { 425 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 426 } 427 428 /// Set a bit to a specific value. 429 static void setBit(char *rawData, size_t bitPos, bool value) { 430 if (value) 431 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 432 else 433 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 434 } 435 436 /// Return the value of the specified bit. 437 static bool getBit(const char *rawData, size_t bitPos) { 438 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 439 } 440 441 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for 442 /// BE format. 443 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, 444 char *result) { 445 assert(llvm::support::endian::system_endianness() == // NOLINT 446 llvm::support::endianness::big); // NOLINT 447 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 448 449 // Copy the words filled with data. 450 // For example, when `value` has 2 words, the first word is filled with data. 451 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| 452 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 453 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 454 numFilledWords, result); 455 // Convert last word of APInt to LE format and store it in char 456 // array(`valueLE`). 457 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| 458 size_t lastWordPos = numFilledWords; 459 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); 460 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 461 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, 462 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); 463 // Extract actual APInt data from `valueLE`, convert endianness to BE format, 464 // and store it in `result`. 465 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| 466 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 467 valueLE.begin(), result + lastWordPos, 468 (numBytes - lastWordPos) * CHAR_BIT, 1); 469 } 470 471 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE 472 /// format. 473 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, 474 APInt &result) { 475 assert(llvm::support::endian::system_endianness() == // NOLINT 476 llvm::support::endianness::big); // NOLINT 477 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 478 479 // Copy the data that fills the word of `result` from `inArray`. 480 // For example, when `result` has 2 words, the first word will be filled with 481 // data. So, the first 8 bytes are copied from `inArray` here. 482 // `inArray` (10 bytes, BE): |abcdefgh|ij| 483 // ==> `result` (2 words, BE): |abcdefgh|--------| 484 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 485 std::copy_n( 486 inArray, numFilledWords, 487 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 488 489 // Convert array data which will be last word of `result` to LE format, and 490 // store it in char array(`inArrayLE`). 491 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| 492 size_t lastWordPos = numFilledWords; 493 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); 494 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 495 inArray + lastWordPos, inArrayLE.begin(), 496 (numBytes - lastWordPos) * CHAR_BIT, 1); 497 498 // Convert `inArrayLE` to BE format, and store it in last word of `result`. 499 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| 500 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 501 inArrayLE.begin(), 502 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + 503 lastWordPos, 504 APInt::APINT_BITS_PER_WORD, 1); 505 } 506 507 /// Writes value to the bit position `bitPos` in array `rawData`. 508 static void writeBits(char *rawData, size_t bitPos, APInt value) { 509 size_t bitWidth = value.getBitWidth(); 510 511 // If the bitwidth is 1 we just toggle the specific bit. 512 if (bitWidth == 1) 513 return setBit(rawData, bitPos, value.isOneValue()); 514 515 // Otherwise, the bit position is guaranteed to be byte aligned. 516 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 517 if (llvm::support::endian::system_endianness() == 518 llvm::support::endianness::big) { 519 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. 520 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 521 // work correctly in BE format. 522 // ex. `value` (2 words including 10 bytes) 523 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| 524 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), 525 rawData + (bitPos / CHAR_BIT)); 526 } else { 527 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 528 llvm::divideCeil(bitWidth, CHAR_BIT), 529 rawData + (bitPos / CHAR_BIT)); 530 } 531 } 532 533 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 534 /// `rawData`. 535 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 536 // Handle a boolean bit position. 537 if (bitWidth == 1) 538 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 539 540 // Otherwise, the bit position must be 8-bit aligned. 541 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 542 APInt result(bitWidth, 0); 543 if (llvm::support::endian::system_endianness() == 544 llvm::support::endianness::big) { 545 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. 546 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 547 // work correctly in BE format. 548 // ex. `result` (2 words including 10 bytes) 549 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function 550 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), 551 llvm::divideCeil(bitWidth, CHAR_BIT), result); 552 } else { 553 std::copy_n(rawData + (bitPos / CHAR_BIT), 554 llvm::divideCeil(bitWidth, CHAR_BIT), 555 const_cast<char *>( 556 reinterpret_cast<const char *>(result.getRawData()))); 557 } 558 return result; 559 } 560 561 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 562 /// the same element count as 'type'. 563 template <typename Values> 564 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 565 return (values.size() == 1) || 566 (type.getNumElements() == static_cast<int64_t>(values.size())); 567 } 568 569 //===----------------------------------------------------------------------===// 570 // DenseElementsAttr Iterators 571 //===----------------------------------------------------------------------===// 572 573 //===----------------------------------------------------------------------===// 574 // AttributeElementIterator 575 576 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 577 DenseElementsAttr attr, size_t index) 578 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 579 Attribute, Attribute, Attribute>( 580 attr.getAsOpaquePointer(), index) {} 581 582 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 583 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 584 Type eltTy = owner.getElementType(); 585 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 586 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 587 if (eltTy.isa<IndexType>()) 588 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 589 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 590 IntElementIterator intIt(owner, index); 591 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 592 return FloatAttr::get(eltTy, *floatIt); 593 } 594 if (auto complexTy = eltTy.dyn_cast<ComplexType>()) { 595 auto complexEltTy = complexTy.getElementType(); 596 ComplexIntElementIterator complexIntIt(owner, index); 597 if (complexEltTy.isa<IntegerType>()) { 598 auto value = *complexIntIt; 599 auto real = IntegerAttr::get(complexEltTy, value.real()); 600 auto imag = IntegerAttr::get(complexEltTy, value.imag()); 601 return ArrayAttr::get(complexTy.getContext(), 602 ArrayRef<Attribute>{real, imag}); 603 } 604 605 ComplexFloatElementIterator complexFloatIt( 606 complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt); 607 auto value = *complexFloatIt; 608 auto real = FloatAttr::get(complexEltTy, value.real()); 609 auto imag = FloatAttr::get(complexEltTy, value.imag()); 610 return ArrayAttr::get(complexTy.getContext(), 611 ArrayRef<Attribute>{real, imag}); 612 } 613 if (owner.isa<DenseStringElementsAttr>()) { 614 ArrayRef<StringRef> vals = owner.getRawStringData(); 615 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 616 } 617 llvm_unreachable("unexpected element type"); 618 } 619 620 //===----------------------------------------------------------------------===// 621 // BoolElementIterator 622 623 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 624 DenseElementsAttr attr, size_t dataIndex) 625 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 626 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 627 628 bool DenseElementsAttr::BoolElementIterator::operator*() const { 629 return getBit(getData(), getDataIndex()); 630 } 631 632 //===----------------------------------------------------------------------===// 633 // IntElementIterator 634 635 DenseElementsAttr::IntElementIterator::IntElementIterator( 636 DenseElementsAttr attr, size_t dataIndex) 637 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 638 attr.getRawData().data(), attr.isSplat(), dataIndex), 639 bitWidth(getDenseElementBitWidth(attr.getElementType())) {} 640 641 APInt DenseElementsAttr::IntElementIterator::operator*() const { 642 return readBits(getData(), 643 getDataIndex() * getDenseElementStorageWidth(bitWidth), 644 bitWidth); 645 } 646 647 //===----------------------------------------------------------------------===// 648 // ComplexIntElementIterator 649 650 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 651 DenseElementsAttr attr, size_t dataIndex) 652 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 653 std::complex<APInt>, std::complex<APInt>, 654 std::complex<APInt>>( 655 attr.getRawData().data(), attr.isSplat(), dataIndex) { 656 auto complexType = attr.getElementType().cast<ComplexType>(); 657 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 658 } 659 660 std::complex<APInt> 661 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 662 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 663 size_t offset = getDataIndex() * storageWidth * 2; 664 return {readBits(getData(), offset, bitWidth), 665 readBits(getData(), offset + storageWidth, bitWidth)}; 666 } 667 668 //===----------------------------------------------------------------------===// 669 // DenseArrayAttr 670 //===----------------------------------------------------------------------===// 671 672 /// Custom storage to ensure proper memory alignment for the allocation of 673 /// DenseArray of any element type. 674 struct mlir::detail::DenseArrayBaseAttrStorage : public AttributeStorage { 675 using KeyTy = std::tuple<ShapedType, DenseArrayBaseAttr::EltType, 676 ::llvm::ArrayRef<char>>; 677 DenseArrayBaseAttrStorage(ShapedType type, 678 DenseArrayBaseAttr::EltType eltType, 679 ::llvm::ArrayRef<char> elements) 680 : AttributeStorage(type), eltType(eltType), elements(elements) {} 681 682 bool operator==(const KeyTy &tblgenKey) const { 683 return (getType() == std::get<0>(tblgenKey)) && 684 (eltType == std::get<1>(tblgenKey)) && 685 (elements == std::get<2>(tblgenKey)); 686 } 687 688 static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { 689 return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey), 690 std::get<2>(tblgenKey)); 691 } 692 693 static DenseArrayBaseAttrStorage * 694 construct(AttributeStorageAllocator &allocator, const KeyTy &tblgenKey) { 695 auto type = std::get<0>(tblgenKey); 696 auto eltType = std::get<1>(tblgenKey); 697 auto elements = std::get<2>(tblgenKey); 698 if (!elements.empty()) { 699 char *alloc = static_cast<char *>( 700 allocator.allocate(elements.size(), alignof(uint64_t))); 701 std::uninitialized_copy(elements.begin(), elements.end(), alloc); 702 elements = ArrayRef<char>(alloc, elements.size()); 703 } 704 return new (allocator.allocate<DenseArrayBaseAttrStorage>()) 705 DenseArrayBaseAttrStorage(type, eltType, elements); 706 } 707 708 DenseArrayBaseAttr::EltType eltType; 709 ::llvm::ArrayRef<char> elements; 710 }; 711 712 DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const { 713 return getImpl()->eltType; 714 } 715 716 const int8_t * 717 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const { 718 return cast<DenseI8ArrayAttr>().asArrayRef().begin(); 719 } 720 const int16_t * 721 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int16_t>) const { 722 return cast<DenseI16ArrayAttr>().asArrayRef().begin(); 723 } 724 const int32_t * 725 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int32_t>) const { 726 return cast<DenseI32ArrayAttr>().asArrayRef().begin(); 727 } 728 const int64_t * 729 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int64_t>) const { 730 return cast<DenseI64ArrayAttr>().asArrayRef().begin(); 731 } 732 const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken<float>) const { 733 return cast<DenseF32ArrayAttr>().asArrayRef().begin(); 734 } 735 const double * 736 DenseArrayBaseAttr::value_begin_impl(OverloadToken<double>) const { 737 return cast<DenseF64ArrayAttr>().asArrayRef().begin(); 738 } 739 740 void DenseArrayBaseAttr::print(AsmPrinter &printer) const { 741 print(printer.getStream()); 742 } 743 744 void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const { 745 switch (getElementType()) { 746 case DenseArrayBaseAttr::EltType::I8: 747 this->cast<DenseI8ArrayAttr>().printWithoutBraces(os); 748 return; 749 case DenseArrayBaseAttr::EltType::I16: 750 this->cast<DenseI16ArrayAttr>().printWithoutBraces(os); 751 return; 752 case DenseArrayBaseAttr::EltType::I32: 753 this->cast<DenseI32ArrayAttr>().printWithoutBraces(os); 754 return; 755 case DenseArrayBaseAttr::EltType::I64: 756 this->cast<DenseI64ArrayAttr>().printWithoutBraces(os); 757 return; 758 case DenseArrayBaseAttr::EltType::F32: 759 this->cast<DenseF32ArrayAttr>().printWithoutBraces(os); 760 return; 761 case DenseArrayBaseAttr::EltType::F64: 762 this->cast<DenseF64ArrayAttr>().printWithoutBraces(os); 763 return; 764 } 765 llvm_unreachable("<unknown DenseArrayBaseAttr>"); 766 } 767 768 void DenseArrayBaseAttr::print(raw_ostream &os) const { 769 os << "["; 770 printWithoutBraces(os); 771 os << "]"; 772 } 773 774 template <typename T> 775 void DenseArrayAttr<T>::print(AsmPrinter &printer) const { 776 print(printer.getStream()); 777 } 778 779 template <typename T> 780 void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const { 781 ArrayRef<T> values{*this}; 782 llvm::interleaveComma(values, os); 783 } 784 785 /// Specialization for int8_t for forcing printing as number instead of chars. 786 template <> 787 void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const { 788 ArrayRef<int8_t> values{*this}; 789 llvm::interleaveComma(values, os, [&](int64_t v) { os << v; }); 790 } 791 792 template <typename T> 793 void DenseArrayAttr<T>::print(raw_ostream &os) const { 794 os << "["; 795 printWithoutBraces(os); 796 os << "]"; 797 } 798 799 /// Parse a single element: generic template for int types, specialized for 800 /// floating points below. 801 template <typename T> 802 static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) { 803 return parser.parseInteger(value); 804 } 805 806 template <> 807 ParseResult parseDenseArrayAttrElt<float>(AsmParser &parser, float &value) { 808 double doubleVal; 809 if (parser.parseFloat(doubleVal)) 810 return failure(); 811 value = doubleVal; 812 return success(); 813 } 814 815 template <> 816 ParseResult parseDenseArrayAttrElt<double>(AsmParser &parser, double &value) { 817 return parser.parseFloat(value); 818 } 819 820 /// Parse a DenseArrayAttr without the braces: `1, 2, 3` 821 template <typename T> 822 Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser, 823 Type odsType) { 824 SmallVector<T> data; 825 if (failed(parser.parseCommaSeparatedList([&]() { 826 T value; 827 if (parseDenseArrayAttrElt(parser, value)) 828 return failure(); 829 data.push_back(value); 830 return success(); 831 }))) 832 return {}; 833 return get(parser.getContext(), data); 834 } 835 836 /// Parse a DenseArrayAttr: `[ 1, 2, 3 ]` 837 template <typename T> 838 Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) { 839 if (parser.parseLSquare()) 840 return {}; 841 // Handle empty list case. 842 if (succeeded(parser.parseOptionalRSquare())) 843 return get(parser.getContext(), {}); 844 Attribute result = parseWithoutBraces(parser, odsType); 845 if (parser.parseRSquare()) 846 return {}; 847 return result; 848 } 849 850 /// Conversion from DenseArrayAttr<T> to ArrayRef<T>. 851 template <typename T> 852 DenseArrayAttr<T>::operator ArrayRef<T>() const { 853 ArrayRef<char> raw = getImpl()->elements; 854 assert((raw.size() % sizeof(T)) == 0); 855 return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()), 856 raw.size() / sizeof(T)); 857 } 858 859 namespace { 860 /// Mapping from C++ element type to MLIR DenseArrayAttr internals. 861 template <typename T> 862 struct denseArrayAttrEltTypeBuilder; 863 template <> 864 struct denseArrayAttrEltTypeBuilder<int8_t> { 865 constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8; 866 static ShapedType getShapedType(MLIRContext *context, 867 ArrayRef<int64_t> shape) { 868 return VectorType::get(shape, IntegerType::get(context, 8)); 869 } 870 }; 871 template <> 872 struct denseArrayAttrEltTypeBuilder<int16_t> { 873 constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16; 874 static ShapedType getShapedType(MLIRContext *context, 875 ArrayRef<int64_t> shape) { 876 return VectorType::get(shape, IntegerType::get(context, 16)); 877 } 878 }; 879 template <> 880 struct denseArrayAttrEltTypeBuilder<int32_t> { 881 constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32; 882 static ShapedType getShapedType(MLIRContext *context, 883 ArrayRef<int64_t> shape) { 884 return VectorType::get(shape, IntegerType::get(context, 32)); 885 } 886 }; 887 template <> 888 struct denseArrayAttrEltTypeBuilder<int64_t> { 889 constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64; 890 static ShapedType getShapedType(MLIRContext *context, 891 ArrayRef<int64_t> shape) { 892 return VectorType::get(shape, IntegerType::get(context, 64)); 893 } 894 }; 895 template <> 896 struct denseArrayAttrEltTypeBuilder<float> { 897 constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32; 898 static ShapedType getShapedType(MLIRContext *context, 899 ArrayRef<int64_t> shape) { 900 return VectorType::get(shape, Float32Type::get(context)); 901 } 902 }; 903 template <> 904 struct denseArrayAttrEltTypeBuilder<double> { 905 constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64; 906 static ShapedType getShapedType(MLIRContext *context, 907 ArrayRef<int64_t> shape) { 908 return VectorType::get(shape, Float64Type::get(context)); 909 } 910 }; 911 } // namespace 912 913 /// Builds a DenseArrayAttr<T> from an ArrayRef<T>. 914 template <typename T> 915 DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context, 916 ArrayRef<T> content) { 917 auto size = static_cast<int64_t>(content.size()); 918 auto shapedType = denseArrayAttrEltTypeBuilder<T>::getShapedType( 919 context, size ? ArrayRef<int64_t>{size} : ArrayRef<int64_t>{}); 920 auto eltType = denseArrayAttrEltTypeBuilder<T>::eltType; 921 auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()), 922 content.size() * sizeof(T)); 923 return Base::get(context, shapedType, eltType, rawArray) 924 .template cast<DenseArrayAttr<T>>(); 925 } 926 927 template <typename T> 928 bool DenseArrayAttr<T>::classof(Attribute attr) { 929 return attr.isa<DenseArrayBaseAttr>() && 930 attr.cast<DenseArrayBaseAttr>().getElementType() == 931 denseArrayAttrEltTypeBuilder<T>::eltType; 932 } 933 934 namespace mlir { 935 namespace detail { 936 // Explicit instantiation for all the supported DenseArrayAttr. 937 template class DenseArrayAttr<int8_t>; 938 template class DenseArrayAttr<int16_t>; 939 template class DenseArrayAttr<int32_t>; 940 template class DenseArrayAttr<int64_t>; 941 template class DenseArrayAttr<float>; 942 template class DenseArrayAttr<double>; 943 } // namespace detail 944 } // namespace mlir 945 946 //===----------------------------------------------------------------------===// 947 // DenseElementsAttr 948 //===----------------------------------------------------------------------===// 949 950 /// Method for support type inquiry through isa, cast and dyn_cast. 951 bool DenseElementsAttr::classof(Attribute attr) { 952 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 953 } 954 955 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 956 ArrayRef<Attribute> values) { 957 assert(hasSameElementsOrSplat(type, values)); 958 959 // If the element type is not based on int/float/index, assume it is a string 960 // type. 961 auto eltType = type.getElementType(); 962 if (!type.getElementType().isIntOrIndexOrFloat()) { 963 SmallVector<StringRef, 8> stringValues; 964 stringValues.reserve(values.size()); 965 for (Attribute attr : values) { 966 assert(attr.isa<StringAttr>() && 967 "expected string value for non integer/index/float element"); 968 stringValues.push_back(attr.cast<StringAttr>().getValue()); 969 } 970 return get(type, stringValues); 971 } 972 973 // Otherwise, get the raw storage width to use for the allocation. 974 size_t bitWidth = getDenseElementBitWidth(eltType); 975 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 976 977 // Compress the attribute values into a character buffer. 978 SmallVector<char, 8> data( 979 llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT)); 980 APInt intVal; 981 for (unsigned i = 0, e = values.size(); i < e; ++i) { 982 assert(eltType == values[i].getType() && 983 "expected attribute value to have element type"); 984 if (eltType.isa<FloatType>()) 985 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 986 else if (eltType.isa<IntegerType, IndexType>()) 987 intVal = values[i].cast<IntegerAttr>().getValue(); 988 else 989 llvm_unreachable("unexpected element type"); 990 991 assert(intVal.getBitWidth() == bitWidth && 992 "expected value to have same bitwidth as element type"); 993 writeBits(data.data(), i * storageBitWidth, intVal); 994 } 995 996 // Handle the special encoding of splat of bool. 997 if (values.size() == 1 && values[0].getType().isInteger(1)) 998 data[0] = data[0] ? -1 : 0; 999 1000 return DenseIntOrFPElementsAttr::getRaw(type, data); 1001 } 1002 1003 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 1004 ArrayRef<bool> values) { 1005 assert(hasSameElementsOrSplat(type, values)); 1006 assert(type.getElementType().isInteger(1)); 1007 1008 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 1009 1010 if (!values.empty()) { 1011 bool isSplat = true; 1012 bool firstValue = values[0]; 1013 for (int i = 0, e = values.size(); i != e; ++i) { 1014 isSplat &= values[i] == firstValue; 1015 setBit(buff.data(), i, values[i]); 1016 } 1017 1018 // Splat of bool is encoded as a byte with all-ones in it. 1019 if (isSplat) { 1020 buff.resize(1); 1021 buff[0] = values[0] ? -1 : 0; 1022 } 1023 } 1024 1025 return DenseIntOrFPElementsAttr::getRaw(type, buff); 1026 } 1027 1028 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 1029 ArrayRef<StringRef> values) { 1030 assert(!type.getElementType().isIntOrFloat()); 1031 return DenseStringElementsAttr::get(type, values); 1032 } 1033 1034 /// Constructs a dense integer elements attribute from an array of APInt 1035 /// values. Each APInt value is expected to have the same bitwidth as the 1036 /// element type of 'type'. 1037 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 1038 ArrayRef<APInt> values) { 1039 assert(type.getElementType().isIntOrIndex()); 1040 assert(hasSameElementsOrSplat(type, values)); 1041 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 1042 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); 1043 } 1044 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 1045 ArrayRef<std::complex<APInt>> values) { 1046 ComplexType complex = type.getElementType().cast<ComplexType>(); 1047 assert(complex.getElementType().isa<IntegerType>()); 1048 assert(hasSameElementsOrSplat(type, values)); 1049 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 1050 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 1051 values.size() * 2); 1052 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals); 1053 } 1054 1055 // Constructs a dense float elements attribute from an array of APFloat 1056 // values. Each APFloat value is expected to have the same bitwidth as the 1057 // element type of 'type'. 1058 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 1059 ArrayRef<APFloat> values) { 1060 assert(type.getElementType().isa<FloatType>()); 1061 assert(hasSameElementsOrSplat(type, values)); 1062 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 1063 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); 1064 } 1065 DenseElementsAttr 1066 DenseElementsAttr::get(ShapedType type, 1067 ArrayRef<std::complex<APFloat>> values) { 1068 ComplexType complex = type.getElementType().cast<ComplexType>(); 1069 assert(complex.getElementType().isa<FloatType>()); 1070 assert(hasSameElementsOrSplat(type, values)); 1071 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 1072 values.size() * 2); 1073 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 1074 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals); 1075 } 1076 1077 /// Construct a dense elements attribute from a raw buffer representing the 1078 /// data for this attribute. Users should generally not use this methods as 1079 /// the expected buffer format may not be a form the user expects. 1080 DenseElementsAttr 1081 DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) { 1082 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer); 1083 } 1084 1085 /// Returns true if the given buffer is a valid raw buffer for the given type. 1086 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 1087 ArrayRef<char> rawBuffer, 1088 bool &detectedSplat) { 1089 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 1090 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 1091 int64_t numElements = type.getNumElements(); 1092 1093 // The initializer is always a splat if the result type has a single element. 1094 detectedSplat = numElements == 1; 1095 1096 // Storage width of 1 is special as it is packed by the bit. 1097 if (storageWidth == 1) { 1098 // Check for a splat, or a buffer equal to the number of elements which 1099 // consists of either all 0's or all 1's. 1100 if (rawBuffer.size() == 1) { 1101 auto rawByte = static_cast<uint8_t>(rawBuffer[0]); 1102 if (rawByte == 0 || rawByte == 0xff) { 1103 detectedSplat = true; 1104 return true; 1105 } 1106 } 1107 1108 // This is a valid non-splat buffer if it has the right size. 1109 return rawBufferWidth == llvm::alignTo<8>(numElements); 1110 } 1111 1112 // All other types are 8-bit aligned, so we can just check the buffer width 1113 // to know if only a single initializer element was passed in. 1114 if (rawBufferWidth == storageWidth) { 1115 detectedSplat = true; 1116 return true; 1117 } 1118 1119 // The raw buffer is valid if it has the right size. 1120 return rawBufferWidth == storageWidth * numElements; 1121 } 1122 1123 /// Check the information for a C++ data type, check if this type is valid for 1124 /// the current attribute. This method is used to verify specific type 1125 /// invariants that the templatized 'getValues' method cannot. 1126 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 1127 bool isSigned) { 1128 // Make sure that the data element size is the same as the type element width. 1129 if (getDenseElementBitWidth(type) != 1130 static_cast<size_t>(dataEltSize * CHAR_BIT)) 1131 return false; 1132 1133 // Check that the element type is either float or integer or index. 1134 if (!isInt) 1135 return type.isa<FloatType>(); 1136 if (type.isIndex()) 1137 return true; 1138 1139 auto intType = type.dyn_cast<IntegerType>(); 1140 if (!intType) 1141 return false; 1142 1143 // Make sure signedness semantics is consistent. 1144 if (intType.isSignless()) 1145 return true; 1146 return intType.isSigned() ? isSigned : !isSigned; 1147 } 1148 1149 /// Defaults down the subclass implementation. 1150 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 1151 ArrayRef<char> data, 1152 int64_t dataEltSize, 1153 bool isInt, bool isSigned) { 1154 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 1155 isSigned); 1156 } 1157 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 1158 ArrayRef<char> data, 1159 int64_t dataEltSize, 1160 bool isInt, 1161 bool isSigned) { 1162 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 1163 isInt, isSigned); 1164 } 1165 1166 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 1167 bool isSigned) const { 1168 return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned); 1169 } 1170 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 1171 bool isSigned) const { 1172 return ::isValidIntOrFloat( 1173 getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, 1174 isInt, isSigned); 1175 } 1176 1177 /// Returns true if this attribute corresponds to a splat, i.e. if all element 1178 /// values are the same. 1179 bool DenseElementsAttr::isSplat() const { 1180 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 1181 } 1182 1183 /// Return if the given complex type has an integer element type. 1184 LLVM_ATTRIBUTE_UNUSED static bool isComplexOfIntType(Type type) { 1185 return type.cast<ComplexType>().getElementType().isa<IntegerType>(); 1186 } 1187 1188 auto DenseElementsAttr::getComplexIntValues() const 1189 -> iterator_range_impl<ComplexIntElementIterator> { 1190 assert(isComplexOfIntType(getElementType()) && 1191 "expected complex integral type"); 1192 return {getType(), ComplexIntElementIterator(*this, 0), 1193 ComplexIntElementIterator(*this, getNumElements())}; 1194 } 1195 auto DenseElementsAttr::complex_value_begin() const 1196 -> ComplexIntElementIterator { 1197 assert(isComplexOfIntType(getElementType()) && 1198 "expected complex integral type"); 1199 return ComplexIntElementIterator(*this, 0); 1200 } 1201 auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator { 1202 assert(isComplexOfIntType(getElementType()) && 1203 "expected complex integral type"); 1204 return ComplexIntElementIterator(*this, getNumElements()); 1205 } 1206 1207 /// Return the held element values as a range of APFloat. The element type of 1208 /// this attribute must be of float type. 1209 auto DenseElementsAttr::getFloatValues() const 1210 -> iterator_range_impl<FloatElementIterator> { 1211 auto elementType = getElementType().cast<FloatType>(); 1212 const auto &elementSemantics = elementType.getFloatSemantics(); 1213 return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()), 1214 FloatElementIterator(elementSemantics, raw_int_end())}; 1215 } 1216 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 1217 auto elementType = getElementType().cast<FloatType>(); 1218 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin()); 1219 } 1220 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 1221 auto elementType = getElementType().cast<FloatType>(); 1222 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end()); 1223 } 1224 1225 auto DenseElementsAttr::getComplexFloatValues() const 1226 -> iterator_range_impl<ComplexFloatElementIterator> { 1227 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 1228 assert(eltTy.isa<FloatType>() && "expected complex float type"); 1229 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 1230 return {getType(), 1231 {semantics, {*this, 0}}, 1232 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 1233 } 1234 auto DenseElementsAttr::complex_float_value_begin() const 1235 -> ComplexFloatElementIterator { 1236 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 1237 assert(eltTy.isa<FloatType>() && "expected complex float type"); 1238 return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}}; 1239 } 1240 auto DenseElementsAttr::complex_float_value_end() const 1241 -> ComplexFloatElementIterator { 1242 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 1243 assert(eltTy.isa<FloatType>() && "expected complex float type"); 1244 return {eltTy.cast<FloatType>().getFloatSemantics(), 1245 {*this, static_cast<size_t>(getNumElements())}}; 1246 } 1247 1248 /// Return the raw storage data held by this attribute. 1249 ArrayRef<char> DenseElementsAttr::getRawData() const { 1250 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data; 1251 } 1252 1253 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 1254 return static_cast<DenseStringElementsAttrStorage *>(impl)->data; 1255 } 1256 1257 /// Return a new DenseElementsAttr that has the same data as the current 1258 /// attribute, but has been reshaped to 'newType'. The new type must have the 1259 /// same total number of elements as well as element type. 1260 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 1261 ShapedType curType = getType(); 1262 if (curType == newType) 1263 return *this; 1264 1265 assert(newType.getElementType() == curType.getElementType() && 1266 "expected the same element type"); 1267 assert(newType.getNumElements() == curType.getNumElements() && 1268 "expected the same number of elements"); 1269 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData()); 1270 } 1271 1272 DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) { 1273 assert(isSplat() && "expected a splat type"); 1274 1275 ShapedType curType = getType(); 1276 if (curType == newType) 1277 return *this; 1278 1279 assert(newType.getElementType() == curType.getElementType() && 1280 "expected the same element type"); 1281 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData()); 1282 } 1283 1284 /// Return a new DenseElementsAttr that has the same data as the current 1285 /// attribute, but has bitcast elements such that it is now 'newType'. The new 1286 /// type must have the same shape and element types of the same bitwidth as the 1287 /// current type. 1288 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) { 1289 ShapedType curType = getType(); 1290 Type curElType = curType.getElementType(); 1291 if (curElType == newElType) 1292 return *this; 1293 1294 assert(getDenseElementBitWidth(newElType) == 1295 getDenseElementBitWidth(curElType) && 1296 "expected element types with the same bitwidth"); 1297 return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType), 1298 getRawData()); 1299 } 1300 1301 DenseElementsAttr 1302 DenseElementsAttr::mapValues(Type newElementType, 1303 function_ref<APInt(const APInt &)> mapping) const { 1304 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 1305 } 1306 1307 DenseElementsAttr DenseElementsAttr::mapValues( 1308 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1309 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 1310 } 1311 1312 ShapedType DenseElementsAttr::getType() const { 1313 return Attribute::getType().cast<ShapedType>(); 1314 } 1315 1316 Type DenseElementsAttr::getElementType() const { 1317 return getType().getElementType(); 1318 } 1319 1320 int64_t DenseElementsAttr::getNumElements() const { 1321 return getType().getNumElements(); 1322 } 1323 1324 //===----------------------------------------------------------------------===// 1325 // DenseIntOrFPElementsAttr 1326 //===----------------------------------------------------------------------===// 1327 1328 /// Utility method to write a range of APInt values to a buffer. 1329 template <typename APRangeT> 1330 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1331 APRangeT &&values) { 1332 size_t numValues = llvm::size(values); 1333 data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT)); 1334 size_t offset = 0; 1335 for (auto it = values.begin(), e = values.end(); it != e; 1336 ++it, offset += storageWidth) { 1337 assert((*it).getBitWidth() <= storageWidth); 1338 writeBits(data.data(), offset, *it); 1339 } 1340 1341 // Handle the special encoding of splat of a boolean. 1342 if (numValues == 1 && (*values.begin()).getBitWidth() == 1) 1343 data[0] = data[0] ? -1 : 0; 1344 } 1345 1346 /// Constructs a dense elements attribute from an array of raw APFloat values. 1347 /// Each APFloat value is expected to have the same bitwidth as the element 1348 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1349 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1350 size_t storageWidth, 1351 ArrayRef<APFloat> values) { 1352 std::vector<char> data; 1353 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1354 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1355 return DenseIntOrFPElementsAttr::getRaw(type, data); 1356 } 1357 1358 /// Constructs a dense elements attribute from an array of raw APInt values. 1359 /// Each APInt value is expected to have the same bitwidth as the element type 1360 /// of 'type'. 1361 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1362 size_t storageWidth, 1363 ArrayRef<APInt> values) { 1364 std::vector<char> data; 1365 writeAPIntsToBuffer(storageWidth, data, values); 1366 return DenseIntOrFPElementsAttr::getRaw(type, data); 1367 } 1368 1369 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1370 ArrayRef<char> data) { 1371 assert((type.isa<RankedTensorType, VectorType>()) && 1372 "type must be ranked tensor or vector"); 1373 assert(type.hasStaticShape() && "type must have static shape"); 1374 bool isSplat = false; 1375 bool isValid = isValidRawBuffer(type, data, isSplat); 1376 assert(isValid); 1377 (void)isValid; 1378 return Base::get(type.getContext(), type, data, isSplat); 1379 } 1380 1381 /// Overload of the raw 'get' method that asserts that the given type is of 1382 /// complex type. This method is used to verify type invariants that the 1383 /// templatized 'get' method cannot. 1384 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1385 ArrayRef<char> data, 1386 int64_t dataEltSize, 1387 bool isInt, 1388 bool isSigned) { 1389 assert(::isValidIntOrFloat( 1390 type.getElementType().cast<ComplexType>().getElementType(), 1391 dataEltSize / 2, isInt, isSigned)); 1392 1393 int64_t numElements = data.size() / dataEltSize; 1394 (void)numElements; 1395 assert(numElements == 1 || numElements == type.getNumElements()); 1396 return getRaw(type, data); 1397 } 1398 1399 /// Overload of the 'getRaw' method that asserts that the given type is of 1400 /// integer type. This method is used to verify type invariants that the 1401 /// templatized 'get' method cannot. 1402 DenseElementsAttr 1403 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1404 int64_t dataEltSize, bool isInt, 1405 bool isSigned) { 1406 assert( 1407 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1408 1409 int64_t numElements = data.size() / dataEltSize; 1410 assert(numElements == 1 || numElements == type.getNumElements()); 1411 (void)numElements; 1412 return getRaw(type, data); 1413 } 1414 1415 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 1416 const char *inRawData, char *outRawData, size_t elementBitWidth, 1417 size_t numElements) { 1418 using llvm::support::ulittle16_t; 1419 using llvm::support::ulittle32_t; 1420 using llvm::support::ulittle64_t; 1421 1422 assert(llvm::support::endian::system_endianness() == // NOLINT 1423 llvm::support::endianness::big); // NOLINT 1424 // NOLINT to avoid warning message about replacing by static_assert() 1425 1426 // Following std::copy_n always converts endianness on BE machine. 1427 switch (elementBitWidth) { 1428 case 16: { 1429 const ulittle16_t *inRawDataPos = 1430 reinterpret_cast<const ulittle16_t *>(inRawData); 1431 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); 1432 std::copy_n(inRawDataPos, numElements, outDataPos); 1433 break; 1434 } 1435 case 32: { 1436 const ulittle32_t *inRawDataPos = 1437 reinterpret_cast<const ulittle32_t *>(inRawData); 1438 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); 1439 std::copy_n(inRawDataPos, numElements, outDataPos); 1440 break; 1441 } 1442 case 64: { 1443 const ulittle64_t *inRawDataPos = 1444 reinterpret_cast<const ulittle64_t *>(inRawData); 1445 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); 1446 std::copy_n(inRawDataPos, numElements, outDataPos); 1447 break; 1448 } 1449 default: { 1450 size_t nBytes = elementBitWidth / CHAR_BIT; 1451 for (size_t i = 0; i < nBytes; i++) 1452 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); 1453 break; 1454 } 1455 } 1456 } 1457 1458 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1459 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, 1460 ShapedType type) { 1461 size_t numElements = type.getNumElements(); 1462 Type elementType = type.getElementType(); 1463 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1464 elementType = complexTy.getElementType(); 1465 numElements = numElements * 2; 1466 } 1467 size_t elementBitWidth = getDenseElementStorageWidth(elementType); 1468 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && 1469 inRawData.size() <= outRawData.size()); 1470 if (elementBitWidth <= CHAR_BIT) 1471 std::memcpy(outRawData.begin(), inRawData.begin(), inRawData.size()); 1472 else 1473 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), 1474 elementBitWidth, numElements); 1475 } 1476 1477 //===----------------------------------------------------------------------===// 1478 // DenseFPElementsAttr 1479 //===----------------------------------------------------------------------===// 1480 1481 template <typename Fn, typename Attr> 1482 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1483 Type newElementType, 1484 llvm::SmallVectorImpl<char> &data) { 1485 size_t bitWidth = getDenseElementBitWidth(newElementType); 1486 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1487 1488 ShapedType newArrayType; 1489 if (inType.isa<RankedTensorType>()) 1490 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1491 else if (inType.isa<UnrankedTensorType>()) 1492 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1493 else if (auto vType = inType.dyn_cast<VectorType>()) 1494 newArrayType = VectorType::get(vType.getShape(), newElementType, 1495 vType.getNumScalableDims()); 1496 else 1497 assert(newArrayType && "Unhandled tensor type"); 1498 1499 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1500 data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT)); 1501 1502 // Functor used to process a single element value of the attribute. 1503 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1504 auto newInt = mapping(value); 1505 assert(newInt.getBitWidth() == bitWidth); 1506 writeBits(data.data(), index * storageBitWidth, newInt); 1507 }; 1508 1509 // Check for the splat case. 1510 if (attr.isSplat()) { 1511 processElt(*attr.begin(), /*index=*/0); 1512 return newArrayType; 1513 } 1514 1515 // Otherwise, process all of the element values. 1516 uint64_t elementIdx = 0; 1517 for (auto value : attr) 1518 processElt(value, elementIdx++); 1519 return newArrayType; 1520 } 1521 1522 DenseElementsAttr DenseFPElementsAttr::mapValues( 1523 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1524 llvm::SmallVector<char, 8> elementData; 1525 auto newArrayType = 1526 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1527 1528 return getRaw(newArrayType, elementData); 1529 } 1530 1531 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1532 bool DenseFPElementsAttr::classof(Attribute attr) { 1533 return attr.isa<DenseElementsAttr>() && 1534 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1535 } 1536 1537 //===----------------------------------------------------------------------===// 1538 // DenseIntElementsAttr 1539 //===----------------------------------------------------------------------===// 1540 1541 DenseElementsAttr DenseIntElementsAttr::mapValues( 1542 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1543 llvm::SmallVector<char, 8> elementData; 1544 auto newArrayType = 1545 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1546 return getRaw(newArrayType, elementData); 1547 } 1548 1549 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1550 bool DenseIntElementsAttr::classof(Attribute attr) { 1551 return attr.isa<DenseElementsAttr>() && 1552 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1553 } 1554 1555 //===----------------------------------------------------------------------===// 1556 // OpaqueElementsAttr 1557 //===----------------------------------------------------------------------===// 1558 1559 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1560 Dialect *dialect = getContext()->getLoadedDialect(getDialect()); 1561 if (!dialect) 1562 return true; 1563 auto *interface = llvm::dyn_cast<DialectDecodeAttributesInterface>(dialect); 1564 if (!interface) 1565 return true; 1566 return failed(interface->decode(*this, result)); 1567 } 1568 1569 LogicalResult 1570 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1571 StringAttr dialect, StringRef value, 1572 ShapedType type) { 1573 if (!Dialect::isValidNamespace(dialect.strref())) 1574 return emitError() << "invalid dialect namespace '" << dialect << "'"; 1575 return success(); 1576 } 1577 1578 //===----------------------------------------------------------------------===// 1579 // SparseElementsAttr 1580 //===----------------------------------------------------------------------===// 1581 1582 /// Get a zero APFloat for the given sparse attribute. 1583 APFloat SparseElementsAttr::getZeroAPFloat() const { 1584 auto eltType = getElementType().cast<FloatType>(); 1585 return APFloat(eltType.getFloatSemantics()); 1586 } 1587 1588 /// Get a zero APInt for the given sparse attribute. 1589 APInt SparseElementsAttr::getZeroAPInt() const { 1590 auto eltType = getElementType().cast<IntegerType>(); 1591 return APInt::getZero(eltType.getWidth()); 1592 } 1593 1594 /// Get a zero attribute for the given attribute type. 1595 Attribute SparseElementsAttr::getZeroAttr() const { 1596 auto eltType = getElementType(); 1597 1598 // Handle floating point elements. 1599 if (eltType.isa<FloatType>()) 1600 return FloatAttr::get(eltType, 0); 1601 1602 // Handle complex elements. 1603 if (auto complexTy = eltType.dyn_cast<ComplexType>()) { 1604 auto eltType = complexTy.getElementType(); 1605 Attribute zero; 1606 if (eltType.isa<FloatType>()) 1607 zero = FloatAttr::get(eltType, 0); 1608 else // must be integer 1609 zero = IntegerAttr::get(eltType, 0); 1610 return ArrayAttr::get(complexTy.getContext(), 1611 ArrayRef<Attribute>{zero, zero}); 1612 } 1613 1614 // Handle string type. 1615 if (getValues().isa<DenseStringElementsAttr>()) 1616 return StringAttr::get("", eltType); 1617 1618 // Otherwise, this is an integer. 1619 return IntegerAttr::get(eltType, 0); 1620 } 1621 1622 /// Flatten, and return, all of the sparse indices in this attribute in 1623 /// row-major order. 1624 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1625 std::vector<ptrdiff_t> flatSparseIndices; 1626 1627 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1628 // as a 1-D index array. 1629 auto sparseIndices = getIndices(); 1630 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1631 if (sparseIndices.isSplat()) { 1632 SmallVector<uint64_t, 8> indices(getType().getRank(), 1633 *sparseIndexValues.begin()); 1634 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1635 return flatSparseIndices; 1636 } 1637 1638 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1639 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1640 size_t rank = getType().getRank(); 1641 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1642 flatSparseIndices.push_back(getFlattenedIndex( 1643 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1644 return flatSparseIndices; 1645 } 1646 1647 LogicalResult 1648 SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1649 ShapedType type, DenseIntElementsAttr sparseIndices, 1650 DenseElementsAttr values) { 1651 ShapedType valuesType = values.getType(); 1652 if (valuesType.getRank() != 1) 1653 return emitError() << "expected 1-d tensor for sparse element values"; 1654 1655 // Verify the indices and values shape. 1656 ShapedType indicesType = sparseIndices.getType(); 1657 auto emitShapeError = [&]() { 1658 return emitError() << "expected shape ([" << type.getShape() 1659 << "]); inferred shape of indices literal ([" 1660 << indicesType.getShape() 1661 << "]); inferred shape of values literal ([" 1662 << valuesType.getShape() << "])"; 1663 }; 1664 // Verify indices shape. 1665 size_t rank = type.getRank(), indicesRank = indicesType.getRank(); 1666 if (indicesRank == 2) { 1667 if (indicesType.getDimSize(1) != static_cast<int64_t>(rank)) 1668 return emitShapeError(); 1669 } else if (indicesRank != 1 || rank != 1) { 1670 return emitShapeError(); 1671 } 1672 // Verify the values shape. 1673 int64_t numSparseIndices = indicesType.getDimSize(0); 1674 if (numSparseIndices != valuesType.getDimSize(0)) 1675 return emitShapeError(); 1676 1677 // Verify that the sparse indices are within the value shape. 1678 auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) { 1679 return emitError() 1680 << "sparse index #" << indexNum 1681 << " is not contained within the value shape, with index=[" << index 1682 << "], and type=" << type; 1683 }; 1684 1685 // Handle the case where the index values are a splat. 1686 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1687 if (sparseIndices.isSplat()) { 1688 SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin()); 1689 if (!ElementsAttr::isValidIndex(type, indices)) 1690 return emitIndexError(0, indices); 1691 return success(); 1692 } 1693 1694 // Otherwise, reinterpret each index as an ArrayRef. 1695 for (size_t i = 0, e = numSparseIndices; i != e; ++i) { 1696 ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank), 1697 rank); 1698 if (!ElementsAttr::isValidIndex(type, index)) 1699 return emitIndexError(i, index); 1700 } 1701 1702 return success(); 1703 } 1704 1705 //===----------------------------------------------------------------------===// 1706 // TypeAttr 1707 //===----------------------------------------------------------------------===// 1708 1709 void TypeAttr::walkImmediateSubElements( 1710 function_ref<void(Attribute)> walkAttrsFn, 1711 function_ref<void(Type)> walkTypesFn) const { 1712 walkTypesFn(getValue()); 1713 } 1714