1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/BuiltinDialect.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/IntegerSet.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/IR/Operation.h" 17 #include "mlir/IR/SymbolTable.h" 18 #include "mlir/IR/Types.h" 19 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" 20 #include "llvm/ADT/APSInt.h" 21 #include "llvm/ADT/Sequence.h" 22 #include "llvm/Support/Endian.h" 23 24 using namespace mlir; 25 using namespace mlir::detail; 26 27 //===----------------------------------------------------------------------===// 28 /// Tablegen Attribute Definitions 29 //===----------------------------------------------------------------------===// 30 31 #define GET_ATTRDEF_CLASSES 32 #include "mlir/IR/BuiltinAttributes.cpp.inc" 33 34 //===----------------------------------------------------------------------===// 35 // BuiltinDialect 36 //===----------------------------------------------------------------------===// 37 38 void BuiltinDialect::registerAttributes() { 39 addAttributes<AffineMapAttr, ArrayAttr, DenseArrayBaseAttr, 40 DenseIntOrFPElementsAttr, DenseStringElementsAttr, 41 DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr, 42 IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr, 43 SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>(); 44 } 45 46 //===----------------------------------------------------------------------===// 47 // ArrayAttr 48 //===----------------------------------------------------------------------===// 49 50 void ArrayAttr::walkImmediateSubElements( 51 function_ref<void(Attribute)> walkAttrsFn, 52 function_ref<void(Type)> walkTypesFn) const { 53 for (Attribute attr : getValue()) 54 walkAttrsFn(attr); 55 } 56 57 Attribute 58 ArrayAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs, 59 ArrayRef<Type> replTypes) const { 60 return get(getContext(), replAttrs); 61 } 62 63 //===----------------------------------------------------------------------===// 64 // DictionaryAttr 65 //===----------------------------------------------------------------------===// 66 67 /// Helper function that does either an in place sort or sorts from source array 68 /// into destination. If inPlace then storage is both the source and the 69 /// destination, else value is the source and storage destination. Returns 70 /// whether source was sorted. 71 template <bool inPlace> 72 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 73 SmallVectorImpl<NamedAttribute> &storage) { 74 // Specialize for the common case. 75 switch (value.size()) { 76 case 0: 77 // Zero already sorted. 78 if (!inPlace) 79 storage.clear(); 80 break; 81 case 1: 82 // One already sorted but may need to be copied. 83 if (!inPlace) 84 storage.assign({value[0]}); 85 break; 86 case 2: { 87 bool isSorted = value[0] < value[1]; 88 if (inPlace) { 89 if (!isSorted) 90 std::swap(storage[0], storage[1]); 91 } else if (isSorted) { 92 storage.assign({value[0], value[1]}); 93 } else { 94 storage.assign({value[1], value[0]}); 95 } 96 return !isSorted; 97 } 98 default: 99 if (!inPlace) 100 storage.assign(value.begin(), value.end()); 101 // Check to see they are sorted already. 102 bool isSorted = llvm::is_sorted(value); 103 // If not, do a general sort. 104 if (!isSorted) 105 llvm::array_pod_sort(storage.begin(), storage.end()); 106 return !isSorted; 107 } 108 return false; 109 } 110 111 /// Returns an entry with a duplicate name from the given sorted array of named 112 /// attributes. Returns llvm::None if all elements have unique names. 113 static Optional<NamedAttribute> 114 findDuplicateElement(ArrayRef<NamedAttribute> value) { 115 const Optional<NamedAttribute> none{llvm::None}; 116 if (value.size() < 2) 117 return none; 118 119 if (value.size() == 2) 120 return value[0].getName() == value[1].getName() ? value[0] : none; 121 122 const auto *it = std::adjacent_find(value.begin(), value.end(), 123 [](NamedAttribute l, NamedAttribute r) { 124 return l.getName() == r.getName(); 125 }); 126 return it != value.end() ? *it : none; 127 } 128 129 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 130 SmallVectorImpl<NamedAttribute> &storage) { 131 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); 132 assert(!findDuplicateElement(storage) && 133 "DictionaryAttr element names must be unique"); 134 return isSorted; 135 } 136 137 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 138 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); 139 assert(!findDuplicateElement(array) && 140 "DictionaryAttr element names must be unique"); 141 return isSorted; 142 } 143 144 Optional<NamedAttribute> 145 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, 146 bool isSorted) { 147 if (!isSorted) 148 dictionaryAttrSort</*inPlace=*/true>(array, array); 149 return findDuplicateElement(array); 150 } 151 152 DictionaryAttr DictionaryAttr::get(MLIRContext *context, 153 ArrayRef<NamedAttribute> value) { 154 if (value.empty()) 155 return DictionaryAttr::getEmpty(context); 156 157 // We need to sort the element list to canonicalize it. 158 SmallVector<NamedAttribute, 8> storage; 159 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 160 value = storage; 161 assert(!findDuplicateElement(value) && 162 "DictionaryAttr element names must be unique"); 163 return Base::get(context, value); 164 } 165 /// Construct a dictionary with an array of values that is known to already be 166 /// sorted by name and uniqued. 167 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context, 168 ArrayRef<NamedAttribute> value) { 169 if (value.empty()) 170 return DictionaryAttr::getEmpty(context); 171 // Ensure that the attribute elements are unique and sorted. 172 assert(llvm::is_sorted( 173 value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && 174 "expected attribute values to be sorted"); 175 assert(!findDuplicateElement(value) && 176 "DictionaryAttr element names must be unique"); 177 return Base::get(context, value); 178 } 179 180 /// Return the specified attribute if present, null otherwise. 181 Attribute DictionaryAttr::get(StringRef name) const { 182 auto it = impl::findAttrSorted(begin(), end(), name); 183 return it.second ? it.first->getValue() : Attribute(); 184 } 185 Attribute DictionaryAttr::get(StringAttr name) const { 186 auto it = impl::findAttrSorted(begin(), end(), name); 187 return it.second ? it.first->getValue() : Attribute(); 188 } 189 190 /// Return the specified named attribute if present, None otherwise. 191 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 192 auto it = impl::findAttrSorted(begin(), end(), name); 193 return it.second ? *it.first : Optional<NamedAttribute>(); 194 } 195 Optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const { 196 auto it = impl::findAttrSorted(begin(), end(), name); 197 return it.second ? *it.first : Optional<NamedAttribute>(); 198 } 199 200 /// Return whether the specified attribute is present. 201 bool DictionaryAttr::contains(StringRef name) const { 202 return impl::findAttrSorted(begin(), end(), name).second; 203 } 204 bool DictionaryAttr::contains(StringAttr name) const { 205 return impl::findAttrSorted(begin(), end(), name).second; 206 } 207 208 DictionaryAttr::iterator DictionaryAttr::begin() const { 209 return getValue().begin(); 210 } 211 DictionaryAttr::iterator DictionaryAttr::end() const { 212 return getValue().end(); 213 } 214 size_t DictionaryAttr::size() const { return getValue().size(); } 215 216 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { 217 return Base::get(context, ArrayRef<NamedAttribute>()); 218 } 219 220 void DictionaryAttr::walkImmediateSubElements( 221 function_ref<void(Attribute)> walkAttrsFn, 222 function_ref<void(Type)> walkTypesFn) const { 223 for (const NamedAttribute &attr : getValue()) 224 walkAttrsFn(attr.getValue()); 225 } 226 227 Attribute 228 DictionaryAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs, 229 ArrayRef<Type> replTypes) const { 230 std::vector<NamedAttribute> vec = getValue().vec(); 231 for (auto &it : llvm::enumerate(replAttrs)) 232 vec[it.index()].setValue(it.value()); 233 234 // The above only modifies the mapped value, but not the key, and therefore 235 // not the order of the elements. It remains sorted 236 return getWithSorted(getContext(), vec); 237 } 238 239 //===----------------------------------------------------------------------===// 240 // StringAttr 241 //===----------------------------------------------------------------------===// 242 243 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) { 244 return Base::get(context, "", NoneType::get(context)); 245 } 246 247 /// Twine support for StringAttr. 248 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) { 249 // Fast-path empty twine. 250 if (twine.isTriviallyEmpty()) 251 return get(context); 252 SmallVector<char, 32> tempStr; 253 return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context)); 254 } 255 256 /// Twine support for StringAttr. 257 StringAttr StringAttr::get(const Twine &twine, Type type) { 258 SmallVector<char, 32> tempStr; 259 return Base::get(type.getContext(), twine.toStringRef(tempStr), type); 260 } 261 262 StringRef StringAttr::getValue() const { return getImpl()->value; } 263 264 Dialect *StringAttr::getReferencedDialect() const { 265 return getImpl()->referencedDialect; 266 } 267 268 //===----------------------------------------------------------------------===// 269 // FloatAttr 270 //===----------------------------------------------------------------------===// 271 272 double FloatAttr::getValueAsDouble() const { 273 return getValueAsDouble(getValue()); 274 } 275 double FloatAttr::getValueAsDouble(APFloat value) { 276 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 277 bool losesInfo = false; 278 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 279 &losesInfo); 280 } 281 return value.convertToDouble(); 282 } 283 284 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 285 Type type, APFloat value) { 286 // Verify that the type is correct. 287 if (!type.isa<FloatType>()) 288 return emitError() << "expected floating point type"; 289 290 // Verify that the type semantics match that of the value. 291 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 292 return emitError() 293 << "FloatAttr type doesn't match the type implied by its value"; 294 } 295 return success(); 296 } 297 298 //===----------------------------------------------------------------------===// 299 // SymbolRefAttr 300 //===----------------------------------------------------------------------===// 301 302 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value, 303 ArrayRef<FlatSymbolRefAttr> nestedRefs) { 304 return get(StringAttr::get(ctx, value), nestedRefs); 305 } 306 307 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { 308 return get(ctx, value, {}).cast<FlatSymbolRefAttr>(); 309 } 310 311 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) { 312 return get(value, {}).cast<FlatSymbolRefAttr>(); 313 } 314 315 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) { 316 auto symName = 317 symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); 318 assert(symName && "value does not have a valid symbol name"); 319 return SymbolRefAttr::get(symName); 320 } 321 322 StringAttr SymbolRefAttr::getLeafReference() const { 323 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 324 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr(); 325 } 326 327 void SymbolRefAttr::walkImmediateSubElements( 328 function_ref<void(Attribute)> walkAttrsFn, 329 function_ref<void(Type)> walkTypesFn) const { 330 walkAttrsFn(getRootReference()); 331 for (FlatSymbolRefAttr ref : getNestedReferences()) 332 walkAttrsFn(ref); 333 } 334 335 Attribute 336 SymbolRefAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs, 337 ArrayRef<Type> replTypes) const { 338 ArrayRef<Attribute> rawNestedRefs = replAttrs.drop_front(); 339 ArrayRef<FlatSymbolRefAttr> nestedRefs( 340 static_cast<const FlatSymbolRefAttr *>(rawNestedRefs.data()), 341 rawNestedRefs.size()); 342 return get(replAttrs[0].cast<StringAttr>(), nestedRefs); 343 } 344 345 //===----------------------------------------------------------------------===// 346 // IntegerAttr 347 //===----------------------------------------------------------------------===// 348 349 int64_t IntegerAttr::getInt() const { 350 assert((getType().isIndex() || getType().isSignlessInteger()) && 351 "must be signless integer"); 352 return getValue().getSExtValue(); 353 } 354 355 int64_t IntegerAttr::getSInt() const { 356 assert(getType().isSignedInteger() && "must be signed integer"); 357 return getValue().getSExtValue(); 358 } 359 360 uint64_t IntegerAttr::getUInt() const { 361 assert(getType().isUnsignedInteger() && "must be unsigned integer"); 362 return getValue().getZExtValue(); 363 } 364 365 /// Return the value as an APSInt which carries the signed from the type of 366 /// the attribute. This traps on signless integers types! 367 APSInt IntegerAttr::getAPSInt() const { 368 assert(!getType().isSignlessInteger() && 369 "Signless integers don't carry a sign for APSInt"); 370 return APSInt(getValue(), getType().isUnsignedInteger()); 371 } 372 373 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, 374 Type type, APInt value) { 375 if (IntegerType integerType = type.dyn_cast<IntegerType>()) { 376 if (integerType.getWidth() != value.getBitWidth()) 377 return emitError() << "integer type bit width (" << integerType.getWidth() 378 << ") doesn't match value bit width (" 379 << value.getBitWidth() << ")"; 380 return success(); 381 } 382 if (type.isa<IndexType>()) 383 return success(); 384 return emitError() << "expected integer or index type"; 385 } 386 387 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { 388 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value)); 389 return attr.cast<BoolAttr>(); 390 } 391 392 //===----------------------------------------------------------------------===// 393 // BoolAttr 394 //===----------------------------------------------------------------------===// 395 396 bool BoolAttr::getValue() const { 397 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl); 398 return storage->value.getBoolValue(); 399 } 400 401 bool BoolAttr::classof(Attribute attr) { 402 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 403 return intAttr && intAttr.getType().isSignlessInteger(1); 404 } 405 406 //===----------------------------------------------------------------------===// 407 // OpaqueAttr 408 //===----------------------------------------------------------------------===// 409 410 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError, 411 StringAttr dialect, StringRef attrData, 412 Type type) { 413 if (!Dialect::isValidNamespace(dialect.strref())) 414 return emitError() << "invalid dialect namespace '" << dialect << "'"; 415 416 // Check that the dialect is actually registered. 417 MLIRContext *context = dialect.getContext(); 418 if (!context->allowsUnregisteredDialects() && 419 !context->getLoadedDialect(dialect.strref())) { 420 return emitError() 421 << "#" << dialect << "<\"" << attrData << "\"> : " << type 422 << " attribute created with unregistered dialect. If this is " 423 "intended, please call allowUnregisteredDialects() on the " 424 "MLIRContext, or use -allow-unregistered-dialect with " 425 "the MLIR opt tool used"; 426 } 427 428 return success(); 429 } 430 431 //===----------------------------------------------------------------------===// 432 // DenseElementsAttr Utilities 433 //===----------------------------------------------------------------------===// 434 435 /// Get the bitwidth of a dense element type within the buffer. 436 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 437 static size_t getDenseElementStorageWidth(size_t origWidth) { 438 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 439 } 440 static size_t getDenseElementStorageWidth(Type elementType) { 441 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 442 } 443 444 /// Set a bit to a specific value. 445 static void setBit(char *rawData, size_t bitPos, bool value) { 446 if (value) 447 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 448 else 449 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 450 } 451 452 /// Return the value of the specified bit. 453 static bool getBit(const char *rawData, size_t bitPos) { 454 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 455 } 456 457 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for 458 /// BE format. 459 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, 460 char *result) { 461 assert(llvm::support::endian::system_endianness() == // NOLINT 462 llvm::support::endianness::big); // NOLINT 463 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 464 465 // Copy the words filled with data. 466 // For example, when `value` has 2 words, the first word is filled with data. 467 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| 468 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 469 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 470 numFilledWords, result); 471 // Convert last word of APInt to LE format and store it in char 472 // array(`valueLE`). 473 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| 474 size_t lastWordPos = numFilledWords; 475 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); 476 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 477 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, 478 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); 479 // Extract actual APInt data from `valueLE`, convert endianness to BE format, 480 // and store it in `result`. 481 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| 482 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 483 valueLE.begin(), result + lastWordPos, 484 (numBytes - lastWordPos) * CHAR_BIT, 1); 485 } 486 487 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE 488 /// format. 489 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, 490 APInt &result) { 491 assert(llvm::support::endian::system_endianness() == // NOLINT 492 llvm::support::endianness::big); // NOLINT 493 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 494 495 // Copy the data that fills the word of `result` from `inArray`. 496 // For example, when `result` has 2 words, the first word will be filled with 497 // data. So, the first 8 bytes are copied from `inArray` here. 498 // `inArray` (10 bytes, BE): |abcdefgh|ij| 499 // ==> `result` (2 words, BE): |abcdefgh|--------| 500 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 501 std::copy_n( 502 inArray, numFilledWords, 503 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 504 505 // Convert array data which will be last word of `result` to LE format, and 506 // store it in char array(`inArrayLE`). 507 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| 508 size_t lastWordPos = numFilledWords; 509 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); 510 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 511 inArray + lastWordPos, inArrayLE.begin(), 512 (numBytes - lastWordPos) * CHAR_BIT, 1); 513 514 // Convert `inArrayLE` to BE format, and store it in last word of `result`. 515 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| 516 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 517 inArrayLE.begin(), 518 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + 519 lastWordPos, 520 APInt::APINT_BITS_PER_WORD, 1); 521 } 522 523 /// Writes value to the bit position `bitPos` in array `rawData`. 524 static void writeBits(char *rawData, size_t bitPos, APInt value) { 525 size_t bitWidth = value.getBitWidth(); 526 527 // If the bitwidth is 1 we just toggle the specific bit. 528 if (bitWidth == 1) 529 return setBit(rawData, bitPos, value.isOneValue()); 530 531 // Otherwise, the bit position is guaranteed to be byte aligned. 532 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 533 if (llvm::support::endian::system_endianness() == 534 llvm::support::endianness::big) { 535 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. 536 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 537 // work correctly in BE format. 538 // ex. `value` (2 words including 10 bytes) 539 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| 540 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), 541 rawData + (bitPos / CHAR_BIT)); 542 } else { 543 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 544 llvm::divideCeil(bitWidth, CHAR_BIT), 545 rawData + (bitPos / CHAR_BIT)); 546 } 547 } 548 549 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 550 /// `rawData`. 551 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 552 // Handle a boolean bit position. 553 if (bitWidth == 1) 554 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 555 556 // Otherwise, the bit position must be 8-bit aligned. 557 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 558 APInt result(bitWidth, 0); 559 if (llvm::support::endian::system_endianness() == 560 llvm::support::endianness::big) { 561 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. 562 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 563 // work correctly in BE format. 564 // ex. `result` (2 words including 10 bytes) 565 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function 566 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), 567 llvm::divideCeil(bitWidth, CHAR_BIT), result); 568 } else { 569 std::copy_n(rawData + (bitPos / CHAR_BIT), 570 llvm::divideCeil(bitWidth, CHAR_BIT), 571 const_cast<char *>( 572 reinterpret_cast<const char *>(result.getRawData()))); 573 } 574 return result; 575 } 576 577 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 578 /// the same element count as 'type'. 579 template <typename Values> 580 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 581 return (values.size() == 1) || 582 (type.getNumElements() == static_cast<int64_t>(values.size())); 583 } 584 585 //===----------------------------------------------------------------------===// 586 // DenseElementsAttr Iterators 587 //===----------------------------------------------------------------------===// 588 589 //===----------------------------------------------------------------------===// 590 // AttributeElementIterator 591 592 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 593 DenseElementsAttr attr, size_t index) 594 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 595 Attribute, Attribute, Attribute>( 596 attr.getAsOpaquePointer(), index) {} 597 598 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 599 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 600 Type eltTy = owner.getElementType(); 601 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 602 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 603 if (eltTy.isa<IndexType>()) 604 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 605 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 606 IntElementIterator intIt(owner, index); 607 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 608 return FloatAttr::get(eltTy, *floatIt); 609 } 610 if (auto complexTy = eltTy.dyn_cast<ComplexType>()) { 611 auto complexEltTy = complexTy.getElementType(); 612 ComplexIntElementIterator complexIntIt(owner, index); 613 if (complexEltTy.isa<IntegerType>()) { 614 auto value = *complexIntIt; 615 auto real = IntegerAttr::get(complexEltTy, value.real()); 616 auto imag = IntegerAttr::get(complexEltTy, value.imag()); 617 return ArrayAttr::get(complexTy.getContext(), 618 ArrayRef<Attribute>{real, imag}); 619 } 620 621 ComplexFloatElementIterator complexFloatIt( 622 complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt); 623 auto value = *complexFloatIt; 624 auto real = FloatAttr::get(complexEltTy, value.real()); 625 auto imag = FloatAttr::get(complexEltTy, value.imag()); 626 return ArrayAttr::get(complexTy.getContext(), 627 ArrayRef<Attribute>{real, imag}); 628 } 629 if (owner.isa<DenseStringElementsAttr>()) { 630 ArrayRef<StringRef> vals = owner.getRawStringData(); 631 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 632 } 633 llvm_unreachable("unexpected element type"); 634 } 635 636 //===----------------------------------------------------------------------===// 637 // BoolElementIterator 638 639 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 640 DenseElementsAttr attr, size_t dataIndex) 641 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 642 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 643 644 bool DenseElementsAttr::BoolElementIterator::operator*() const { 645 return getBit(getData(), getDataIndex()); 646 } 647 648 //===----------------------------------------------------------------------===// 649 // IntElementIterator 650 651 DenseElementsAttr::IntElementIterator::IntElementIterator( 652 DenseElementsAttr attr, size_t dataIndex) 653 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 654 attr.getRawData().data(), attr.isSplat(), dataIndex), 655 bitWidth(getDenseElementBitWidth(attr.getElementType())) {} 656 657 APInt DenseElementsAttr::IntElementIterator::operator*() const { 658 return readBits(getData(), 659 getDataIndex() * getDenseElementStorageWidth(bitWidth), 660 bitWidth); 661 } 662 663 //===----------------------------------------------------------------------===// 664 // ComplexIntElementIterator 665 666 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 667 DenseElementsAttr attr, size_t dataIndex) 668 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 669 std::complex<APInt>, std::complex<APInt>, 670 std::complex<APInt>>( 671 attr.getRawData().data(), attr.isSplat(), dataIndex) { 672 auto complexType = attr.getElementType().cast<ComplexType>(); 673 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 674 } 675 676 std::complex<APInt> 677 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 678 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 679 size_t offset = getDataIndex() * storageWidth * 2; 680 return {readBits(getData(), offset, bitWidth), 681 readBits(getData(), offset + storageWidth, bitWidth)}; 682 } 683 684 //===----------------------------------------------------------------------===// 685 // DenseArrayAttr 686 //===----------------------------------------------------------------------===// 687 688 /// Custom storage to ensure proper memory alignment for the allocation of 689 /// DenseArray of any element type. 690 struct mlir::detail::DenseArrayBaseAttrStorage : public AttributeStorage { 691 using KeyTy = std::tuple<ShapedType, DenseArrayBaseAttr::EltType, 692 ::llvm::ArrayRef<char>>; 693 DenseArrayBaseAttrStorage(ShapedType type, 694 DenseArrayBaseAttr::EltType eltType, 695 ::llvm::ArrayRef<char> elements) 696 : AttributeStorage(type), eltType(eltType), elements(elements) {} 697 698 bool operator==(const KeyTy &tblgenKey) const { 699 return (getType() == std::get<0>(tblgenKey)) && 700 (eltType == std::get<1>(tblgenKey)) && 701 (elements == std::get<2>(tblgenKey)); 702 } 703 704 static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { 705 return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey), 706 std::get<2>(tblgenKey)); 707 } 708 709 static DenseArrayBaseAttrStorage * 710 construct(AttributeStorageAllocator &allocator, const KeyTy &tblgenKey) { 711 auto type = std::get<0>(tblgenKey); 712 auto eltType = std::get<1>(tblgenKey); 713 auto elements = std::get<2>(tblgenKey); 714 if (!elements.empty()) { 715 char *alloc = static_cast<char *>( 716 allocator.allocate(elements.size(), alignof(uint64_t))); 717 std::uninitialized_copy(elements.begin(), elements.end(), alloc); 718 elements = ArrayRef<char>(alloc, elements.size()); 719 } 720 return new (allocator.allocate<DenseArrayBaseAttrStorage>()) 721 DenseArrayBaseAttrStorage(type, eltType, elements); 722 } 723 724 DenseArrayBaseAttr::EltType eltType; 725 ::llvm::ArrayRef<char> elements; 726 }; 727 728 DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const { 729 return getImpl()->eltType; 730 } 731 732 const int8_t * 733 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const { 734 return cast<DenseI8ArrayAttr>().asArrayRef().begin(); 735 } 736 const int16_t * 737 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int16_t>) const { 738 return cast<DenseI16ArrayAttr>().asArrayRef().begin(); 739 } 740 const int32_t * 741 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int32_t>) const { 742 return cast<DenseI32ArrayAttr>().asArrayRef().begin(); 743 } 744 const int64_t * 745 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int64_t>) const { 746 return cast<DenseI64ArrayAttr>().asArrayRef().begin(); 747 } 748 const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken<float>) const { 749 return cast<DenseF32ArrayAttr>().asArrayRef().begin(); 750 } 751 const double * 752 DenseArrayBaseAttr::value_begin_impl(OverloadToken<double>) const { 753 return cast<DenseF64ArrayAttr>().asArrayRef().begin(); 754 } 755 756 void DenseArrayBaseAttr::print(AsmPrinter &printer) const { 757 print(printer.getStream()); 758 } 759 760 void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const { 761 switch (getElementType()) { 762 case DenseArrayBaseAttr::EltType::I8: 763 this->cast<DenseI8ArrayAttr>().printWithoutBraces(os); 764 return; 765 case DenseArrayBaseAttr::EltType::I16: 766 this->cast<DenseI16ArrayAttr>().printWithoutBraces(os); 767 return; 768 case DenseArrayBaseAttr::EltType::I32: 769 this->cast<DenseI32ArrayAttr>().printWithoutBraces(os); 770 return; 771 case DenseArrayBaseAttr::EltType::I64: 772 this->cast<DenseI64ArrayAttr>().printWithoutBraces(os); 773 return; 774 case DenseArrayBaseAttr::EltType::F32: 775 this->cast<DenseF32ArrayAttr>().printWithoutBraces(os); 776 return; 777 case DenseArrayBaseAttr::EltType::F64: 778 this->cast<DenseF64ArrayAttr>().printWithoutBraces(os); 779 return; 780 } 781 llvm_unreachable("<unknown DenseArrayBaseAttr>"); 782 } 783 784 void DenseArrayBaseAttr::print(raw_ostream &os) const { 785 os << "["; 786 printWithoutBraces(os); 787 os << "]"; 788 } 789 790 template <typename T> 791 void DenseArrayAttr<T>::print(AsmPrinter &printer) const { 792 print(printer.getStream()); 793 } 794 795 template <typename T> 796 void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const { 797 ArrayRef<T> values{*this}; 798 llvm::interleaveComma(values, os); 799 } 800 801 /// Specialization for int8_t for forcing printing as number instead of chars. 802 template <> 803 void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const { 804 ArrayRef<int8_t> values{*this}; 805 llvm::interleaveComma(values, os, [&](int64_t v) { os << v; }); 806 } 807 808 template <typename T> 809 void DenseArrayAttr<T>::print(raw_ostream &os) const { 810 os << "["; 811 printWithoutBraces(os); 812 os << "]"; 813 } 814 815 /// Parse a single element: generic template for int types, specialized for 816 /// floating points below. 817 template <typename T> 818 static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) { 819 return parser.parseInteger(value); 820 } 821 822 template <> 823 ParseResult parseDenseArrayAttrElt<float>(AsmParser &parser, float &value) { 824 double doubleVal; 825 if (parser.parseFloat(doubleVal)) 826 return failure(); 827 value = doubleVal; 828 return success(); 829 } 830 831 template <> 832 ParseResult parseDenseArrayAttrElt<double>(AsmParser &parser, double &value) { 833 return parser.parseFloat(value); 834 } 835 836 /// Parse a DenseArrayAttr without the braces: `1, 2, 3` 837 template <typename T> 838 Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser, 839 Type odsType) { 840 SmallVector<T> data; 841 if (failed(parser.parseCommaSeparatedList([&]() { 842 T value; 843 if (parseDenseArrayAttrElt(parser, value)) 844 return failure(); 845 data.push_back(value); 846 return success(); 847 }))) 848 return {}; 849 return get(parser.getContext(), data); 850 } 851 852 /// Parse a DenseArrayAttr: `[ 1, 2, 3 ]` 853 template <typename T> 854 Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) { 855 if (parser.parseLSquare()) 856 return {}; 857 // Handle empty list case. 858 if (succeeded(parser.parseOptionalRSquare())) 859 return get(parser.getContext(), {}); 860 Attribute result = parseWithoutBraces(parser, odsType); 861 if (parser.parseRSquare()) 862 return {}; 863 return result; 864 } 865 866 /// Conversion from DenseArrayAttr<T> to ArrayRef<T>. 867 template <typename T> 868 DenseArrayAttr<T>::operator ArrayRef<T>() const { 869 ArrayRef<char> raw = getImpl()->elements; 870 assert((raw.size() % sizeof(T)) == 0); 871 return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()), 872 raw.size() / sizeof(T)); 873 } 874 875 namespace { 876 /// Mapping from C++ element type to MLIR DenseArrayAttr internals. 877 template <typename T> 878 struct denseArrayAttrEltTypeBuilder; 879 template <> 880 struct denseArrayAttrEltTypeBuilder<int8_t> { 881 constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8; 882 static ShapedType getShapedType(MLIRContext *context, 883 ArrayRef<int64_t> shape) { 884 return VectorType::get(shape, IntegerType::get(context, 8)); 885 } 886 }; 887 template <> 888 struct denseArrayAttrEltTypeBuilder<int16_t> { 889 constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16; 890 static ShapedType getShapedType(MLIRContext *context, 891 ArrayRef<int64_t> shape) { 892 return VectorType::get(shape, IntegerType::get(context, 16)); 893 } 894 }; 895 template <> 896 struct denseArrayAttrEltTypeBuilder<int32_t> { 897 constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32; 898 static ShapedType getShapedType(MLIRContext *context, 899 ArrayRef<int64_t> shape) { 900 return VectorType::get(shape, IntegerType::get(context, 32)); 901 } 902 }; 903 template <> 904 struct denseArrayAttrEltTypeBuilder<int64_t> { 905 constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64; 906 static ShapedType getShapedType(MLIRContext *context, 907 ArrayRef<int64_t> shape) { 908 return VectorType::get(shape, IntegerType::get(context, 64)); 909 } 910 }; 911 template <> 912 struct denseArrayAttrEltTypeBuilder<float> { 913 constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32; 914 static ShapedType getShapedType(MLIRContext *context, 915 ArrayRef<int64_t> shape) { 916 return VectorType::get(shape, Float32Type::get(context)); 917 } 918 }; 919 template <> 920 struct denseArrayAttrEltTypeBuilder<double> { 921 constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64; 922 static ShapedType getShapedType(MLIRContext *context, 923 ArrayRef<int64_t> shape) { 924 return VectorType::get(shape, Float64Type::get(context)); 925 } 926 }; 927 } // namespace 928 929 /// Builds a DenseArrayAttr<T> from an ArrayRef<T>. 930 template <typename T> 931 DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context, 932 ArrayRef<T> content) { 933 auto size = static_cast<int64_t>(content.size()); 934 auto shapedType = denseArrayAttrEltTypeBuilder<T>::getShapedType( 935 context, size ? ArrayRef<int64_t>{size} : ArrayRef<int64_t>{}); 936 auto eltType = denseArrayAttrEltTypeBuilder<T>::eltType; 937 auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()), 938 content.size() * sizeof(T)); 939 return Base::get(context, shapedType, eltType, rawArray) 940 .template cast<DenseArrayAttr<T>>(); 941 } 942 943 template <typename T> 944 bool DenseArrayAttr<T>::classof(Attribute attr) { 945 return attr.isa<DenseArrayBaseAttr>() && 946 attr.cast<DenseArrayBaseAttr>().getElementType() == 947 denseArrayAttrEltTypeBuilder<T>::eltType; 948 } 949 950 namespace mlir { 951 namespace detail { 952 // Explicit instantiation for all the supported DenseArrayAttr. 953 template class DenseArrayAttr<int8_t>; 954 template class DenseArrayAttr<int16_t>; 955 template class DenseArrayAttr<int32_t>; 956 template class DenseArrayAttr<int64_t>; 957 template class DenseArrayAttr<float>; 958 template class DenseArrayAttr<double>; 959 } // namespace detail 960 } // namespace mlir 961 962 //===----------------------------------------------------------------------===// 963 // DenseElementsAttr 964 //===----------------------------------------------------------------------===// 965 966 /// Method for support type inquiry through isa, cast and dyn_cast. 967 bool DenseElementsAttr::classof(Attribute attr) { 968 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 969 } 970 971 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 972 ArrayRef<Attribute> values) { 973 assert(hasSameElementsOrSplat(type, values)); 974 975 // If the element type is not based on int/float/index, assume it is a string 976 // type. 977 auto eltType = type.getElementType(); 978 if (!type.getElementType().isIntOrIndexOrFloat()) { 979 SmallVector<StringRef, 8> stringValues; 980 stringValues.reserve(values.size()); 981 for (Attribute attr : values) { 982 assert(attr.isa<StringAttr>() && 983 "expected string value for non integer/index/float element"); 984 stringValues.push_back(attr.cast<StringAttr>().getValue()); 985 } 986 return get(type, stringValues); 987 } 988 989 // Otherwise, get the raw storage width to use for the allocation. 990 size_t bitWidth = getDenseElementBitWidth(eltType); 991 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 992 993 // Compress the attribute values into a character buffer. 994 SmallVector<char, 8> data( 995 llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT)); 996 APInt intVal; 997 for (unsigned i = 0, e = values.size(); i < e; ++i) { 998 assert(eltType == values[i].getType() && 999 "expected attribute value to have element type"); 1000 if (eltType.isa<FloatType>()) 1001 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 1002 else if (eltType.isa<IntegerType, IndexType>()) 1003 intVal = values[i].cast<IntegerAttr>().getValue(); 1004 else 1005 llvm_unreachable("unexpected element type"); 1006 1007 assert(intVal.getBitWidth() == bitWidth && 1008 "expected value to have same bitwidth as element type"); 1009 writeBits(data.data(), i * storageBitWidth, intVal); 1010 } 1011 1012 // Handle the special encoding of splat of bool. 1013 if (values.size() == 1 && values[0].getType().isInteger(1)) 1014 data[0] = data[0] ? -1 : 0; 1015 1016 return DenseIntOrFPElementsAttr::getRaw(type, data); 1017 } 1018 1019 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 1020 ArrayRef<bool> values) { 1021 assert(hasSameElementsOrSplat(type, values)); 1022 assert(type.getElementType().isInteger(1)); 1023 1024 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 1025 1026 if (!values.empty()) { 1027 bool isSplat = true; 1028 bool firstValue = values[0]; 1029 for (int i = 0, e = values.size(); i != e; ++i) { 1030 isSplat &= values[i] == firstValue; 1031 setBit(buff.data(), i, values[i]); 1032 } 1033 1034 // Splat of bool is encoded as a byte with all-ones in it. 1035 if (isSplat) { 1036 buff.resize(1); 1037 buff[0] = values[0] ? -1 : 0; 1038 } 1039 } 1040 1041 return DenseIntOrFPElementsAttr::getRaw(type, buff); 1042 } 1043 1044 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 1045 ArrayRef<StringRef> values) { 1046 assert(!type.getElementType().isIntOrFloat()); 1047 return DenseStringElementsAttr::get(type, values); 1048 } 1049 1050 /// Constructs a dense integer elements attribute from an array of APInt 1051 /// values. Each APInt value is expected to have the same bitwidth as the 1052 /// element type of 'type'. 1053 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 1054 ArrayRef<APInt> values) { 1055 assert(type.getElementType().isIntOrIndex()); 1056 assert(hasSameElementsOrSplat(type, values)); 1057 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 1058 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); 1059 } 1060 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 1061 ArrayRef<std::complex<APInt>> values) { 1062 ComplexType complex = type.getElementType().cast<ComplexType>(); 1063 assert(complex.getElementType().isa<IntegerType>()); 1064 assert(hasSameElementsOrSplat(type, values)); 1065 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 1066 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 1067 values.size() * 2); 1068 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals); 1069 } 1070 1071 // Constructs a dense float elements attribute from an array of APFloat 1072 // values. Each APFloat value is expected to have the same bitwidth as the 1073 // element type of 'type'. 1074 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 1075 ArrayRef<APFloat> values) { 1076 assert(type.getElementType().isa<FloatType>()); 1077 assert(hasSameElementsOrSplat(type, values)); 1078 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 1079 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); 1080 } 1081 DenseElementsAttr 1082 DenseElementsAttr::get(ShapedType type, 1083 ArrayRef<std::complex<APFloat>> values) { 1084 ComplexType complex = type.getElementType().cast<ComplexType>(); 1085 assert(complex.getElementType().isa<FloatType>()); 1086 assert(hasSameElementsOrSplat(type, values)); 1087 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 1088 values.size() * 2); 1089 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 1090 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals); 1091 } 1092 1093 /// Construct a dense elements attribute from a raw buffer representing the 1094 /// data for this attribute. Users should generally not use this methods as 1095 /// the expected buffer format may not be a form the user expects. 1096 DenseElementsAttr 1097 DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) { 1098 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer); 1099 } 1100 1101 /// Returns true if the given buffer is a valid raw buffer for the given type. 1102 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 1103 ArrayRef<char> rawBuffer, 1104 bool &detectedSplat) { 1105 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 1106 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 1107 int64_t numElements = type.getNumElements(); 1108 1109 // The initializer is always a splat if the result type has a single element. 1110 detectedSplat = numElements == 1; 1111 1112 // Storage width of 1 is special as it is packed by the bit. 1113 if (storageWidth == 1) { 1114 // Check for a splat, or a buffer equal to the number of elements which 1115 // consists of either all 0's or all 1's. 1116 if (rawBuffer.size() == 1) { 1117 auto rawByte = static_cast<uint8_t>(rawBuffer[0]); 1118 if (rawByte == 0 || rawByte == 0xff) { 1119 detectedSplat = true; 1120 return true; 1121 } 1122 } 1123 1124 // This is a valid non-splat buffer if it has the right size. 1125 return rawBufferWidth == llvm::alignTo<8>(numElements); 1126 } 1127 1128 // All other types are 8-bit aligned, so we can just check the buffer width 1129 // to know if only a single initializer element was passed in. 1130 if (rawBufferWidth == storageWidth) { 1131 detectedSplat = true; 1132 return true; 1133 } 1134 1135 // The raw buffer is valid if it has the right size. 1136 return rawBufferWidth == storageWidth * numElements; 1137 } 1138 1139 /// Check the information for a C++ data type, check if this type is valid for 1140 /// the current attribute. This method is used to verify specific type 1141 /// invariants that the templatized 'getValues' method cannot. 1142 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 1143 bool isSigned) { 1144 // Make sure that the data element size is the same as the type element width. 1145 if (getDenseElementBitWidth(type) != 1146 static_cast<size_t>(dataEltSize * CHAR_BIT)) 1147 return false; 1148 1149 // Check that the element type is either float or integer or index. 1150 if (!isInt) 1151 return type.isa<FloatType>(); 1152 if (type.isIndex()) 1153 return true; 1154 1155 auto intType = type.dyn_cast<IntegerType>(); 1156 if (!intType) 1157 return false; 1158 1159 // Make sure signedness semantics is consistent. 1160 if (intType.isSignless()) 1161 return true; 1162 return intType.isSigned() ? isSigned : !isSigned; 1163 } 1164 1165 /// Defaults down the subclass implementation. 1166 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 1167 ArrayRef<char> data, 1168 int64_t dataEltSize, 1169 bool isInt, bool isSigned) { 1170 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 1171 isSigned); 1172 } 1173 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 1174 ArrayRef<char> data, 1175 int64_t dataEltSize, 1176 bool isInt, 1177 bool isSigned) { 1178 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 1179 isInt, isSigned); 1180 } 1181 1182 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 1183 bool isSigned) const { 1184 return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned); 1185 } 1186 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 1187 bool isSigned) const { 1188 return ::isValidIntOrFloat( 1189 getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, 1190 isInt, isSigned); 1191 } 1192 1193 /// Returns true if this attribute corresponds to a splat, i.e. if all element 1194 /// values are the same. 1195 bool DenseElementsAttr::isSplat() const { 1196 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 1197 } 1198 1199 /// Return if the given complex type has an integer element type. 1200 LLVM_ATTRIBUTE_UNUSED static bool isComplexOfIntType(Type type) { 1201 return type.cast<ComplexType>().getElementType().isa<IntegerType>(); 1202 } 1203 1204 auto DenseElementsAttr::getComplexIntValues() const 1205 -> iterator_range_impl<ComplexIntElementIterator> { 1206 assert(isComplexOfIntType(getElementType()) && 1207 "expected complex integral type"); 1208 return {getType(), ComplexIntElementIterator(*this, 0), 1209 ComplexIntElementIterator(*this, getNumElements())}; 1210 } 1211 auto DenseElementsAttr::complex_value_begin() const 1212 -> ComplexIntElementIterator { 1213 assert(isComplexOfIntType(getElementType()) && 1214 "expected complex integral type"); 1215 return ComplexIntElementIterator(*this, 0); 1216 } 1217 auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator { 1218 assert(isComplexOfIntType(getElementType()) && 1219 "expected complex integral type"); 1220 return ComplexIntElementIterator(*this, getNumElements()); 1221 } 1222 1223 /// Return the held element values as a range of APFloat. The element type of 1224 /// this attribute must be of float type. 1225 auto DenseElementsAttr::getFloatValues() const 1226 -> iterator_range_impl<FloatElementIterator> { 1227 auto elementType = getElementType().cast<FloatType>(); 1228 const auto &elementSemantics = elementType.getFloatSemantics(); 1229 return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()), 1230 FloatElementIterator(elementSemantics, raw_int_end())}; 1231 } 1232 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 1233 auto elementType = getElementType().cast<FloatType>(); 1234 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin()); 1235 } 1236 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 1237 auto elementType = getElementType().cast<FloatType>(); 1238 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end()); 1239 } 1240 1241 auto DenseElementsAttr::getComplexFloatValues() const 1242 -> iterator_range_impl<ComplexFloatElementIterator> { 1243 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 1244 assert(eltTy.isa<FloatType>() && "expected complex float type"); 1245 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 1246 return {getType(), 1247 {semantics, {*this, 0}}, 1248 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 1249 } 1250 auto DenseElementsAttr::complex_float_value_begin() const 1251 -> ComplexFloatElementIterator { 1252 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 1253 assert(eltTy.isa<FloatType>() && "expected complex float type"); 1254 return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}}; 1255 } 1256 auto DenseElementsAttr::complex_float_value_end() const 1257 -> ComplexFloatElementIterator { 1258 Type eltTy = getElementType().cast<ComplexType>().getElementType(); 1259 assert(eltTy.isa<FloatType>() && "expected complex float type"); 1260 return {eltTy.cast<FloatType>().getFloatSemantics(), 1261 {*this, static_cast<size_t>(getNumElements())}}; 1262 } 1263 1264 /// Return the raw storage data held by this attribute. 1265 ArrayRef<char> DenseElementsAttr::getRawData() const { 1266 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data; 1267 } 1268 1269 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 1270 return static_cast<DenseStringElementsAttrStorage *>(impl)->data; 1271 } 1272 1273 /// Return a new DenseElementsAttr that has the same data as the current 1274 /// attribute, but has been reshaped to 'newType'. The new type must have the 1275 /// same total number of elements as well as element type. 1276 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 1277 ShapedType curType = getType(); 1278 if (curType == newType) 1279 return *this; 1280 1281 assert(newType.getElementType() == curType.getElementType() && 1282 "expected the same element type"); 1283 assert(newType.getNumElements() == curType.getNumElements() && 1284 "expected the same number of elements"); 1285 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData()); 1286 } 1287 1288 DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) { 1289 assert(isSplat() && "expected a splat type"); 1290 1291 ShapedType curType = getType(); 1292 if (curType == newType) 1293 return *this; 1294 1295 assert(newType.getElementType() == curType.getElementType() && 1296 "expected the same element type"); 1297 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData()); 1298 } 1299 1300 /// Return a new DenseElementsAttr that has the same data as the current 1301 /// attribute, but has bitcast elements such that it is now 'newType'. The new 1302 /// type must have the same shape and element types of the same bitwidth as the 1303 /// current type. 1304 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) { 1305 ShapedType curType = getType(); 1306 Type curElType = curType.getElementType(); 1307 if (curElType == newElType) 1308 return *this; 1309 1310 assert(getDenseElementBitWidth(newElType) == 1311 getDenseElementBitWidth(curElType) && 1312 "expected element types with the same bitwidth"); 1313 return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType), 1314 getRawData()); 1315 } 1316 1317 DenseElementsAttr 1318 DenseElementsAttr::mapValues(Type newElementType, 1319 function_ref<APInt(const APInt &)> mapping) const { 1320 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 1321 } 1322 1323 DenseElementsAttr DenseElementsAttr::mapValues( 1324 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1325 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 1326 } 1327 1328 ShapedType DenseElementsAttr::getType() const { 1329 return Attribute::getType().cast<ShapedType>(); 1330 } 1331 1332 Type DenseElementsAttr::getElementType() const { 1333 return getType().getElementType(); 1334 } 1335 1336 int64_t DenseElementsAttr::getNumElements() const { 1337 return getType().getNumElements(); 1338 } 1339 1340 //===----------------------------------------------------------------------===// 1341 // DenseIntOrFPElementsAttr 1342 //===----------------------------------------------------------------------===// 1343 1344 /// Utility method to write a range of APInt values to a buffer. 1345 template <typename APRangeT> 1346 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1347 APRangeT &&values) { 1348 size_t numValues = llvm::size(values); 1349 data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT)); 1350 size_t offset = 0; 1351 for (auto it = values.begin(), e = values.end(); it != e; 1352 ++it, offset += storageWidth) { 1353 assert((*it).getBitWidth() <= storageWidth); 1354 writeBits(data.data(), offset, *it); 1355 } 1356 1357 // Handle the special encoding of splat of a boolean. 1358 if (numValues == 1 && (*values.begin()).getBitWidth() == 1) 1359 data[0] = data[0] ? -1 : 0; 1360 } 1361 1362 /// Constructs a dense elements attribute from an array of raw APFloat values. 1363 /// Each APFloat value is expected to have the same bitwidth as the element 1364 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1365 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1366 size_t storageWidth, 1367 ArrayRef<APFloat> values) { 1368 std::vector<char> data; 1369 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1370 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1371 return DenseIntOrFPElementsAttr::getRaw(type, data); 1372 } 1373 1374 /// Constructs a dense elements attribute from an array of raw APInt values. 1375 /// Each APInt value is expected to have the same bitwidth as the element type 1376 /// of 'type'. 1377 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1378 size_t storageWidth, 1379 ArrayRef<APInt> values) { 1380 std::vector<char> data; 1381 writeAPIntsToBuffer(storageWidth, data, values); 1382 return DenseIntOrFPElementsAttr::getRaw(type, data); 1383 } 1384 1385 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1386 ArrayRef<char> data) { 1387 assert((type.isa<RankedTensorType, VectorType>()) && 1388 "type must be ranked tensor or vector"); 1389 assert(type.hasStaticShape() && "type must have static shape"); 1390 bool isSplat = false; 1391 bool isValid = isValidRawBuffer(type, data, isSplat); 1392 assert(isValid); 1393 (void)isValid; 1394 return Base::get(type.getContext(), type, data, isSplat); 1395 } 1396 1397 /// Overload of the raw 'get' method that asserts that the given type is of 1398 /// complex type. This method is used to verify type invariants that the 1399 /// templatized 'get' method cannot. 1400 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1401 ArrayRef<char> data, 1402 int64_t dataEltSize, 1403 bool isInt, 1404 bool isSigned) { 1405 assert(::isValidIntOrFloat( 1406 type.getElementType().cast<ComplexType>().getElementType(), 1407 dataEltSize / 2, isInt, isSigned)); 1408 1409 int64_t numElements = data.size() / dataEltSize; 1410 (void)numElements; 1411 assert(numElements == 1 || numElements == type.getNumElements()); 1412 return getRaw(type, data); 1413 } 1414 1415 /// Overload of the 'getRaw' method that asserts that the given type is of 1416 /// integer type. This method is used to verify type invariants that the 1417 /// templatized 'get' method cannot. 1418 DenseElementsAttr 1419 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1420 int64_t dataEltSize, bool isInt, 1421 bool isSigned) { 1422 assert( 1423 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1424 1425 int64_t numElements = data.size() / dataEltSize; 1426 assert(numElements == 1 || numElements == type.getNumElements()); 1427 (void)numElements; 1428 return getRaw(type, data); 1429 } 1430 1431 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 1432 const char *inRawData, char *outRawData, size_t elementBitWidth, 1433 size_t numElements) { 1434 using llvm::support::ulittle16_t; 1435 using llvm::support::ulittle32_t; 1436 using llvm::support::ulittle64_t; 1437 1438 assert(llvm::support::endian::system_endianness() == // NOLINT 1439 llvm::support::endianness::big); // NOLINT 1440 // NOLINT to avoid warning message about replacing by static_assert() 1441 1442 // Following std::copy_n always converts endianness on BE machine. 1443 switch (elementBitWidth) { 1444 case 16: { 1445 const ulittle16_t *inRawDataPos = 1446 reinterpret_cast<const ulittle16_t *>(inRawData); 1447 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); 1448 std::copy_n(inRawDataPos, numElements, outDataPos); 1449 break; 1450 } 1451 case 32: { 1452 const ulittle32_t *inRawDataPos = 1453 reinterpret_cast<const ulittle32_t *>(inRawData); 1454 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); 1455 std::copy_n(inRawDataPos, numElements, outDataPos); 1456 break; 1457 } 1458 case 64: { 1459 const ulittle64_t *inRawDataPos = 1460 reinterpret_cast<const ulittle64_t *>(inRawData); 1461 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); 1462 std::copy_n(inRawDataPos, numElements, outDataPos); 1463 break; 1464 } 1465 default: { 1466 size_t nBytes = elementBitWidth / CHAR_BIT; 1467 for (size_t i = 0; i < nBytes; i++) 1468 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); 1469 break; 1470 } 1471 } 1472 } 1473 1474 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1475 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, 1476 ShapedType type) { 1477 size_t numElements = type.getNumElements(); 1478 Type elementType = type.getElementType(); 1479 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1480 elementType = complexTy.getElementType(); 1481 numElements = numElements * 2; 1482 } 1483 size_t elementBitWidth = getDenseElementStorageWidth(elementType); 1484 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && 1485 inRawData.size() <= outRawData.size()); 1486 if (elementBitWidth <= CHAR_BIT) 1487 std::memcpy(outRawData.begin(), inRawData.begin(), inRawData.size()); 1488 else 1489 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), 1490 elementBitWidth, numElements); 1491 } 1492 1493 //===----------------------------------------------------------------------===// 1494 // DenseFPElementsAttr 1495 //===----------------------------------------------------------------------===// 1496 1497 template <typename Fn, typename Attr> 1498 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1499 Type newElementType, 1500 llvm::SmallVectorImpl<char> &data) { 1501 size_t bitWidth = getDenseElementBitWidth(newElementType); 1502 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1503 1504 ShapedType newArrayType; 1505 if (inType.isa<RankedTensorType>()) 1506 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1507 else if (inType.isa<UnrankedTensorType>()) 1508 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1509 else if (auto vType = inType.dyn_cast<VectorType>()) 1510 newArrayType = VectorType::get(vType.getShape(), newElementType, 1511 vType.getNumScalableDims()); 1512 else 1513 assert(newArrayType && "Unhandled tensor type"); 1514 1515 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1516 data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT)); 1517 1518 // Functor used to process a single element value of the attribute. 1519 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1520 auto newInt = mapping(value); 1521 assert(newInt.getBitWidth() == bitWidth); 1522 writeBits(data.data(), index * storageBitWidth, newInt); 1523 }; 1524 1525 // Check for the splat case. 1526 if (attr.isSplat()) { 1527 processElt(*attr.begin(), /*index=*/0); 1528 return newArrayType; 1529 } 1530 1531 // Otherwise, process all of the element values. 1532 uint64_t elementIdx = 0; 1533 for (auto value : attr) 1534 processElt(value, elementIdx++); 1535 return newArrayType; 1536 } 1537 1538 DenseElementsAttr DenseFPElementsAttr::mapValues( 1539 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1540 llvm::SmallVector<char, 8> elementData; 1541 auto newArrayType = 1542 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1543 1544 return getRaw(newArrayType, elementData); 1545 } 1546 1547 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1548 bool DenseFPElementsAttr::classof(Attribute attr) { 1549 return attr.isa<DenseElementsAttr>() && 1550 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1551 } 1552 1553 //===----------------------------------------------------------------------===// 1554 // DenseIntElementsAttr 1555 //===----------------------------------------------------------------------===// 1556 1557 DenseElementsAttr DenseIntElementsAttr::mapValues( 1558 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1559 llvm::SmallVector<char, 8> elementData; 1560 auto newArrayType = 1561 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1562 return getRaw(newArrayType, elementData); 1563 } 1564 1565 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1566 bool DenseIntElementsAttr::classof(Attribute attr) { 1567 return attr.isa<DenseElementsAttr>() && 1568 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1569 } 1570 1571 //===----------------------------------------------------------------------===// 1572 // OpaqueElementsAttr 1573 //===----------------------------------------------------------------------===// 1574 1575 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1576 Dialect *dialect = getContext()->getLoadedDialect(getDialect()); 1577 if (!dialect) 1578 return true; 1579 auto *interface = llvm::dyn_cast<DialectDecodeAttributesInterface>(dialect); 1580 if (!interface) 1581 return true; 1582 return failed(interface->decode(*this, result)); 1583 } 1584 1585 LogicalResult 1586 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1587 StringAttr dialect, StringRef value, 1588 ShapedType type) { 1589 if (!Dialect::isValidNamespace(dialect.strref())) 1590 return emitError() << "invalid dialect namespace '" << dialect << "'"; 1591 return success(); 1592 } 1593 1594 //===----------------------------------------------------------------------===// 1595 // SparseElementsAttr 1596 //===----------------------------------------------------------------------===// 1597 1598 /// Get a zero APFloat for the given sparse attribute. 1599 APFloat SparseElementsAttr::getZeroAPFloat() const { 1600 auto eltType = getElementType().cast<FloatType>(); 1601 return APFloat(eltType.getFloatSemantics()); 1602 } 1603 1604 /// Get a zero APInt for the given sparse attribute. 1605 APInt SparseElementsAttr::getZeroAPInt() const { 1606 auto eltType = getElementType().cast<IntegerType>(); 1607 return APInt::getZero(eltType.getWidth()); 1608 } 1609 1610 /// Get a zero attribute for the given attribute type. 1611 Attribute SparseElementsAttr::getZeroAttr() const { 1612 auto eltType = getElementType(); 1613 1614 // Handle floating point elements. 1615 if (eltType.isa<FloatType>()) 1616 return FloatAttr::get(eltType, 0); 1617 1618 // Handle complex elements. 1619 if (auto complexTy = eltType.dyn_cast<ComplexType>()) { 1620 auto eltType = complexTy.getElementType(); 1621 Attribute zero; 1622 if (eltType.isa<FloatType>()) 1623 zero = FloatAttr::get(eltType, 0); 1624 else // must be integer 1625 zero = IntegerAttr::get(eltType, 0); 1626 return ArrayAttr::get(complexTy.getContext(), 1627 ArrayRef<Attribute>{zero, zero}); 1628 } 1629 1630 // Handle string type. 1631 if (getValues().isa<DenseStringElementsAttr>()) 1632 return StringAttr::get("", eltType); 1633 1634 // Otherwise, this is an integer. 1635 return IntegerAttr::get(eltType, 0); 1636 } 1637 1638 /// Flatten, and return, all of the sparse indices in this attribute in 1639 /// row-major order. 1640 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1641 std::vector<ptrdiff_t> flatSparseIndices; 1642 1643 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1644 // as a 1-D index array. 1645 auto sparseIndices = getIndices(); 1646 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1647 if (sparseIndices.isSplat()) { 1648 SmallVector<uint64_t, 8> indices(getType().getRank(), 1649 *sparseIndexValues.begin()); 1650 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1651 return flatSparseIndices; 1652 } 1653 1654 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1655 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1656 size_t rank = getType().getRank(); 1657 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1658 flatSparseIndices.push_back(getFlattenedIndex( 1659 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1660 return flatSparseIndices; 1661 } 1662 1663 LogicalResult 1664 SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1665 ShapedType type, DenseIntElementsAttr sparseIndices, 1666 DenseElementsAttr values) { 1667 ShapedType valuesType = values.getType(); 1668 if (valuesType.getRank() != 1) 1669 return emitError() << "expected 1-d tensor for sparse element values"; 1670 1671 // Verify the indices and values shape. 1672 ShapedType indicesType = sparseIndices.getType(); 1673 auto emitShapeError = [&]() { 1674 return emitError() << "expected shape ([" << type.getShape() 1675 << "]); inferred shape of indices literal ([" 1676 << indicesType.getShape() 1677 << "]); inferred shape of values literal ([" 1678 << valuesType.getShape() << "])"; 1679 }; 1680 // Verify indices shape. 1681 size_t rank = type.getRank(), indicesRank = indicesType.getRank(); 1682 if (indicesRank == 2) { 1683 if (indicesType.getDimSize(1) != static_cast<int64_t>(rank)) 1684 return emitShapeError(); 1685 } else if (indicesRank != 1 || rank != 1) { 1686 return emitShapeError(); 1687 } 1688 // Verify the values shape. 1689 int64_t numSparseIndices = indicesType.getDimSize(0); 1690 if (numSparseIndices != valuesType.getDimSize(0)) 1691 return emitShapeError(); 1692 1693 // Verify that the sparse indices are within the value shape. 1694 auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) { 1695 return emitError() 1696 << "sparse index #" << indexNum 1697 << " is not contained within the value shape, with index=[" << index 1698 << "], and type=" << type; 1699 }; 1700 1701 // Handle the case where the index values are a splat. 1702 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1703 if (sparseIndices.isSplat()) { 1704 SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin()); 1705 if (!ElementsAttr::isValidIndex(type, indices)) 1706 return emitIndexError(0, indices); 1707 return success(); 1708 } 1709 1710 // Otherwise, reinterpret each index as an ArrayRef. 1711 for (size_t i = 0, e = numSparseIndices; i != e; ++i) { 1712 ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank), 1713 rank); 1714 if (!ElementsAttr::isValidIndex(type, index)) 1715 return emitIndexError(i, index); 1716 } 1717 1718 return success(); 1719 } 1720 1721 //===----------------------------------------------------------------------===// 1722 // TypeAttr 1723 //===----------------------------------------------------------------------===// 1724 1725 void TypeAttr::walkImmediateSubElements( 1726 function_ref<void(Attribute)> walkAttrsFn, 1727 function_ref<void(Type)> walkTypesFn) const { 1728 walkTypesFn(getValue()); 1729 } 1730 1731 Attribute 1732 TypeAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs, 1733 ArrayRef<Type> replTypes) const { 1734 return get(replTypes[0]); 1735 } 1736