1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Diagnostics.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/IntegerSet.h" 15 #include "mlir/IR/Types.h" 16 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 #include "llvm/Support/Endian.h" 20 21 using namespace mlir; 22 using namespace mlir::detail; 23 24 //===----------------------------------------------------------------------===// 25 // AffineMapAttr 26 //===----------------------------------------------------------------------===// 27 28 AffineMapAttr AffineMapAttr::get(AffineMap value) { 29 return Base::get(value.getContext(), value); 30 } 31 32 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } 33 34 //===----------------------------------------------------------------------===// 35 // ArrayAttr 36 //===----------------------------------------------------------------------===// 37 38 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) { 39 return Base::get(context, value); 40 } 41 42 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } 43 44 Attribute ArrayAttr::operator[](unsigned idx) const { 45 assert(idx < size() && "index out of bounds"); 46 return getValue()[idx]; 47 } 48 49 //===----------------------------------------------------------------------===// 50 // DictionaryAttr 51 //===----------------------------------------------------------------------===// 52 53 /// Helper function that does either an in place sort or sorts from source array 54 /// into destination. If inPlace then storage is both the source and the 55 /// destination, else value is the source and storage destination. Returns 56 /// whether source was sorted. 57 template <bool inPlace> 58 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 59 SmallVectorImpl<NamedAttribute> &storage) { 60 // Specialize for the common case. 61 switch (value.size()) { 62 case 0: 63 // Zero already sorted. 64 break; 65 case 1: 66 // One already sorted but may need to be copied. 67 if (!inPlace) 68 storage.assign({value[0]}); 69 break; 70 case 2: { 71 bool isSorted = value[0] < value[1]; 72 if (inPlace) { 73 if (!isSorted) 74 std::swap(storage[0], storage[1]); 75 } else if (isSorted) { 76 storage.assign({value[0], value[1]}); 77 } else { 78 storage.assign({value[1], value[0]}); 79 } 80 return !isSorted; 81 } 82 default: 83 if (!inPlace) 84 storage.assign(value.begin(), value.end()); 85 // Check to see they are sorted already. 86 bool isSorted = llvm::is_sorted(value); 87 if (!isSorted) { 88 // If not, do a general sort. 89 llvm::array_pod_sort(storage.begin(), storage.end()); 90 value = storage; 91 } 92 return !isSorted; 93 } 94 return false; 95 } 96 97 /// Returns an entry with a duplicate name from the given sorted array of named 98 /// attributes. Returns llvm::None if all elements have unique names. 99 static Optional<NamedAttribute> 100 findDuplicateElement(ArrayRef<NamedAttribute> value) { 101 const Optional<NamedAttribute> none{llvm::None}; 102 if (value.size() < 2) 103 return none; 104 105 if (value.size() == 2) 106 return value[0].first == value[1].first ? value[0] : none; 107 108 auto it = std::adjacent_find( 109 value.begin(), value.end(), 110 [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; }); 111 return it != value.end() ? *it : none; 112 } 113 114 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 115 SmallVectorImpl<NamedAttribute> &storage) { 116 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); 117 assert(!findDuplicateElement(storage) && 118 "DictionaryAttr element names must be unique"); 119 return isSorted; 120 } 121 122 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 123 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); 124 assert(!findDuplicateElement(array) && 125 "DictionaryAttr element names must be unique"); 126 return isSorted; 127 } 128 129 Optional<NamedAttribute> 130 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, 131 bool isSorted) { 132 if (!isSorted) 133 dictionaryAttrSort</*inPlace=*/true>(array, array); 134 return findDuplicateElement(array); 135 } 136 137 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value, 138 MLIRContext *context) { 139 if (value.empty()) 140 return DictionaryAttr::getEmpty(context); 141 assert(llvm::all_of(value, 142 [](const NamedAttribute &attr) { return attr.second; }) && 143 "value cannot have null entries"); 144 145 // We need to sort the element list to canonicalize it. 146 SmallVector<NamedAttribute, 8> storage; 147 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 148 value = storage; 149 assert(!findDuplicateElement(value) && 150 "DictionaryAttr element names must be unique"); 151 return Base::get(context, value); 152 } 153 /// Construct a dictionary with an array of values that is known to already be 154 /// sorted by name and uniqued. 155 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value, 156 MLIRContext *context) { 157 if (value.empty()) 158 return DictionaryAttr::getEmpty(context); 159 // Ensure that the attribute elements are unique and sorted. 160 assert(llvm::is_sorted(value, 161 [](NamedAttribute l, NamedAttribute r) { 162 return l.first.strref() < r.first.strref(); 163 }) && 164 "expected attribute values to be sorted"); 165 assert(!findDuplicateElement(value) && 166 "DictionaryAttr element names must be unique"); 167 return Base::get(context, value); 168 } 169 170 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const { 171 return getImpl()->getElements(); 172 } 173 174 /// Return the specified attribute if present, null otherwise. 175 Attribute DictionaryAttr::get(StringRef name) const { 176 Optional<NamedAttribute> attr = getNamed(name); 177 return attr ? attr->second : nullptr; 178 } 179 Attribute DictionaryAttr::get(Identifier name) const { 180 Optional<NamedAttribute> attr = getNamed(name); 181 return attr ? attr->second : nullptr; 182 } 183 184 /// Return the specified named attribute if present, None otherwise. 185 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 186 ArrayRef<NamedAttribute> values = getValue(); 187 const auto *it = llvm::lower_bound(values, name); 188 return it != values.end() && it->first == name ? *it 189 : Optional<NamedAttribute>(); 190 } 191 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { 192 for (auto elt : getValue()) 193 if (elt.first == name) 194 return elt; 195 return llvm::None; 196 } 197 198 DictionaryAttr::iterator DictionaryAttr::begin() const { 199 return getValue().begin(); 200 } 201 DictionaryAttr::iterator DictionaryAttr::end() const { 202 return getValue().end(); 203 } 204 size_t DictionaryAttr::size() const { return getValue().size(); } 205 206 //===----------------------------------------------------------------------===// 207 // FloatAttr 208 //===----------------------------------------------------------------------===// 209 210 FloatAttr FloatAttr::get(Type type, double value) { 211 return Base::get(type.getContext(), type, value); 212 } 213 214 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { 215 return Base::getChecked(loc, type, value); 216 } 217 218 FloatAttr FloatAttr::get(Type type, const APFloat &value) { 219 return Base::get(type.getContext(), type, value); 220 } 221 222 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { 223 return Base::getChecked(loc, type, value); 224 } 225 226 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } 227 228 double FloatAttr::getValueAsDouble() const { 229 return getValueAsDouble(getValue()); 230 } 231 double FloatAttr::getValueAsDouble(APFloat value) { 232 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 233 bool losesInfo = false; 234 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 235 &losesInfo); 236 } 237 return value.convertToDouble(); 238 } 239 240 /// Verify construction invariants. 241 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) { 242 if (!type.isa<FloatType>()) 243 return emitError(loc, "expected floating point type"); 244 return success(); 245 } 246 247 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 248 double value) { 249 return verifyFloatTypeInvariants(loc, type); 250 } 251 252 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 253 const APFloat &value) { 254 // Verify that the type is correct. 255 if (failed(verifyFloatTypeInvariants(loc, type))) 256 return failure(); 257 258 // Verify that the type semantics match that of the value. 259 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 260 return emitError( 261 loc, "FloatAttr type doesn't match the type implied by its value"); 262 } 263 return success(); 264 } 265 266 //===----------------------------------------------------------------------===// 267 // SymbolRefAttr 268 //===----------------------------------------------------------------------===// 269 270 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { 271 return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>(); 272 } 273 274 SymbolRefAttr SymbolRefAttr::get(StringRef value, 275 ArrayRef<FlatSymbolRefAttr> nestedReferences, 276 MLIRContext *ctx) { 277 return Base::get(ctx, value, nestedReferences); 278 } 279 280 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } 281 282 StringRef SymbolRefAttr::getLeafReference() const { 283 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 284 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 285 } 286 287 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const { 288 return getImpl()->getNestedRefs(); 289 } 290 291 //===----------------------------------------------------------------------===// 292 // IntegerAttr 293 //===----------------------------------------------------------------------===// 294 295 IntegerAttr IntegerAttr::get(Type type, const APInt &value) { 296 if (type.isSignlessInteger(1)) 297 return BoolAttr::get(value.getBoolValue(), type.getContext()); 298 return Base::get(type.getContext(), type, value); 299 } 300 301 IntegerAttr IntegerAttr::get(Type type, int64_t value) { 302 // This uses 64 bit APInts by default for index type. 303 if (type.isIndex()) 304 return get(type, APInt(IndexType::kInternalStorageBitWidth, value)); 305 306 auto intType = type.cast<IntegerType>(); 307 return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger())); 308 } 309 310 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } 311 312 int64_t IntegerAttr::getInt() const { 313 assert((getImpl()->getType().isIndex() || 314 getImpl()->getType().isSignlessInteger()) && 315 "must be signless integer"); 316 return getValue().getSExtValue(); 317 } 318 319 int64_t IntegerAttr::getSInt() const { 320 assert(getImpl()->getType().isSignedInteger() && "must be signed integer"); 321 return getValue().getSExtValue(); 322 } 323 324 uint64_t IntegerAttr::getUInt() const { 325 assert(getImpl()->getType().isUnsignedInteger() && 326 "must be unsigned integer"); 327 return getValue().getZExtValue(); 328 } 329 330 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) { 331 if (type.isa<IntegerType, IndexType>()) 332 return success(); 333 return emitError(loc, "expected integer or index type"); 334 } 335 336 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 337 int64_t value) { 338 return verifyIntegerTypeInvariants(loc, type); 339 } 340 341 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 342 const APInt &value) { 343 if (failed(verifyIntegerTypeInvariants(loc, type))) 344 return failure(); 345 if (auto integerType = type.dyn_cast<IntegerType>()) 346 if (integerType.getWidth() != value.getBitWidth()) 347 return emitError(loc, "integer type bit width (") 348 << integerType.getWidth() << ") doesn't match value bit width (" 349 << value.getBitWidth() << ")"; 350 return success(); 351 } 352 353 //===----------------------------------------------------------------------===// 354 // BoolAttr 355 356 bool BoolAttr::getValue() const { 357 auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl); 358 return storage->getValue().getBoolValue(); 359 } 360 361 bool BoolAttr::classof(Attribute attr) { 362 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 363 return intAttr && intAttr.getType().isSignlessInteger(1); 364 } 365 366 //===----------------------------------------------------------------------===// 367 // IntegerSetAttr 368 //===----------------------------------------------------------------------===// 369 370 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { 371 return Base::get(value.getConstraint(0).getContext(), value); 372 } 373 374 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } 375 376 //===----------------------------------------------------------------------===// 377 // OpaqueAttr 378 //===----------------------------------------------------------------------===// 379 380 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, 381 MLIRContext *context) { 382 return Base::get(context, dialect, attrData, type); 383 } 384 385 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, 386 Type type, Location location) { 387 return Base::getChecked(location, dialect, attrData, type); 388 } 389 390 /// Returns the dialect namespace of the opaque attribute. 391 Identifier OpaqueAttr::getDialectNamespace() const { 392 return getImpl()->dialectNamespace; 393 } 394 395 /// Returns the raw attribute data of the opaque attribute. 396 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } 397 398 /// Verify the construction of an opaque attribute. 399 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc, 400 Identifier dialect, 401 StringRef attrData, 402 Type type) { 403 if (!Dialect::isValidNamespace(dialect.strref())) 404 return emitError(loc, "invalid dialect namespace '") << dialect << "'"; 405 return success(); 406 } 407 408 //===----------------------------------------------------------------------===// 409 // StringAttr 410 //===----------------------------------------------------------------------===// 411 412 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { 413 return get(bytes, NoneType::get(context)); 414 } 415 416 /// Get an instance of a StringAttr with the given string and Type. 417 StringAttr StringAttr::get(StringRef bytes, Type type) { 418 return Base::get(type.getContext(), bytes, type); 419 } 420 421 StringRef StringAttr::getValue() const { return getImpl()->value; } 422 423 //===----------------------------------------------------------------------===// 424 // TypeAttr 425 //===----------------------------------------------------------------------===// 426 427 TypeAttr TypeAttr::get(Type value) { 428 return Base::get(value.getContext(), value); 429 } 430 431 Type TypeAttr::getValue() const { return getImpl()->value; } 432 433 //===----------------------------------------------------------------------===// 434 // ElementsAttr 435 //===----------------------------------------------------------------------===// 436 437 ShapedType ElementsAttr::getType() const { 438 return Attribute::getType().cast<ShapedType>(); 439 } 440 441 /// Returns the number of elements held by this attribute. 442 int64_t ElementsAttr::getNumElements() const { 443 return getType().getNumElements(); 444 } 445 446 /// Return the value at the given index. If index does not refer to a valid 447 /// element, then a null attribute is returned. 448 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 449 if (auto denseAttr = dyn_cast<DenseElementsAttr>()) 450 return denseAttr.getValue(index); 451 if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>()) 452 return opaqueAttr.getValue(index); 453 return cast<SparseElementsAttr>().getValue(index); 454 } 455 456 /// Return if the given 'index' refers to a valid element in this attribute. 457 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 458 auto type = getType(); 459 460 // Verify that the rank of the indices matches the held type. 461 auto rank = type.getRank(); 462 if (rank != static_cast<int64_t>(index.size())) 463 return false; 464 465 // Verify that all of the indices are within the shape dimensions. 466 auto shape = type.getShape(); 467 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 468 return static_cast<int64_t>(index[i]) < shape[i]; 469 }); 470 } 471 472 ElementsAttr 473 ElementsAttr::mapValues(Type newElementType, 474 function_ref<APInt(const APInt &)> mapping) const { 475 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 476 return intOrFpAttr.mapValues(newElementType, mapping); 477 llvm_unreachable("unsupported ElementsAttr subtype"); 478 } 479 480 ElementsAttr 481 ElementsAttr::mapValues(Type newElementType, 482 function_ref<APInt(const APFloat &)> mapping) const { 483 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 484 return intOrFpAttr.mapValues(newElementType, mapping); 485 llvm_unreachable("unsupported ElementsAttr subtype"); 486 } 487 488 /// Method for support type inquiry through isa, cast and dyn_cast. 489 bool ElementsAttr::classof(Attribute attr) { 490 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr, 491 OpaqueElementsAttr, SparseElementsAttr>(); 492 } 493 494 /// Returns the 1 dimensional flattened row-major index from the given 495 /// multi-dimensional index. 496 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 497 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 498 auto type = getType(); 499 500 // Reduce the provided multidimensional index into a flattended 1D row-major 501 // index. 502 auto rank = type.getRank(); 503 auto shape = type.getShape(); 504 uint64_t valueIndex = 0; 505 uint64_t dimMultiplier = 1; 506 for (int i = rank - 1; i >= 0; --i) { 507 valueIndex += index[i] * dimMultiplier; 508 dimMultiplier *= shape[i]; 509 } 510 return valueIndex; 511 } 512 513 //===----------------------------------------------------------------------===// 514 // DenseElementsAttr Utilities 515 //===----------------------------------------------------------------------===// 516 517 /// Get the bitwidth of a dense element type within the buffer. 518 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 519 static size_t getDenseElementStorageWidth(size_t origWidth) { 520 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 521 } 522 static size_t getDenseElementStorageWidth(Type elementType) { 523 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 524 } 525 526 /// Set a bit to a specific value. 527 static void setBit(char *rawData, size_t bitPos, bool value) { 528 if (value) 529 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 530 else 531 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 532 } 533 534 /// Return the value of the specified bit. 535 static bool getBit(const char *rawData, size_t bitPos) { 536 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 537 } 538 539 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for 540 /// BE format. 541 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, 542 char *result) { 543 assert(llvm::support::endian::system_endianness() == // NOLINT 544 llvm::support::endianness::big); // NOLINT 545 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 546 547 // Copy the words filled with data. 548 // For example, when `value` has 2 words, the first word is filled with data. 549 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| 550 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 551 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 552 numFilledWords, result); 553 // Convert last word of APInt to LE format and store it in char 554 // array(`valueLE`). 555 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| 556 size_t lastWordPos = numFilledWords; 557 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); 558 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 559 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, 560 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); 561 // Extract actual APInt data from `valueLE`, convert endianness to BE format, 562 // and store it in `result`. 563 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| 564 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 565 valueLE.begin(), result + lastWordPos, 566 (numBytes - lastWordPos) * CHAR_BIT, 1); 567 } 568 569 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE 570 /// format. 571 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, 572 APInt &result) { 573 assert(llvm::support::endian::system_endianness() == // NOLINT 574 llvm::support::endianness::big); // NOLINT 575 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 576 577 // Copy the data that fills the word of `result` from `inArray`. 578 // For example, when `result` has 2 words, the first word will be filled with 579 // data. So, the first 8 bytes are copied from `inArray` here. 580 // `inArray` (10 bytes, BE): |abcdefgh|ij| 581 // ==> `result` (2 words, BE): |abcdefgh|--------| 582 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 583 std::copy_n( 584 inArray, numFilledWords, 585 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 586 587 // Convert array data which will be last word of `result` to LE format, and 588 // store it in char array(`inArrayLE`). 589 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| 590 size_t lastWordPos = numFilledWords; 591 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); 592 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 593 inArray + lastWordPos, inArrayLE.begin(), 594 (numBytes - lastWordPos) * CHAR_BIT, 1); 595 596 // Convert `inArrayLE` to BE format, and store it in last word of `result`. 597 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| 598 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 599 inArrayLE.begin(), 600 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + 601 lastWordPos, 602 APInt::APINT_BITS_PER_WORD, 1); 603 } 604 605 /// Writes value to the bit position `bitPos` in array `rawData`. 606 static void writeBits(char *rawData, size_t bitPos, APInt value) { 607 size_t bitWidth = value.getBitWidth(); 608 609 // If the bitwidth is 1 we just toggle the specific bit. 610 if (bitWidth == 1) 611 return setBit(rawData, bitPos, value.isOneValue()); 612 613 // Otherwise, the bit position is guaranteed to be byte aligned. 614 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 615 if (llvm::support::endian::system_endianness() == 616 llvm::support::endianness::big) { 617 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. 618 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 619 // work correctly in BE format. 620 // ex. `value` (2 words including 10 bytes) 621 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| 622 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), 623 rawData + (bitPos / CHAR_BIT)); 624 } else { 625 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 626 llvm::divideCeil(bitWidth, CHAR_BIT), 627 rawData + (bitPos / CHAR_BIT)); 628 } 629 } 630 631 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 632 /// `rawData`. 633 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 634 // Handle a boolean bit position. 635 if (bitWidth == 1) 636 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 637 638 // Otherwise, the bit position must be 8-bit aligned. 639 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 640 APInt result(bitWidth, 0); 641 if (llvm::support::endian::system_endianness() == 642 llvm::support::endianness::big) { 643 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. 644 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 645 // work correctly in BE format. 646 // ex. `result` (2 words including 10 bytes) 647 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function 648 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), 649 llvm::divideCeil(bitWidth, CHAR_BIT), result); 650 } else { 651 std::copy_n(rawData + (bitPos / CHAR_BIT), 652 llvm::divideCeil(bitWidth, CHAR_BIT), 653 const_cast<char *>( 654 reinterpret_cast<const char *>(result.getRawData()))); 655 } 656 return result; 657 } 658 659 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 660 /// the same element count as 'type'. 661 template <typename Values> 662 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 663 return (values.size() == 1) || 664 (type.getNumElements() == static_cast<int64_t>(values.size())); 665 } 666 667 //===----------------------------------------------------------------------===// 668 // DenseElementsAttr Iterators 669 //===----------------------------------------------------------------------===// 670 671 //===----------------------------------------------------------------------===// 672 // AttributeElementIterator 673 674 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 675 DenseElementsAttr attr, size_t index) 676 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 677 Attribute, Attribute, Attribute>( 678 attr.getAsOpaquePointer(), index) {} 679 680 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 681 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 682 Type eltTy = owner.getType().getElementType(); 683 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 684 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 685 if (eltTy.isa<IndexType>()) 686 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 687 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 688 IntElementIterator intIt(owner, index); 689 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 690 return FloatAttr::get(eltTy, *floatIt); 691 } 692 if (owner.isa<DenseStringElementsAttr>()) { 693 ArrayRef<StringRef> vals = owner.getRawStringData(); 694 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 695 } 696 llvm_unreachable("unexpected element type"); 697 } 698 699 //===----------------------------------------------------------------------===// 700 // BoolElementIterator 701 702 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 703 DenseElementsAttr attr, size_t dataIndex) 704 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 705 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 706 707 bool DenseElementsAttr::BoolElementIterator::operator*() const { 708 return getBit(getData(), getDataIndex()); 709 } 710 711 //===----------------------------------------------------------------------===// 712 // IntElementIterator 713 714 DenseElementsAttr::IntElementIterator::IntElementIterator( 715 DenseElementsAttr attr, size_t dataIndex) 716 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 717 attr.getRawData().data(), attr.isSplat(), dataIndex), 718 bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} 719 720 APInt DenseElementsAttr::IntElementIterator::operator*() const { 721 return readBits(getData(), 722 getDataIndex() * getDenseElementStorageWidth(bitWidth), 723 bitWidth); 724 } 725 726 //===----------------------------------------------------------------------===// 727 // ComplexIntElementIterator 728 729 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 730 DenseElementsAttr attr, size_t dataIndex) 731 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 732 std::complex<APInt>, std::complex<APInt>, 733 std::complex<APInt>>( 734 attr.getRawData().data(), attr.isSplat(), dataIndex) { 735 auto complexType = attr.getType().getElementType().cast<ComplexType>(); 736 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 737 } 738 739 std::complex<APInt> 740 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 741 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 742 size_t offset = getDataIndex() * storageWidth * 2; 743 return {readBits(getData(), offset, bitWidth), 744 readBits(getData(), offset + storageWidth, bitWidth)}; 745 } 746 747 //===----------------------------------------------------------------------===// 748 // FloatElementIterator 749 750 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 751 const llvm::fltSemantics &smt, IntElementIterator it) 752 : llvm::mapped_iterator<IntElementIterator, 753 std::function<APFloat(const APInt &)>>( 754 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 755 756 //===----------------------------------------------------------------------===// 757 // ComplexFloatElementIterator 758 759 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( 760 const llvm::fltSemantics &smt, ComplexIntElementIterator it) 761 : llvm::mapped_iterator< 762 ComplexIntElementIterator, 763 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( 764 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { 765 return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; 766 }) {} 767 768 //===----------------------------------------------------------------------===// 769 // DenseElementsAttr 770 //===----------------------------------------------------------------------===// 771 772 /// Method for support type inquiry through isa, cast and dyn_cast. 773 bool DenseElementsAttr::classof(Attribute attr) { 774 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 775 } 776 777 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 778 ArrayRef<Attribute> values) { 779 assert(hasSameElementsOrSplat(type, values)); 780 781 // If the element type is not based on int/float/index, assume it is a string 782 // type. 783 auto eltType = type.getElementType(); 784 if (!type.getElementType().isIntOrIndexOrFloat()) { 785 SmallVector<StringRef, 8> stringValues; 786 stringValues.reserve(values.size()); 787 for (Attribute attr : values) { 788 assert(attr.isa<StringAttr>() && 789 "expected string value for non integer/index/float element"); 790 stringValues.push_back(attr.cast<StringAttr>().getValue()); 791 } 792 return get(type, stringValues); 793 } 794 795 // Otherwise, get the raw storage width to use for the allocation. 796 size_t bitWidth = getDenseElementBitWidth(eltType); 797 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 798 799 // Compress the attribute values into a character buffer. 800 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 801 values.size()); 802 APInt intVal; 803 for (unsigned i = 0, e = values.size(); i < e; ++i) { 804 assert(eltType == values[i].getType() && 805 "expected attribute value to have element type"); 806 if (eltType.isa<FloatType>()) 807 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 808 else if (eltType.isa<IntegerType>()) 809 intVal = values[i].cast<IntegerAttr>().getValue(); 810 else 811 llvm_unreachable("unexpected element type"); 812 813 assert(intVal.getBitWidth() == bitWidth && 814 "expected value to have same bitwidth as element type"); 815 writeBits(data.data(), i * storageBitWidth, intVal); 816 } 817 return DenseIntOrFPElementsAttr::getRaw(type, data, 818 /*isSplat=*/(values.size() == 1)); 819 } 820 821 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 822 ArrayRef<bool> values) { 823 assert(hasSameElementsOrSplat(type, values)); 824 assert(type.getElementType().isInteger(1)); 825 826 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 827 for (int i = 0, e = values.size(); i != e; ++i) 828 setBit(buff.data(), i, values[i]); 829 return DenseIntOrFPElementsAttr::getRaw(type, buff, 830 /*isSplat=*/(values.size() == 1)); 831 } 832 833 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 834 ArrayRef<StringRef> values) { 835 assert(!type.getElementType().isIntOrFloat()); 836 return DenseStringElementsAttr::get(type, values); 837 } 838 839 /// Constructs a dense integer elements attribute from an array of APInt 840 /// values. Each APInt value is expected to have the same bitwidth as the 841 /// element type of 'type'. 842 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 843 ArrayRef<APInt> values) { 844 assert(type.getElementType().isIntOrIndex()); 845 assert(hasSameElementsOrSplat(type, values)); 846 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 847 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 848 /*isSplat=*/(values.size() == 1)); 849 } 850 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 851 ArrayRef<std::complex<APInt>> values) { 852 ComplexType complex = type.getElementType().cast<ComplexType>(); 853 assert(complex.getElementType().isa<IntegerType>()); 854 assert(hasSameElementsOrSplat(type, values)); 855 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 856 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 857 values.size() * 2); 858 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, 859 /*isSplat=*/(values.size() == 1)); 860 } 861 862 // Constructs a dense float elements attribute from an array of APFloat 863 // values. Each APFloat value is expected to have the same bitwidth as the 864 // element type of 'type'. 865 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 866 ArrayRef<APFloat> values) { 867 assert(type.getElementType().isa<FloatType>()); 868 assert(hasSameElementsOrSplat(type, values)); 869 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 870 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 871 /*isSplat=*/(values.size() == 1)); 872 } 873 DenseElementsAttr 874 DenseElementsAttr::get(ShapedType type, 875 ArrayRef<std::complex<APFloat>> values) { 876 ComplexType complex = type.getElementType().cast<ComplexType>(); 877 assert(complex.getElementType().isa<FloatType>()); 878 assert(hasSameElementsOrSplat(type, values)); 879 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 880 values.size() * 2); 881 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 882 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, 883 /*isSplat=*/(values.size() == 1)); 884 } 885 886 /// Construct a dense elements attribute from a raw buffer representing the 887 /// data for this attribute. Users should generally not use this methods as 888 /// the expected buffer format may not be a form the user expects. 889 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 890 ArrayRef<char> rawBuffer, 891 bool isSplatBuffer) { 892 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 893 } 894 895 /// Returns true if the given buffer is a valid raw buffer for the given type. 896 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 897 ArrayRef<char> rawBuffer, 898 bool &detectedSplat) { 899 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 900 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 901 902 // Storage width of 1 is special as it is packed by the bit. 903 if (storageWidth == 1) { 904 // Check for a splat, or a buffer equal to the number of elements. 905 if ((detectedSplat = rawBuffer.size() == 1)) 906 return true; 907 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); 908 } 909 // All other types are 8-bit aligned. 910 if ((detectedSplat = rawBufferWidth == storageWidth)) 911 return true; 912 return rawBufferWidth == (storageWidth * type.getNumElements()); 913 } 914 915 /// Check the information for a C++ data type, check if this type is valid for 916 /// the current attribute. This method is used to verify specific type 917 /// invariants that the templatized 'getValues' method cannot. 918 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 919 bool isSigned) { 920 // Make sure that the data element size is the same as the type element width. 921 if (getDenseElementBitWidth(type) != 922 static_cast<size_t>(dataEltSize * CHAR_BIT)) 923 return false; 924 925 // Check that the element type is either float or integer or index. 926 if (!isInt) 927 return type.isa<FloatType>(); 928 if (type.isIndex()) 929 return true; 930 931 auto intType = type.dyn_cast<IntegerType>(); 932 if (!intType) 933 return false; 934 935 // Make sure signedness semantics is consistent. 936 if (intType.isSignless()) 937 return true; 938 return intType.isSigned() ? isSigned : !isSigned; 939 } 940 941 /// Defaults down the subclass implementation. 942 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 943 ArrayRef<char> data, 944 int64_t dataEltSize, 945 bool isInt, bool isSigned) { 946 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 947 isSigned); 948 } 949 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 950 ArrayRef<char> data, 951 int64_t dataEltSize, 952 bool isInt, 953 bool isSigned) { 954 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 955 isInt, isSigned); 956 } 957 958 /// A method used to verify specific type invariants that the templatized 'get' 959 /// method cannot. 960 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 961 bool isSigned) const { 962 return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt, 963 isSigned); 964 } 965 966 /// Check the information for a C++ data type, check if this type is valid for 967 /// the current attribute. 968 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 969 bool isSigned) const { 970 return ::isValidIntOrFloat( 971 getType().getElementType().cast<ComplexType>().getElementType(), 972 dataEltSize / 2, isInt, isSigned); 973 } 974 975 /// Returns true if this attribute corresponds to a splat, i.e. if all element 976 /// values are the same. 977 bool DenseElementsAttr::isSplat() const { 978 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 979 } 980 981 /// Return the held element values as a range of Attributes. 982 auto DenseElementsAttr::getAttributeValues() const 983 -> llvm::iterator_range<AttributeElementIterator> { 984 return {attr_value_begin(), attr_value_end()}; 985 } 986 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 987 return AttributeElementIterator(*this, 0); 988 } 989 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 990 return AttributeElementIterator(*this, getNumElements()); 991 } 992 993 /// Return the held element values as a range of bool. The element type of 994 /// this attribute must be of integer type of bitwidth 1. 995 auto DenseElementsAttr::getBoolValues() const 996 -> llvm::iterator_range<BoolElementIterator> { 997 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 998 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 999 (void)eltType; 1000 return {BoolElementIterator(*this, 0), 1001 BoolElementIterator(*this, getNumElements())}; 1002 } 1003 1004 /// Return the held element values as a range of APInts. The element type of 1005 /// this attribute must be of integer type. 1006 auto DenseElementsAttr::getIntValues() const 1007 -> llvm::iterator_range<IntElementIterator> { 1008 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 1009 return {raw_int_begin(), raw_int_end()}; 1010 } 1011 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 1012 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 1013 return raw_int_begin(); 1014 } 1015 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 1016 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 1017 return raw_int_end(); 1018 } 1019 auto DenseElementsAttr::getComplexIntValues() const 1020 -> llvm::iterator_range<ComplexIntElementIterator> { 1021 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 1022 (void)eltTy; 1023 assert(eltTy.isa<IntegerType>() && "expected complex integral type"); 1024 return {ComplexIntElementIterator(*this, 0), 1025 ComplexIntElementIterator(*this, getNumElements())}; 1026 } 1027 1028 /// Return the held element values as a range of APFloat. The element type of 1029 /// this attribute must be of float type. 1030 auto DenseElementsAttr::getFloatValues() const 1031 -> llvm::iterator_range<FloatElementIterator> { 1032 auto elementType = getType().getElementType().cast<FloatType>(); 1033 const auto &elementSemantics = elementType.getFloatSemantics(); 1034 return {FloatElementIterator(elementSemantics, raw_int_begin()), 1035 FloatElementIterator(elementSemantics, raw_int_end())}; 1036 } 1037 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 1038 return getFloatValues().begin(); 1039 } 1040 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 1041 return getFloatValues().end(); 1042 } 1043 auto DenseElementsAttr::getComplexFloatValues() const 1044 -> llvm::iterator_range<ComplexFloatElementIterator> { 1045 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 1046 assert(eltTy.isa<FloatType>() && "expected complex float type"); 1047 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 1048 return {{semantics, {*this, 0}}, 1049 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 1050 } 1051 1052 /// Return the raw storage data held by this attribute. 1053 ArrayRef<char> DenseElementsAttr::getRawData() const { 1054 return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data; 1055 } 1056 1057 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 1058 return static_cast<DenseStringElementsAttributeStorage *>(impl)->data; 1059 } 1060 1061 /// Return a new DenseElementsAttr that has the same data as the current 1062 /// attribute, but has been reshaped to 'newType'. The new type must have the 1063 /// same total number of elements as well as element type. 1064 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 1065 ShapedType curType = getType(); 1066 if (curType == newType) 1067 return *this; 1068 1069 (void)curType; 1070 assert(newType.getElementType() == curType.getElementType() && 1071 "expected the same element type"); 1072 assert(newType.getNumElements() == curType.getNumElements() && 1073 "expected the same number of elements"); 1074 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 1075 } 1076 1077 DenseElementsAttr 1078 DenseElementsAttr::mapValues(Type newElementType, 1079 function_ref<APInt(const APInt &)> mapping) const { 1080 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 1081 } 1082 1083 DenseElementsAttr DenseElementsAttr::mapValues( 1084 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1085 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 1086 } 1087 1088 //===----------------------------------------------------------------------===// 1089 // DenseStringElementsAttr 1090 //===----------------------------------------------------------------------===// 1091 1092 DenseStringElementsAttr 1093 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) { 1094 return Base::get(type.getContext(), type, values, (values.size() == 1)); 1095 } 1096 1097 //===----------------------------------------------------------------------===// 1098 // DenseIntOrFPElementsAttr 1099 //===----------------------------------------------------------------------===// 1100 1101 /// Utility method to write a range of APInt values to a buffer. 1102 template <typename APRangeT> 1103 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1104 APRangeT &&values) { 1105 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); 1106 size_t offset = 0; 1107 for (auto it = values.begin(), e = values.end(); it != e; 1108 ++it, offset += storageWidth) { 1109 assert((*it).getBitWidth() <= storageWidth); 1110 writeBits(data.data(), offset, *it); 1111 } 1112 } 1113 1114 /// Constructs a dense elements attribute from an array of raw APFloat values. 1115 /// Each APFloat value is expected to have the same bitwidth as the element 1116 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1117 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1118 size_t storageWidth, 1119 ArrayRef<APFloat> values, 1120 bool isSplat) { 1121 std::vector<char> data; 1122 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1123 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1124 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1125 } 1126 1127 /// Constructs a dense elements attribute from an array of raw APInt values. 1128 /// Each APInt value is expected to have the same bitwidth as the element type 1129 /// of 'type'. 1130 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1131 size_t storageWidth, 1132 ArrayRef<APInt> values, 1133 bool isSplat) { 1134 std::vector<char> data; 1135 writeAPIntsToBuffer(storageWidth, data, values); 1136 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1137 } 1138 1139 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1140 ArrayRef<char> data, 1141 bool isSplat) { 1142 assert((type.isa<RankedTensorType, VectorType>()) && 1143 "type must be ranked tensor or vector"); 1144 assert(type.hasStaticShape() && "type must have static shape"); 1145 return Base::get(type.getContext(), type, data, isSplat); 1146 } 1147 1148 /// Overload of the raw 'get' method that asserts that the given type is of 1149 /// complex type. This method is used to verify type invariants that the 1150 /// templatized 'get' method cannot. 1151 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1152 ArrayRef<char> data, 1153 int64_t dataEltSize, 1154 bool isInt, 1155 bool isSigned) { 1156 assert(::isValidIntOrFloat( 1157 type.getElementType().cast<ComplexType>().getElementType(), 1158 dataEltSize / 2, isInt, isSigned)); 1159 1160 int64_t numElements = data.size() / dataEltSize; 1161 assert(numElements == 1 || numElements == type.getNumElements()); 1162 return getRaw(type, data, /*isSplat=*/numElements == 1); 1163 } 1164 1165 /// Overload of the 'getRaw' method that asserts that the given type is of 1166 /// integer type. This method is used to verify type invariants that the 1167 /// templatized 'get' method cannot. 1168 DenseElementsAttr 1169 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1170 int64_t dataEltSize, bool isInt, 1171 bool isSigned) { 1172 assert( 1173 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1174 1175 int64_t numElements = data.size() / dataEltSize; 1176 assert(numElements == 1 || numElements == type.getNumElements()); 1177 return getRaw(type, data, /*isSplat=*/numElements == 1); 1178 } 1179 1180 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 1181 const char *inRawData, char *outRawData, size_t elementBitWidth, 1182 size_t numElements) { 1183 using llvm::support::ulittle16_t; 1184 using llvm::support::ulittle32_t; 1185 using llvm::support::ulittle64_t; 1186 1187 assert(llvm::support::endian::system_endianness() == // NOLINT 1188 llvm::support::endianness::big); // NOLINT 1189 // NOLINT to avoid warning message about replacing by static_assert() 1190 1191 // Following std::copy_n always converts endianness on BE machine. 1192 switch (elementBitWidth) { 1193 case 16: { 1194 const ulittle16_t *inRawDataPos = 1195 reinterpret_cast<const ulittle16_t *>(inRawData); 1196 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); 1197 std::copy_n(inRawDataPos, numElements, outDataPos); 1198 break; 1199 } 1200 case 32: { 1201 const ulittle32_t *inRawDataPos = 1202 reinterpret_cast<const ulittle32_t *>(inRawData); 1203 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); 1204 std::copy_n(inRawDataPos, numElements, outDataPos); 1205 break; 1206 } 1207 case 64: { 1208 const ulittle64_t *inRawDataPos = 1209 reinterpret_cast<const ulittle64_t *>(inRawData); 1210 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); 1211 std::copy_n(inRawDataPos, numElements, outDataPos); 1212 break; 1213 } 1214 default: { 1215 size_t nBytes = elementBitWidth / CHAR_BIT; 1216 for (size_t i = 0; i < nBytes; i++) 1217 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); 1218 break; 1219 } 1220 } 1221 } 1222 1223 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1224 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, 1225 ShapedType type) { 1226 size_t numElements = type.getNumElements(); 1227 Type elementType = type.getElementType(); 1228 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1229 elementType = complexTy.getElementType(); 1230 numElements = numElements * 2; 1231 } 1232 size_t elementBitWidth = getDenseElementStorageWidth(elementType); 1233 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && 1234 inRawData.size() <= outRawData.size()); 1235 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), 1236 elementBitWidth, numElements); 1237 } 1238 1239 //===----------------------------------------------------------------------===// 1240 // DenseFPElementsAttr 1241 //===----------------------------------------------------------------------===// 1242 1243 template <typename Fn, typename Attr> 1244 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1245 Type newElementType, 1246 llvm::SmallVectorImpl<char> &data) { 1247 size_t bitWidth = getDenseElementBitWidth(newElementType); 1248 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1249 1250 ShapedType newArrayType; 1251 if (inType.isa<RankedTensorType>()) 1252 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1253 else if (inType.isa<UnrankedTensorType>()) 1254 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1255 else if (inType.isa<VectorType>()) 1256 newArrayType = VectorType::get(inType.getShape(), newElementType); 1257 else 1258 assert(newArrayType && "Unhandled tensor type"); 1259 1260 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1261 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 1262 1263 // Functor used to process a single element value of the attribute. 1264 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1265 auto newInt = mapping(value); 1266 assert(newInt.getBitWidth() == bitWidth); 1267 writeBits(data.data(), index * storageBitWidth, newInt); 1268 }; 1269 1270 // Check for the splat case. 1271 if (attr.isSplat()) { 1272 processElt(*attr.begin(), /*index=*/0); 1273 return newArrayType; 1274 } 1275 1276 // Otherwise, process all of the element values. 1277 uint64_t elementIdx = 0; 1278 for (auto value : attr) 1279 processElt(value, elementIdx++); 1280 return newArrayType; 1281 } 1282 1283 DenseElementsAttr DenseFPElementsAttr::mapValues( 1284 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1285 llvm::SmallVector<char, 8> elementData; 1286 auto newArrayType = 1287 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1288 1289 return getRaw(newArrayType, elementData, isSplat()); 1290 } 1291 1292 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1293 bool DenseFPElementsAttr::classof(Attribute attr) { 1294 return attr.isa<DenseElementsAttr>() && 1295 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1296 } 1297 1298 //===----------------------------------------------------------------------===// 1299 // DenseIntElementsAttr 1300 //===----------------------------------------------------------------------===// 1301 1302 DenseElementsAttr DenseIntElementsAttr::mapValues( 1303 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1304 llvm::SmallVector<char, 8> elementData; 1305 auto newArrayType = 1306 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1307 1308 return getRaw(newArrayType, elementData, isSplat()); 1309 } 1310 1311 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1312 bool DenseIntElementsAttr::classof(Attribute attr) { 1313 return attr.isa<DenseElementsAttr>() && 1314 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1315 } 1316 1317 //===----------------------------------------------------------------------===// 1318 // OpaqueElementsAttr 1319 //===----------------------------------------------------------------------===// 1320 1321 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, 1322 StringRef bytes) { 1323 assert(TensorType::isValidElementType(type.getElementType()) && 1324 "Input element type should be a valid tensor element type"); 1325 return Base::get(type.getContext(), type, dialect, bytes); 1326 } 1327 1328 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } 1329 1330 /// Return the value at the given index. If index does not refer to a valid 1331 /// element, then a null attribute is returned. 1332 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1333 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1334 return Attribute(); 1335 } 1336 1337 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } 1338 1339 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1340 auto *d = getDialect(); 1341 if (!d) 1342 return true; 1343 auto *interface = 1344 d->getRegisteredInterface<DialectDecodeAttributesInterface>(); 1345 if (!interface) 1346 return true; 1347 return failed(interface->decode(*this, result)); 1348 } 1349 1350 //===----------------------------------------------------------------------===// 1351 // SparseElementsAttr 1352 //===----------------------------------------------------------------------===// 1353 1354 SparseElementsAttr SparseElementsAttr::get(ShapedType type, 1355 DenseElementsAttr indices, 1356 DenseElementsAttr values) { 1357 assert(indices.getType().getElementType().isInteger(64) && 1358 "expected sparse indices to be 64-bit integer values"); 1359 assert((type.isa<RankedTensorType, VectorType>()) && 1360 "type must be ranked tensor or vector"); 1361 assert(type.hasStaticShape() && "type must have static shape"); 1362 return Base::get(type.getContext(), type, 1363 indices.cast<DenseIntElementsAttr>(), values); 1364 } 1365 1366 DenseIntElementsAttr SparseElementsAttr::getIndices() const { 1367 return getImpl()->indices; 1368 } 1369 1370 DenseElementsAttr SparseElementsAttr::getValues() const { 1371 return getImpl()->values; 1372 } 1373 1374 /// Return the value of the element at the given index. 1375 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1376 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1377 auto type = getType(); 1378 1379 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1380 // as a 1-D index array. 1381 auto sparseIndices = getIndices(); 1382 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1383 1384 // Check to see if the indices are a splat. 1385 if (sparseIndices.isSplat()) { 1386 // If the index is also not a splat of the index value, we know that the 1387 // value is zero. 1388 auto splatIndex = *sparseIndexValues.begin(); 1389 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1390 return getZeroAttr(); 1391 1392 // If the indices are a splat, we also expect the values to be a splat. 1393 assert(getValues().isSplat() && "expected splat values"); 1394 return getValues().getSplatValue(); 1395 } 1396 1397 // Build a mapping between known indices and the offset of the stored element. 1398 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1399 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1400 size_t rank = type.getRank(); 1401 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1402 mappedIndices.try_emplace( 1403 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1404 1405 // Look for the provided index key within the mapped indices. If the provided 1406 // index is not found, then return a zero attribute. 1407 auto it = mappedIndices.find(index); 1408 if (it == mappedIndices.end()) 1409 return getZeroAttr(); 1410 1411 // Otherwise, return the held sparse value element. 1412 return getValues().getValue(it->second); 1413 } 1414 1415 /// Get a zero APFloat for the given sparse attribute. 1416 APFloat SparseElementsAttr::getZeroAPFloat() const { 1417 auto eltType = getType().getElementType().cast<FloatType>(); 1418 return APFloat(eltType.getFloatSemantics()); 1419 } 1420 1421 /// Get a zero APInt for the given sparse attribute. 1422 APInt SparseElementsAttr::getZeroAPInt() const { 1423 auto eltType = getType().getElementType().cast<IntegerType>(); 1424 return APInt::getNullValue(eltType.getWidth()); 1425 } 1426 1427 /// Get a zero attribute for the given attribute type. 1428 Attribute SparseElementsAttr::getZeroAttr() const { 1429 auto eltType = getType().getElementType(); 1430 1431 // Handle floating point elements. 1432 if (eltType.isa<FloatType>()) 1433 return FloatAttr::get(eltType, 0); 1434 1435 // Otherwise, this is an integer. 1436 // TODO: Handle StringAttr here. 1437 return IntegerAttr::get(eltType, 0); 1438 } 1439 1440 /// Flatten, and return, all of the sparse indices in this attribute in 1441 /// row-major order. 1442 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1443 std::vector<ptrdiff_t> flatSparseIndices; 1444 1445 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1446 // as a 1-D index array. 1447 auto sparseIndices = getIndices(); 1448 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1449 if (sparseIndices.isSplat()) { 1450 SmallVector<uint64_t, 8> indices(getType().getRank(), 1451 *sparseIndexValues.begin()); 1452 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1453 return flatSparseIndices; 1454 } 1455 1456 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1457 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1458 size_t rank = getType().getRank(); 1459 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1460 flatSparseIndices.push_back(getFlattenedIndex( 1461 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1462 return flatSparseIndices; 1463 } 1464