1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/BuiltinDialect.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/IntegerSet.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/SymbolTable.h" 17 #include "mlir/IR/Types.h" 18 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" 19 #include "llvm/ADT/APSInt.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/Support/Endian.h" 22 23 using namespace mlir; 24 using namespace mlir::detail; 25 26 //===----------------------------------------------------------------------===// 27 /// Tablegen Attribute Definitions 28 //===----------------------------------------------------------------------===// 29 30 #define GET_ATTRDEF_CLASSES 31 #include "mlir/IR/BuiltinAttributes.cpp.inc" 32 33 //===----------------------------------------------------------------------===// 34 // BuiltinDialect 35 //===----------------------------------------------------------------------===// 36 37 void BuiltinDialect::registerAttributes() { 38 addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr, 39 DenseStringElementsAttr, DictionaryAttr, FloatAttr, 40 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr, 41 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr, 42 UnitAttr>(); 43 } 44 45 //===----------------------------------------------------------------------===// 46 // ArrayAttr 47 //===----------------------------------------------------------------------===// 48 49 void ArrayAttr::walkImmediateSubElements( 50 function_ref<void(Attribute)> walkAttrsFn, 51 function_ref<void(Type)> walkTypesFn) const { 52 for (Attribute attr : getValue()) 53 walkAttrsFn(attr); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // DictionaryAttr 58 //===----------------------------------------------------------------------===// 59 60 /// Helper function that does either an in place sort or sorts from source array 61 /// into destination. If inPlace then storage is both the source and the 62 /// destination, else value is the source and storage destination. Returns 63 /// whether source was sorted. 64 template <bool inPlace> 65 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 66 SmallVectorImpl<NamedAttribute> &storage) { 67 // Specialize for the common case. 68 switch (value.size()) { 69 case 0: 70 // Zero already sorted. 71 break; 72 case 1: 73 // One already sorted but may need to be copied. 74 if (!inPlace) 75 storage.assign({value[0]}); 76 break; 77 case 2: { 78 bool isSorted = value[0] < value[1]; 79 if (inPlace) { 80 if (!isSorted) 81 std::swap(storage[0], storage[1]); 82 } else if (isSorted) { 83 storage.assign({value[0], value[1]}); 84 } else { 85 storage.assign({value[1], value[0]}); 86 } 87 return !isSorted; 88 } 89 default: 90 if (!inPlace) 91 storage.assign(value.begin(), value.end()); 92 // Check to see they are sorted already. 93 bool isSorted = llvm::is_sorted(value); 94 // If not, do a general sort. 95 if (!isSorted) 96 llvm::array_pod_sort(storage.begin(), storage.end()); 97 return !isSorted; 98 } 99 return false; 100 } 101 102 /// Returns an entry with a duplicate name from the given sorted array of named 103 /// attributes. Returns llvm::None if all elements have unique names. 104 static Optional<NamedAttribute> 105 findDuplicateElement(ArrayRef<NamedAttribute> value) { 106 const Optional<NamedAttribute> none{llvm::None}; 107 if (value.size() < 2) 108 return none; 109 110 if (value.size() == 2) 111 return value[0].first == value[1].first ? value[0] : none; 112 113 auto it = std::adjacent_find( 114 value.begin(), value.end(), 115 [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; }); 116 return it != value.end() ? *it : none; 117 } 118 119 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 120 SmallVectorImpl<NamedAttribute> &storage) { 121 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); 122 assert(!findDuplicateElement(storage) && 123 "DictionaryAttr element names must be unique"); 124 return isSorted; 125 } 126 127 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 128 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); 129 assert(!findDuplicateElement(array) && 130 "DictionaryAttr element names must be unique"); 131 return isSorted; 132 } 133 134 Optional<NamedAttribute> 135 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, 136 bool isSorted) { 137 if (!isSorted) 138 dictionaryAttrSort</*inPlace=*/true>(array, array); 139 return findDuplicateElement(array); 140 } 141 142 DictionaryAttr DictionaryAttr::get(MLIRContext *context, 143 ArrayRef<NamedAttribute> value) { 144 if (value.empty()) 145 return DictionaryAttr::getEmpty(context); 146 assert(llvm::all_of(value, 147 [](const NamedAttribute &attr) { return attr.second; }) && 148 "value cannot have null entries"); 149 150 // We need to sort the element list to canonicalize it. 151 SmallVector<NamedAttribute, 8> storage; 152 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 153 value = storage; 154 assert(!findDuplicateElement(value) && 155 "DictionaryAttr element names must be unique"); 156 return Base::get(context, value); 157 } 158 /// Construct a dictionary with an array of values that is known to already be 159 /// sorted by name and uniqued. 160 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context, 161 ArrayRef<NamedAttribute> value) { 162 if (value.empty()) 163 return DictionaryAttr::getEmpty(context); 164 // Ensure that the attribute elements are unique and sorted. 165 assert(llvm::is_sorted(value, 166 [](NamedAttribute l, NamedAttribute r) { 167 return l.first.strref() < r.first.strref(); 168 }) && 169 "expected attribute values to be sorted"); 170 assert(!findDuplicateElement(value) && 171 "DictionaryAttr element names must be unique"); 172 return Base::get(context, value); 173 } 174 175 /// Return the specified attribute if present, null otherwise. 176 Attribute DictionaryAttr::get(StringRef name) const { 177 Optional<NamedAttribute> attr = getNamed(name); 178 return attr ? attr->second : nullptr; 179 } 180 Attribute DictionaryAttr::get(Identifier name) const { 181 Optional<NamedAttribute> attr = getNamed(name); 182 return attr ? attr->second : nullptr; 183 } 184 185 /// Return the specified named attribute if present, None otherwise. 186 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 187 ArrayRef<NamedAttribute> values = getValue(); 188 const auto *it = llvm::lower_bound(values, name); 189 return it != values.end() && it->first == name ? *it 190 : Optional<NamedAttribute>(); 191 } 192 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { 193 for (auto elt : getValue()) 194 if (elt.first == name) 195 return elt; 196 return llvm::None; 197 } 198 199 DictionaryAttr::iterator DictionaryAttr::begin() const { 200 return getValue().begin(); 201 } 202 DictionaryAttr::iterator DictionaryAttr::end() const { 203 return getValue().end(); 204 } 205 size_t DictionaryAttr::size() const { return getValue().size(); } 206 207 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { 208 return Base::get(context, ArrayRef<NamedAttribute>()); 209 } 210 211 void DictionaryAttr::walkImmediateSubElements( 212 function_ref<void(Attribute)> walkAttrsFn, 213 function_ref<void(Type)> walkTypesFn) const { 214 for (Attribute attr : llvm::make_second_range(getValue())) 215 walkAttrsFn(attr); 216 } 217 218 //===----------------------------------------------------------------------===// 219 // StringAttr 220 //===----------------------------------------------------------------------===// 221 222 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) { 223 return Base::get(context, "", NoneType::get(context)); 224 } 225 226 /// Twine support for StringAttr. 227 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) { 228 // Fast-path empty twine. 229 if (twine.isTriviallyEmpty()) 230 return get(context); 231 SmallVector<char, 32> tempStr; 232 return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context)); 233 } 234 235 /// Twine support for StringAttr. 236 StringAttr StringAttr::get(const Twine &twine, Type type) { 237 SmallVector<char, 32> tempStr; 238 return Base::get(type.getContext(), twine.toStringRef(tempStr), type); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // FloatAttr 243 //===----------------------------------------------------------------------===// 244 245 double FloatAttr::getValueAsDouble() const { 246 return getValueAsDouble(getValue()); 247 } 248 double FloatAttr::getValueAsDouble(APFloat value) { 249 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 250 bool losesInfo = false; 251 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 252 &losesInfo); 253 } 254 return value.convertToDouble(); 255 } 256 257 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 258 Type type, APFloat value) { 259 // Verify that the type is correct. 260 if (!type.isa<FloatType>()) 261 return emitError() << "expected floating point type"; 262 263 // Verify that the type semantics match that of the value. 264 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 265 return emitError() 266 << "FloatAttr type doesn't match the type implied by its value"; 267 } 268 return success(); 269 } 270 271 //===----------------------------------------------------------------------===// 272 // SymbolRefAttr 273 //===----------------------------------------------------------------------===// 274 275 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value, 276 ArrayRef<FlatSymbolRefAttr> nestedRefs) { 277 return get(StringAttr::get(ctx, value), nestedRefs); 278 } 279 280 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { 281 return get(ctx, value, {}).cast<FlatSymbolRefAttr>(); 282 } 283 284 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) { 285 return get(value, {}).cast<FlatSymbolRefAttr>(); 286 } 287 288 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) { 289 auto symName = 290 symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); 291 assert(symName && "value does not have a valid symbol name"); 292 return SymbolRefAttr::get(symName); 293 } 294 295 StringAttr SymbolRefAttr::getLeafReference() const { 296 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 297 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr(); 298 } 299 300 //===----------------------------------------------------------------------===// 301 // IntegerAttr 302 //===----------------------------------------------------------------------===// 303 304 int64_t IntegerAttr::getInt() const { 305 assert((getType().isIndex() || getType().isSignlessInteger()) && 306 "must be signless integer"); 307 return getValue().getSExtValue(); 308 } 309 310 int64_t IntegerAttr::getSInt() const { 311 assert(getType().isSignedInteger() && "must be signed integer"); 312 return getValue().getSExtValue(); 313 } 314 315 uint64_t IntegerAttr::getUInt() const { 316 assert(getType().isUnsignedInteger() && "must be unsigned integer"); 317 return getValue().getZExtValue(); 318 } 319 320 /// Return the value as an APSInt which carries the signed from the type of 321 /// the attribute. This traps on signless integers types! 322 APSInt IntegerAttr::getAPSInt() const { 323 assert(!getType().isSignlessInteger() && 324 "Signless integers don't carry a sign for APSInt"); 325 return APSInt(getValue(), getType().isUnsignedInteger()); 326 } 327 328 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, 329 Type type, APInt value) { 330 if (IntegerType integerType = type.dyn_cast<IntegerType>()) { 331 if (integerType.getWidth() != value.getBitWidth()) 332 return emitError() << "integer type bit width (" << integerType.getWidth() 333 << ") doesn't match value bit width (" 334 << value.getBitWidth() << ")"; 335 return success(); 336 } 337 if (type.isa<IndexType>()) 338 return success(); 339 return emitError() << "expected integer or index type"; 340 } 341 342 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { 343 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value)); 344 return attr.cast<BoolAttr>(); 345 } 346 347 //===----------------------------------------------------------------------===// 348 // BoolAttr 349 350 bool BoolAttr::getValue() const { 351 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl); 352 return storage->value.getBoolValue(); 353 } 354 355 bool BoolAttr::classof(Attribute attr) { 356 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 357 return intAttr && intAttr.getType().isSignlessInteger(1); 358 } 359 360 //===----------------------------------------------------------------------===// 361 // OpaqueAttr 362 //===----------------------------------------------------------------------===// 363 364 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError, 365 Identifier dialect, StringRef attrData, 366 Type type) { 367 if (!Dialect::isValidNamespace(dialect.strref())) 368 return emitError() << "invalid dialect namespace '" << dialect << "'"; 369 370 // Check that the dialect is actually registered. 371 MLIRContext *context = dialect.getContext(); 372 if (!context->allowsUnregisteredDialects() && 373 !context->getLoadedDialect(dialect.strref())) { 374 return emitError() 375 << "#" << dialect << "<\"" << attrData << "\"> : " << type 376 << " attribute created with unregistered dialect. If this is " 377 "intended, please call allowUnregisteredDialects() on the " 378 "MLIRContext, or use -allow-unregistered-dialect with " 379 "the MLIR opt tool used"; 380 } 381 382 return success(); 383 } 384 385 //===----------------------------------------------------------------------===// 386 // ElementsAttr 387 //===----------------------------------------------------------------------===// 388 389 ShapedType ElementsAttr::getType() const { 390 return Attribute::getType().cast<ShapedType>(); 391 } 392 393 /// Returns the number of elements held by this attribute. 394 int64_t ElementsAttr::getNumElements() const { 395 return getType().getNumElements(); 396 } 397 398 /// Return the value at the given index. If index does not refer to a valid 399 /// element, then a null attribute is returned. 400 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 401 if (auto denseAttr = dyn_cast<DenseElementsAttr>()) 402 return denseAttr.getValue(index); 403 if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>()) 404 return opaqueAttr.getValue(index); 405 return cast<SparseElementsAttr>().getValue(index); 406 } 407 408 /// Return if the given 'index' refers to a valid element in this attribute. 409 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 410 auto type = getType(); 411 412 // Verify that the rank of the indices matches the held type. 413 auto rank = type.getRank(); 414 if (rank == 0 && index.size() == 1 && index[0] == 0) 415 return true; 416 if (rank != static_cast<int64_t>(index.size())) 417 return false; 418 419 // Verify that all of the indices are within the shape dimensions. 420 auto shape = type.getShape(); 421 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 422 int64_t dim = static_cast<int64_t>(index[i]); 423 return 0 <= dim && dim < shape[i]; 424 }); 425 } 426 427 ElementsAttr 428 ElementsAttr::mapValues(Type newElementType, 429 function_ref<APInt(const APInt &)> mapping) const { 430 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 431 return intOrFpAttr.mapValues(newElementType, mapping); 432 llvm_unreachable("unsupported ElementsAttr subtype"); 433 } 434 435 ElementsAttr 436 ElementsAttr::mapValues(Type newElementType, 437 function_ref<APInt(const APFloat &)> mapping) const { 438 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 439 return intOrFpAttr.mapValues(newElementType, mapping); 440 llvm_unreachable("unsupported ElementsAttr subtype"); 441 } 442 443 /// Method for support type inquiry through isa, cast and dyn_cast. 444 bool ElementsAttr::classof(Attribute attr) { 445 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr, 446 OpaqueElementsAttr, SparseElementsAttr>(); 447 } 448 449 /// Returns the 1 dimensional flattened row-major index from the given 450 /// multi-dimensional index. 451 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 452 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 453 auto type = getType(); 454 455 // Reduce the provided multidimensional index into a flattended 1D row-major 456 // index. 457 auto rank = type.getRank(); 458 auto shape = type.getShape(); 459 uint64_t valueIndex = 0; 460 uint64_t dimMultiplier = 1; 461 for (int i = rank - 1; i >= 0; --i) { 462 valueIndex += index[i] * dimMultiplier; 463 dimMultiplier *= shape[i]; 464 } 465 return valueIndex; 466 } 467 468 //===----------------------------------------------------------------------===// 469 // DenseElementsAttr Utilities 470 //===----------------------------------------------------------------------===// 471 472 /// Get the bitwidth of a dense element type within the buffer. 473 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 474 static size_t getDenseElementStorageWidth(size_t origWidth) { 475 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 476 } 477 static size_t getDenseElementStorageWidth(Type elementType) { 478 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 479 } 480 481 /// Set a bit to a specific value. 482 static void setBit(char *rawData, size_t bitPos, bool value) { 483 if (value) 484 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 485 else 486 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 487 } 488 489 /// Return the value of the specified bit. 490 static bool getBit(const char *rawData, size_t bitPos) { 491 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 492 } 493 494 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for 495 /// BE format. 496 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, 497 char *result) { 498 assert(llvm::support::endian::system_endianness() == // NOLINT 499 llvm::support::endianness::big); // NOLINT 500 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 501 502 // Copy the words filled with data. 503 // For example, when `value` has 2 words, the first word is filled with data. 504 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| 505 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 506 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 507 numFilledWords, result); 508 // Convert last word of APInt to LE format and store it in char 509 // array(`valueLE`). 510 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| 511 size_t lastWordPos = numFilledWords; 512 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); 513 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 514 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, 515 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); 516 // Extract actual APInt data from `valueLE`, convert endianness to BE format, 517 // and store it in `result`. 518 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| 519 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 520 valueLE.begin(), result + lastWordPos, 521 (numBytes - lastWordPos) * CHAR_BIT, 1); 522 } 523 524 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE 525 /// format. 526 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, 527 APInt &result) { 528 assert(llvm::support::endian::system_endianness() == // NOLINT 529 llvm::support::endianness::big); // NOLINT 530 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 531 532 // Copy the data that fills the word of `result` from `inArray`. 533 // For example, when `result` has 2 words, the first word will be filled with 534 // data. So, the first 8 bytes are copied from `inArray` here. 535 // `inArray` (10 bytes, BE): |abcdefgh|ij| 536 // ==> `result` (2 words, BE): |abcdefgh|--------| 537 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 538 std::copy_n( 539 inArray, numFilledWords, 540 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 541 542 // Convert array data which will be last word of `result` to LE format, and 543 // store it in char array(`inArrayLE`). 544 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| 545 size_t lastWordPos = numFilledWords; 546 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); 547 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 548 inArray + lastWordPos, inArrayLE.begin(), 549 (numBytes - lastWordPos) * CHAR_BIT, 1); 550 551 // Convert `inArrayLE` to BE format, and store it in last word of `result`. 552 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| 553 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 554 inArrayLE.begin(), 555 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + 556 lastWordPos, 557 APInt::APINT_BITS_PER_WORD, 1); 558 } 559 560 /// Writes value to the bit position `bitPos` in array `rawData`. 561 static void writeBits(char *rawData, size_t bitPos, APInt value) { 562 size_t bitWidth = value.getBitWidth(); 563 564 // If the bitwidth is 1 we just toggle the specific bit. 565 if (bitWidth == 1) 566 return setBit(rawData, bitPos, value.isOneValue()); 567 568 // Otherwise, the bit position is guaranteed to be byte aligned. 569 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 570 if (llvm::support::endian::system_endianness() == 571 llvm::support::endianness::big) { 572 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. 573 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 574 // work correctly in BE format. 575 // ex. `value` (2 words including 10 bytes) 576 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| 577 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), 578 rawData + (bitPos / CHAR_BIT)); 579 } else { 580 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 581 llvm::divideCeil(bitWidth, CHAR_BIT), 582 rawData + (bitPos / CHAR_BIT)); 583 } 584 } 585 586 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 587 /// `rawData`. 588 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 589 // Handle a boolean bit position. 590 if (bitWidth == 1) 591 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 592 593 // Otherwise, the bit position must be 8-bit aligned. 594 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 595 APInt result(bitWidth, 0); 596 if (llvm::support::endian::system_endianness() == 597 llvm::support::endianness::big) { 598 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. 599 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 600 // work correctly in BE format. 601 // ex. `result` (2 words including 10 bytes) 602 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function 603 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), 604 llvm::divideCeil(bitWidth, CHAR_BIT), result); 605 } else { 606 std::copy_n(rawData + (bitPos / CHAR_BIT), 607 llvm::divideCeil(bitWidth, CHAR_BIT), 608 const_cast<char *>( 609 reinterpret_cast<const char *>(result.getRawData()))); 610 } 611 return result; 612 } 613 614 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 615 /// the same element count as 'type'. 616 template <typename Values> 617 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 618 return (values.size() == 1) || 619 (type.getNumElements() == static_cast<int64_t>(values.size())); 620 } 621 622 //===----------------------------------------------------------------------===// 623 // DenseElementsAttr Iterators 624 //===----------------------------------------------------------------------===// 625 626 //===----------------------------------------------------------------------===// 627 // AttributeElementIterator 628 629 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 630 DenseElementsAttr attr, size_t index) 631 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 632 Attribute, Attribute, Attribute>( 633 attr.getAsOpaquePointer(), index) {} 634 635 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 636 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 637 Type eltTy = owner.getType().getElementType(); 638 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 639 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 640 if (eltTy.isa<IndexType>()) 641 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 642 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 643 IntElementIterator intIt(owner, index); 644 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 645 return FloatAttr::get(eltTy, *floatIt); 646 } 647 if (auto complexTy = eltTy.dyn_cast<ComplexType>()) { 648 auto complexEltTy = complexTy.getElementType(); 649 ComplexIntElementIterator complexIntIt(owner, index); 650 if (complexEltTy.isa<IntegerType>()) { 651 auto value = *complexIntIt; 652 auto real = IntegerAttr::get(complexEltTy, value.real()); 653 auto imag = IntegerAttr::get(complexEltTy, value.imag()); 654 return ArrayAttr::get(complexTy.getContext(), 655 ArrayRef<Attribute>{real, imag}); 656 } 657 658 ComplexFloatElementIterator complexFloatIt( 659 complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt); 660 auto value = *complexFloatIt; 661 auto real = FloatAttr::get(complexEltTy, value.real()); 662 auto imag = FloatAttr::get(complexEltTy, value.imag()); 663 return ArrayAttr::get(complexTy.getContext(), 664 ArrayRef<Attribute>{real, imag}); 665 } 666 if (owner.isa<DenseStringElementsAttr>()) { 667 ArrayRef<StringRef> vals = owner.getRawStringData(); 668 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 669 } 670 llvm_unreachable("unexpected element type"); 671 } 672 673 //===----------------------------------------------------------------------===// 674 // BoolElementIterator 675 676 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 677 DenseElementsAttr attr, size_t dataIndex) 678 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 679 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 680 681 bool DenseElementsAttr::BoolElementIterator::operator*() const { 682 return getBit(getData(), getDataIndex()); 683 } 684 685 //===----------------------------------------------------------------------===// 686 // IntElementIterator 687 688 DenseElementsAttr::IntElementIterator::IntElementIterator( 689 DenseElementsAttr attr, size_t dataIndex) 690 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 691 attr.getRawData().data(), attr.isSplat(), dataIndex), 692 bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} 693 694 APInt DenseElementsAttr::IntElementIterator::operator*() const { 695 return readBits(getData(), 696 getDataIndex() * getDenseElementStorageWidth(bitWidth), 697 bitWidth); 698 } 699 700 //===----------------------------------------------------------------------===// 701 // ComplexIntElementIterator 702 703 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 704 DenseElementsAttr attr, size_t dataIndex) 705 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 706 std::complex<APInt>, std::complex<APInt>, 707 std::complex<APInt>>( 708 attr.getRawData().data(), attr.isSplat(), dataIndex) { 709 auto complexType = attr.getType().getElementType().cast<ComplexType>(); 710 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 711 } 712 713 std::complex<APInt> 714 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 715 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 716 size_t offset = getDataIndex() * storageWidth * 2; 717 return {readBits(getData(), offset, bitWidth), 718 readBits(getData(), offset + storageWidth, bitWidth)}; 719 } 720 721 //===----------------------------------------------------------------------===// 722 // FloatElementIterator 723 724 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 725 const llvm::fltSemantics &smt, IntElementIterator it) 726 : llvm::mapped_iterator<IntElementIterator, 727 std::function<APFloat(const APInt &)>>( 728 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 729 730 //===----------------------------------------------------------------------===// 731 // ComplexFloatElementIterator 732 733 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( 734 const llvm::fltSemantics &smt, ComplexIntElementIterator it) 735 : llvm::mapped_iterator< 736 ComplexIntElementIterator, 737 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( 738 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { 739 return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; 740 }) {} 741 742 //===----------------------------------------------------------------------===// 743 // DenseElementsAttr 744 //===----------------------------------------------------------------------===// 745 746 /// Method for support type inquiry through isa, cast and dyn_cast. 747 bool DenseElementsAttr::classof(Attribute attr) { 748 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 749 } 750 751 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 752 ArrayRef<Attribute> values) { 753 assert(hasSameElementsOrSplat(type, values)); 754 755 // If the element type is not based on int/float/index, assume it is a string 756 // type. 757 auto eltType = type.getElementType(); 758 if (!type.getElementType().isIntOrIndexOrFloat()) { 759 SmallVector<StringRef, 8> stringValues; 760 stringValues.reserve(values.size()); 761 for (Attribute attr : values) { 762 assert(attr.isa<StringAttr>() && 763 "expected string value for non integer/index/float element"); 764 stringValues.push_back(attr.cast<StringAttr>().getValue()); 765 } 766 return get(type, stringValues); 767 } 768 769 // Otherwise, get the raw storage width to use for the allocation. 770 size_t bitWidth = getDenseElementBitWidth(eltType); 771 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 772 773 // Compress the attribute values into a character buffer. 774 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 775 values.size()); 776 APInt intVal; 777 for (unsigned i = 0, e = values.size(); i < e; ++i) { 778 assert(eltType == values[i].getType() && 779 "expected attribute value to have element type"); 780 if (eltType.isa<FloatType>()) 781 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 782 else if (eltType.isa<IntegerType, IndexType>()) 783 intVal = values[i].cast<IntegerAttr>().getValue(); 784 else 785 llvm_unreachable("unexpected element type"); 786 787 assert(intVal.getBitWidth() == bitWidth && 788 "expected value to have same bitwidth as element type"); 789 writeBits(data.data(), i * storageBitWidth, intVal); 790 } 791 return DenseIntOrFPElementsAttr::getRaw(type, data, 792 /*isSplat=*/(values.size() == 1)); 793 } 794 795 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 796 ArrayRef<bool> values) { 797 assert(hasSameElementsOrSplat(type, values)); 798 assert(type.getElementType().isInteger(1)); 799 800 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 801 for (int i = 0, e = values.size(); i != e; ++i) 802 setBit(buff.data(), i, values[i]); 803 return DenseIntOrFPElementsAttr::getRaw(type, buff, 804 /*isSplat=*/(values.size() == 1)); 805 } 806 807 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 808 ArrayRef<StringRef> values) { 809 assert(!type.getElementType().isIntOrFloat()); 810 return DenseStringElementsAttr::get(type, values); 811 } 812 813 /// Constructs a dense integer elements attribute from an array of APInt 814 /// values. Each APInt value is expected to have the same bitwidth as the 815 /// element type of 'type'. 816 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 817 ArrayRef<APInt> values) { 818 assert(type.getElementType().isIntOrIndex()); 819 assert(hasSameElementsOrSplat(type, values)); 820 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 821 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 822 /*isSplat=*/(values.size() == 1)); 823 } 824 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 825 ArrayRef<std::complex<APInt>> values) { 826 ComplexType complex = type.getElementType().cast<ComplexType>(); 827 assert(complex.getElementType().isa<IntegerType>()); 828 assert(hasSameElementsOrSplat(type, values)); 829 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 830 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 831 values.size() * 2); 832 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, 833 /*isSplat=*/(values.size() == 1)); 834 } 835 836 // Constructs a dense float elements attribute from an array of APFloat 837 // values. Each APFloat value is expected to have the same bitwidth as the 838 // element type of 'type'. 839 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 840 ArrayRef<APFloat> values) { 841 assert(type.getElementType().isa<FloatType>()); 842 assert(hasSameElementsOrSplat(type, values)); 843 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 844 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 845 /*isSplat=*/(values.size() == 1)); 846 } 847 DenseElementsAttr 848 DenseElementsAttr::get(ShapedType type, 849 ArrayRef<std::complex<APFloat>> values) { 850 ComplexType complex = type.getElementType().cast<ComplexType>(); 851 assert(complex.getElementType().isa<FloatType>()); 852 assert(hasSameElementsOrSplat(type, values)); 853 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 854 values.size() * 2); 855 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 856 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, 857 /*isSplat=*/(values.size() == 1)); 858 } 859 860 /// Construct a dense elements attribute from a raw buffer representing the 861 /// data for this attribute. Users should generally not use this methods as 862 /// the expected buffer format may not be a form the user expects. 863 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 864 ArrayRef<char> rawBuffer, 865 bool isSplatBuffer) { 866 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 867 } 868 869 /// Returns true if the given buffer is a valid raw buffer for the given type. 870 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 871 ArrayRef<char> rawBuffer, 872 bool &detectedSplat) { 873 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 874 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 875 876 // Storage width of 1 is special as it is packed by the bit. 877 if (storageWidth == 1) { 878 // Check for a splat, or a buffer equal to the number of elements. 879 if ((detectedSplat = rawBuffer.size() == 1)) 880 return true; 881 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); 882 } 883 // All other types are 8-bit aligned. 884 if ((detectedSplat = rawBufferWidth == storageWidth)) 885 return true; 886 return rawBufferWidth == (storageWidth * type.getNumElements()); 887 } 888 889 /// Check the information for a C++ data type, check if this type is valid for 890 /// the current attribute. This method is used to verify specific type 891 /// invariants that the templatized 'getValues' method cannot. 892 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 893 bool isSigned) { 894 // Make sure that the data element size is the same as the type element width. 895 if (getDenseElementBitWidth(type) != 896 static_cast<size_t>(dataEltSize * CHAR_BIT)) 897 return false; 898 899 // Check that the element type is either float or integer or index. 900 if (!isInt) 901 return type.isa<FloatType>(); 902 if (type.isIndex()) 903 return true; 904 905 auto intType = type.dyn_cast<IntegerType>(); 906 if (!intType) 907 return false; 908 909 // Make sure signedness semantics is consistent. 910 if (intType.isSignless()) 911 return true; 912 return intType.isSigned() ? isSigned : !isSigned; 913 } 914 915 /// Defaults down the subclass implementation. 916 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 917 ArrayRef<char> data, 918 int64_t dataEltSize, 919 bool isInt, bool isSigned) { 920 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 921 isSigned); 922 } 923 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 924 ArrayRef<char> data, 925 int64_t dataEltSize, 926 bool isInt, 927 bool isSigned) { 928 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 929 isInt, isSigned); 930 } 931 932 /// A method used to verify specific type invariants that the templatized 'get' 933 /// method cannot. 934 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 935 bool isSigned) const { 936 return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt, 937 isSigned); 938 } 939 940 /// Check the information for a C++ data type, check if this type is valid for 941 /// the current attribute. 942 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 943 bool isSigned) const { 944 return ::isValidIntOrFloat( 945 getType().getElementType().cast<ComplexType>().getElementType(), 946 dataEltSize / 2, isInt, isSigned); 947 } 948 949 /// Returns true if this attribute corresponds to a splat, i.e. if all element 950 /// values are the same. 951 bool DenseElementsAttr::isSplat() const { 952 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 953 } 954 955 /// Return the held element values as a range of Attributes. 956 auto DenseElementsAttr::getAttributeValues() const 957 -> llvm::iterator_range<AttributeElementIterator> { 958 return {attr_value_begin(), attr_value_end()}; 959 } 960 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 961 return AttributeElementIterator(*this, 0); 962 } 963 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 964 return AttributeElementIterator(*this, getNumElements()); 965 } 966 967 /// Return the held element values as a range of bool. The element type of 968 /// this attribute must be of integer type of bitwidth 1. 969 auto DenseElementsAttr::getBoolValues() const 970 -> llvm::iterator_range<BoolElementIterator> { 971 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 972 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 973 (void)eltType; 974 return {BoolElementIterator(*this, 0), 975 BoolElementIterator(*this, getNumElements())}; 976 } 977 978 /// Return the held element values as a range of APInts. The element type of 979 /// this attribute must be of integer type. 980 auto DenseElementsAttr::getIntValues() const 981 -> llvm::iterator_range<IntElementIterator> { 982 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 983 return {raw_int_begin(), raw_int_end()}; 984 } 985 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 986 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 987 return raw_int_begin(); 988 } 989 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 990 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 991 return raw_int_end(); 992 } 993 auto DenseElementsAttr::getComplexIntValues() const 994 -> llvm::iterator_range<ComplexIntElementIterator> { 995 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 996 (void)eltTy; 997 assert(eltTy.isa<IntegerType>() && "expected complex integral type"); 998 return {ComplexIntElementIterator(*this, 0), 999 ComplexIntElementIterator(*this, getNumElements())}; 1000 } 1001 1002 /// Return the held element values as a range of APFloat. The element type of 1003 /// this attribute must be of float type. 1004 auto DenseElementsAttr::getFloatValues() const 1005 -> llvm::iterator_range<FloatElementIterator> { 1006 auto elementType = getType().getElementType().cast<FloatType>(); 1007 const auto &elementSemantics = elementType.getFloatSemantics(); 1008 return {FloatElementIterator(elementSemantics, raw_int_begin()), 1009 FloatElementIterator(elementSemantics, raw_int_end())}; 1010 } 1011 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 1012 return getFloatValues().begin(); 1013 } 1014 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 1015 return getFloatValues().end(); 1016 } 1017 auto DenseElementsAttr::getComplexFloatValues() const 1018 -> llvm::iterator_range<ComplexFloatElementIterator> { 1019 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 1020 assert(eltTy.isa<FloatType>() && "expected complex float type"); 1021 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 1022 return {{semantics, {*this, 0}}, 1023 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 1024 } 1025 1026 /// Return the raw storage data held by this attribute. 1027 ArrayRef<char> DenseElementsAttr::getRawData() const { 1028 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data; 1029 } 1030 1031 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 1032 return static_cast<DenseStringElementsAttrStorage *>(impl)->data; 1033 } 1034 1035 /// Return a new DenseElementsAttr that has the same data as the current 1036 /// attribute, but has been reshaped to 'newType'. The new type must have the 1037 /// same total number of elements as well as element type. 1038 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 1039 ShapedType curType = getType(); 1040 if (curType == newType) 1041 return *this; 1042 1043 assert(newType.getElementType() == curType.getElementType() && 1044 "expected the same element type"); 1045 assert(newType.getNumElements() == curType.getNumElements() && 1046 "expected the same number of elements"); 1047 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 1048 } 1049 1050 /// Return a new DenseElementsAttr that has the same data as the current 1051 /// attribute, but has bitcast elements such that it is now 'newType'. The new 1052 /// type must have the same shape and element types of the same bitwidth as the 1053 /// current type. 1054 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) { 1055 ShapedType curType = getType(); 1056 Type curElType = curType.getElementType(); 1057 if (curElType == newElType) 1058 return *this; 1059 1060 assert(getDenseElementBitWidth(newElType) == 1061 getDenseElementBitWidth(curElType) && 1062 "expected element types with the same bitwidth"); 1063 return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType), 1064 getRawData(), isSplat()); 1065 } 1066 1067 DenseElementsAttr 1068 DenseElementsAttr::mapValues(Type newElementType, 1069 function_ref<APInt(const APInt &)> mapping) const { 1070 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 1071 } 1072 1073 DenseElementsAttr DenseElementsAttr::mapValues( 1074 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1075 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 1076 } 1077 1078 //===----------------------------------------------------------------------===// 1079 // DenseIntOrFPElementsAttr 1080 //===----------------------------------------------------------------------===// 1081 1082 /// Utility method to write a range of APInt values to a buffer. 1083 template <typename APRangeT> 1084 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1085 APRangeT &&values) { 1086 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); 1087 size_t offset = 0; 1088 for (auto it = values.begin(), e = values.end(); it != e; 1089 ++it, offset += storageWidth) { 1090 assert((*it).getBitWidth() <= storageWidth); 1091 writeBits(data.data(), offset, *it); 1092 } 1093 } 1094 1095 /// Constructs a dense elements attribute from an array of raw APFloat values. 1096 /// Each APFloat value is expected to have the same bitwidth as the element 1097 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1098 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1099 size_t storageWidth, 1100 ArrayRef<APFloat> values, 1101 bool isSplat) { 1102 std::vector<char> data; 1103 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1104 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1105 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1106 } 1107 1108 /// Constructs a dense elements attribute from an array of raw APInt values. 1109 /// Each APInt value is expected to have the same bitwidth as the element type 1110 /// of 'type'. 1111 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1112 size_t storageWidth, 1113 ArrayRef<APInt> values, 1114 bool isSplat) { 1115 std::vector<char> data; 1116 writeAPIntsToBuffer(storageWidth, data, values); 1117 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1118 } 1119 1120 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1121 ArrayRef<char> data, 1122 bool isSplat) { 1123 assert((type.isa<RankedTensorType, VectorType>()) && 1124 "type must be ranked tensor or vector"); 1125 assert(type.hasStaticShape() && "type must have static shape"); 1126 return Base::get(type.getContext(), type, data, isSplat); 1127 } 1128 1129 /// Overload of the raw 'get' method that asserts that the given type is of 1130 /// complex type. This method is used to verify type invariants that the 1131 /// templatized 'get' method cannot. 1132 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1133 ArrayRef<char> data, 1134 int64_t dataEltSize, 1135 bool isInt, 1136 bool isSigned) { 1137 assert(::isValidIntOrFloat( 1138 type.getElementType().cast<ComplexType>().getElementType(), 1139 dataEltSize / 2, isInt, isSigned)); 1140 1141 int64_t numElements = data.size() / dataEltSize; 1142 assert(numElements == 1 || numElements == type.getNumElements()); 1143 return getRaw(type, data, /*isSplat=*/numElements == 1); 1144 } 1145 1146 /// Overload of the 'getRaw' method that asserts that the given type is of 1147 /// integer type. This method is used to verify type invariants that the 1148 /// templatized 'get' method cannot. 1149 DenseElementsAttr 1150 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1151 int64_t dataEltSize, bool isInt, 1152 bool isSigned) { 1153 assert( 1154 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1155 1156 int64_t numElements = data.size() / dataEltSize; 1157 assert(numElements == 1 || numElements == type.getNumElements()); 1158 return getRaw(type, data, /*isSplat=*/numElements == 1); 1159 } 1160 1161 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 1162 const char *inRawData, char *outRawData, size_t elementBitWidth, 1163 size_t numElements) { 1164 using llvm::support::ulittle16_t; 1165 using llvm::support::ulittle32_t; 1166 using llvm::support::ulittle64_t; 1167 1168 assert(llvm::support::endian::system_endianness() == // NOLINT 1169 llvm::support::endianness::big); // NOLINT 1170 // NOLINT to avoid warning message about replacing by static_assert() 1171 1172 // Following std::copy_n always converts endianness on BE machine. 1173 switch (elementBitWidth) { 1174 case 16: { 1175 const ulittle16_t *inRawDataPos = 1176 reinterpret_cast<const ulittle16_t *>(inRawData); 1177 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); 1178 std::copy_n(inRawDataPos, numElements, outDataPos); 1179 break; 1180 } 1181 case 32: { 1182 const ulittle32_t *inRawDataPos = 1183 reinterpret_cast<const ulittle32_t *>(inRawData); 1184 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); 1185 std::copy_n(inRawDataPos, numElements, outDataPos); 1186 break; 1187 } 1188 case 64: { 1189 const ulittle64_t *inRawDataPos = 1190 reinterpret_cast<const ulittle64_t *>(inRawData); 1191 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); 1192 std::copy_n(inRawDataPos, numElements, outDataPos); 1193 break; 1194 } 1195 default: { 1196 size_t nBytes = elementBitWidth / CHAR_BIT; 1197 for (size_t i = 0; i < nBytes; i++) 1198 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); 1199 break; 1200 } 1201 } 1202 } 1203 1204 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1205 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, 1206 ShapedType type) { 1207 size_t numElements = type.getNumElements(); 1208 Type elementType = type.getElementType(); 1209 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1210 elementType = complexTy.getElementType(); 1211 numElements = numElements * 2; 1212 } 1213 size_t elementBitWidth = getDenseElementStorageWidth(elementType); 1214 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && 1215 inRawData.size() <= outRawData.size()); 1216 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), 1217 elementBitWidth, numElements); 1218 } 1219 1220 //===----------------------------------------------------------------------===// 1221 // DenseFPElementsAttr 1222 //===----------------------------------------------------------------------===// 1223 1224 template <typename Fn, typename Attr> 1225 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1226 Type newElementType, 1227 llvm::SmallVectorImpl<char> &data) { 1228 size_t bitWidth = getDenseElementBitWidth(newElementType); 1229 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1230 1231 ShapedType newArrayType; 1232 if (inType.isa<RankedTensorType>()) 1233 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1234 else if (inType.isa<UnrankedTensorType>()) 1235 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1236 else if (inType.isa<VectorType>()) 1237 newArrayType = VectorType::get(inType.getShape(), newElementType); 1238 else 1239 assert(newArrayType && "Unhandled tensor type"); 1240 1241 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1242 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 1243 1244 // Functor used to process a single element value of the attribute. 1245 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1246 auto newInt = mapping(value); 1247 assert(newInt.getBitWidth() == bitWidth); 1248 writeBits(data.data(), index * storageBitWidth, newInt); 1249 }; 1250 1251 // Check for the splat case. 1252 if (attr.isSplat()) { 1253 processElt(*attr.begin(), /*index=*/0); 1254 return newArrayType; 1255 } 1256 1257 // Otherwise, process all of the element values. 1258 uint64_t elementIdx = 0; 1259 for (auto value : attr) 1260 processElt(value, elementIdx++); 1261 return newArrayType; 1262 } 1263 1264 DenseElementsAttr DenseFPElementsAttr::mapValues( 1265 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1266 llvm::SmallVector<char, 8> elementData; 1267 auto newArrayType = 1268 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1269 1270 return getRaw(newArrayType, elementData, isSplat()); 1271 } 1272 1273 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1274 bool DenseFPElementsAttr::classof(Attribute attr) { 1275 return attr.isa<DenseElementsAttr>() && 1276 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1277 } 1278 1279 //===----------------------------------------------------------------------===// 1280 // DenseIntElementsAttr 1281 //===----------------------------------------------------------------------===// 1282 1283 DenseElementsAttr DenseIntElementsAttr::mapValues( 1284 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1285 llvm::SmallVector<char, 8> elementData; 1286 auto newArrayType = 1287 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1288 1289 return getRaw(newArrayType, elementData, isSplat()); 1290 } 1291 1292 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1293 bool DenseIntElementsAttr::classof(Attribute attr) { 1294 return attr.isa<DenseElementsAttr>() && 1295 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1296 } 1297 1298 //===----------------------------------------------------------------------===// 1299 // OpaqueElementsAttr 1300 //===----------------------------------------------------------------------===// 1301 1302 /// Return the value at the given index. If index does not refer to a valid 1303 /// element, then a null attribute is returned. 1304 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1305 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1306 return Attribute(); 1307 } 1308 1309 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1310 Dialect *dialect = getDialect().getDialect(); 1311 if (!dialect) 1312 return true; 1313 auto *interface = 1314 dialect->getRegisteredInterface<DialectDecodeAttributesInterface>(); 1315 if (!interface) 1316 return true; 1317 return failed(interface->decode(*this, result)); 1318 } 1319 1320 LogicalResult 1321 OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1322 Identifier dialect, StringRef value, 1323 ShapedType type) { 1324 if (!Dialect::isValidNamespace(dialect.strref())) 1325 return emitError() << "invalid dialect namespace '" << dialect << "'"; 1326 return success(); 1327 } 1328 1329 //===----------------------------------------------------------------------===// 1330 // SparseElementsAttr 1331 //===----------------------------------------------------------------------===// 1332 1333 /// Return the value of the element at the given index. 1334 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1335 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1336 auto type = getType(); 1337 1338 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1339 // as a 1-D index array. 1340 auto sparseIndices = getIndices(); 1341 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1342 1343 // Check to see if the indices are a splat. 1344 if (sparseIndices.isSplat()) { 1345 // If the index is also not a splat of the index value, we know that the 1346 // value is zero. 1347 auto splatIndex = *sparseIndexValues.begin(); 1348 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1349 return getZeroAttr(); 1350 1351 // If the indices are a splat, we also expect the values to be a splat. 1352 assert(getValues().isSplat() && "expected splat values"); 1353 return getValues().getSplatValue(); 1354 } 1355 1356 // Build a mapping between known indices and the offset of the stored element. 1357 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1358 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1359 size_t rank = type.getRank(); 1360 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1361 mappedIndices.try_emplace( 1362 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1363 1364 // Look for the provided index key within the mapped indices. If the provided 1365 // index is not found, then return a zero attribute. 1366 auto it = mappedIndices.find(index); 1367 if (it == mappedIndices.end()) 1368 return getZeroAttr(); 1369 1370 // Otherwise, return the held sparse value element. 1371 return getValues().getValue(it->second); 1372 } 1373 1374 /// Get a zero APFloat for the given sparse attribute. 1375 APFloat SparseElementsAttr::getZeroAPFloat() const { 1376 auto eltType = getType().getElementType().cast<FloatType>(); 1377 return APFloat(eltType.getFloatSemantics()); 1378 } 1379 1380 /// Get a zero APInt for the given sparse attribute. 1381 APInt SparseElementsAttr::getZeroAPInt() const { 1382 auto eltType = getType().getElementType().cast<IntegerType>(); 1383 return APInt::getNullValue(eltType.getWidth()); 1384 } 1385 1386 /// Get a zero attribute for the given attribute type. 1387 Attribute SparseElementsAttr::getZeroAttr() const { 1388 auto eltType = getType().getElementType(); 1389 1390 // Handle floating point elements. 1391 if (eltType.isa<FloatType>()) 1392 return FloatAttr::get(eltType, 0); 1393 1394 // Otherwise, this is an integer. 1395 // TODO: Handle StringAttr here. 1396 return IntegerAttr::get(eltType, 0); 1397 } 1398 1399 /// Flatten, and return, all of the sparse indices in this attribute in 1400 /// row-major order. 1401 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1402 std::vector<ptrdiff_t> flatSparseIndices; 1403 1404 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1405 // as a 1-D index array. 1406 auto sparseIndices = getIndices(); 1407 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1408 if (sparseIndices.isSplat()) { 1409 SmallVector<uint64_t, 8> indices(getType().getRank(), 1410 *sparseIndexValues.begin()); 1411 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1412 return flatSparseIndices; 1413 } 1414 1415 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1416 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1417 size_t rank = getType().getRank(); 1418 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1419 flatSparseIndices.push_back(getFlattenedIndex( 1420 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1421 return flatSparseIndices; 1422 } 1423 1424 //===----------------------------------------------------------------------===// 1425 // TypeAttr 1426 //===----------------------------------------------------------------------===// 1427 1428 void TypeAttr::walkImmediateSubElements( 1429 function_ref<void(Attribute)> walkAttrsFn, 1430 function_ref<void(Type)> walkTypesFn) const { 1431 walkTypesFn(getValue()); 1432 } 1433