1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Diagnostics.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/IntegerSet.h" 15 #include "mlir/IR/Types.h" 16 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 #include "llvm/Support/Endian.h" 20 21 using namespace mlir; 22 using namespace mlir::detail; 23 24 //===----------------------------------------------------------------------===// 25 // AffineMapAttr 26 //===----------------------------------------------------------------------===// 27 28 AffineMapAttr AffineMapAttr::get(AffineMap value) { 29 return Base::get(value.getContext(), value); 30 } 31 32 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } 33 34 //===----------------------------------------------------------------------===// 35 // ArrayAttr 36 //===----------------------------------------------------------------------===// 37 38 ArrayAttr ArrayAttr::get(MLIRContext *context, ArrayRef<Attribute> value) { 39 return Base::get(context, value); 40 } 41 42 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } 43 44 Attribute ArrayAttr::operator[](unsigned idx) const { 45 assert(idx < size() && "index out of bounds"); 46 return getValue()[idx]; 47 } 48 49 //===----------------------------------------------------------------------===// 50 // DictionaryAttr 51 //===----------------------------------------------------------------------===// 52 53 /// Helper function that does either an in place sort or sorts from source array 54 /// into destination. If inPlace then storage is both the source and the 55 /// destination, else value is the source and storage destination. Returns 56 /// whether source was sorted. 57 template <bool inPlace> 58 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 59 SmallVectorImpl<NamedAttribute> &storage) { 60 // Specialize for the common case. 61 switch (value.size()) { 62 case 0: 63 // Zero already sorted. 64 break; 65 case 1: 66 // One already sorted but may need to be copied. 67 if (!inPlace) 68 storage.assign({value[0]}); 69 break; 70 case 2: { 71 bool isSorted = value[0] < value[1]; 72 if (inPlace) { 73 if (!isSorted) 74 std::swap(storage[0], storage[1]); 75 } else if (isSorted) { 76 storage.assign({value[0], value[1]}); 77 } else { 78 storage.assign({value[1], value[0]}); 79 } 80 return !isSorted; 81 } 82 default: 83 if (!inPlace) 84 storage.assign(value.begin(), value.end()); 85 // Check to see they are sorted already. 86 bool isSorted = llvm::is_sorted(value); 87 if (!isSorted) { 88 // If not, do a general sort. 89 llvm::array_pod_sort(storage.begin(), storage.end()); 90 value = storage; 91 } 92 return !isSorted; 93 } 94 return false; 95 } 96 97 /// Returns an entry with a duplicate name from the given sorted array of named 98 /// attributes. Returns llvm::None if all elements have unique names. 99 static Optional<NamedAttribute> 100 findDuplicateElement(ArrayRef<NamedAttribute> value) { 101 const Optional<NamedAttribute> none{llvm::None}; 102 if (value.size() < 2) 103 return none; 104 105 if (value.size() == 2) 106 return value[0].first == value[1].first ? value[0] : none; 107 108 auto it = std::adjacent_find( 109 value.begin(), value.end(), 110 [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; }); 111 return it != value.end() ? *it : none; 112 } 113 114 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 115 SmallVectorImpl<NamedAttribute> &storage) { 116 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); 117 assert(!findDuplicateElement(storage) && 118 "DictionaryAttr element names must be unique"); 119 return isSorted; 120 } 121 122 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 123 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); 124 assert(!findDuplicateElement(array) && 125 "DictionaryAttr element names must be unique"); 126 return isSorted; 127 } 128 129 Optional<NamedAttribute> 130 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, 131 bool isSorted) { 132 if (!isSorted) 133 dictionaryAttrSort</*inPlace=*/true>(array, array); 134 return findDuplicateElement(array); 135 } 136 137 DictionaryAttr DictionaryAttr::get(MLIRContext *context, 138 ArrayRef<NamedAttribute> value) { 139 if (value.empty()) 140 return DictionaryAttr::getEmpty(context); 141 assert(llvm::all_of(value, 142 [](const NamedAttribute &attr) { return attr.second; }) && 143 "value cannot have null entries"); 144 145 // We need to sort the element list to canonicalize it. 146 SmallVector<NamedAttribute, 8> storage; 147 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 148 value = storage; 149 assert(!findDuplicateElement(value) && 150 "DictionaryAttr element names must be unique"); 151 return Base::get(context, value); 152 } 153 /// Construct a dictionary with an array of values that is known to already be 154 /// sorted by name and uniqued. 155 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value, 156 MLIRContext *context) { 157 if (value.empty()) 158 return DictionaryAttr::getEmpty(context); 159 // Ensure that the attribute elements are unique and sorted. 160 assert(llvm::is_sorted(value, 161 [](NamedAttribute l, NamedAttribute r) { 162 return l.first.strref() < r.first.strref(); 163 }) && 164 "expected attribute values to be sorted"); 165 assert(!findDuplicateElement(value) && 166 "DictionaryAttr element names must be unique"); 167 return Base::get(context, value); 168 } 169 170 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const { 171 return getImpl()->getElements(); 172 } 173 174 /// Return the specified attribute if present, null otherwise. 175 Attribute DictionaryAttr::get(StringRef name) const { 176 Optional<NamedAttribute> attr = getNamed(name); 177 return attr ? attr->second : nullptr; 178 } 179 Attribute DictionaryAttr::get(Identifier name) const { 180 Optional<NamedAttribute> attr = getNamed(name); 181 return attr ? attr->second : nullptr; 182 } 183 184 /// Return the specified named attribute if present, None otherwise. 185 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 186 ArrayRef<NamedAttribute> values = getValue(); 187 const auto *it = llvm::lower_bound(values, name); 188 return it != values.end() && it->first == name ? *it 189 : Optional<NamedAttribute>(); 190 } 191 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { 192 for (auto elt : getValue()) 193 if (elt.first == name) 194 return elt; 195 return llvm::None; 196 } 197 198 DictionaryAttr::iterator DictionaryAttr::begin() const { 199 return getValue().begin(); 200 } 201 DictionaryAttr::iterator DictionaryAttr::end() const { 202 return getValue().end(); 203 } 204 size_t DictionaryAttr::size() const { return getValue().size(); } 205 206 //===----------------------------------------------------------------------===// 207 // FloatAttr 208 //===----------------------------------------------------------------------===// 209 210 FloatAttr FloatAttr::get(Type type, double value) { 211 return Base::get(type.getContext(), type, value); 212 } 213 214 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { 215 return Base::getChecked(loc, type, value); 216 } 217 218 FloatAttr FloatAttr::get(Type type, const APFloat &value) { 219 return Base::get(type.getContext(), type, value); 220 } 221 222 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { 223 return Base::getChecked(loc, type, value); 224 } 225 226 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } 227 228 double FloatAttr::getValueAsDouble() const { 229 return getValueAsDouble(getValue()); 230 } 231 double FloatAttr::getValueAsDouble(APFloat value) { 232 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 233 bool losesInfo = false; 234 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 235 &losesInfo); 236 } 237 return value.convertToDouble(); 238 } 239 240 /// Verify construction invariants. 241 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) { 242 if (!type.isa<FloatType>()) 243 return emitError(loc, "expected floating point type"); 244 return success(); 245 } 246 247 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 248 double value) { 249 return verifyFloatTypeInvariants(loc, type); 250 } 251 252 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 253 const APFloat &value) { 254 // Verify that the type is correct. 255 if (failed(verifyFloatTypeInvariants(loc, type))) 256 return failure(); 257 258 // Verify that the type semantics match that of the value. 259 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 260 return emitError( 261 loc, "FloatAttr type doesn't match the type implied by its value"); 262 } 263 return success(); 264 } 265 266 //===----------------------------------------------------------------------===// 267 // SymbolRefAttr 268 //===----------------------------------------------------------------------===// 269 270 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { 271 return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>(); 272 } 273 274 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value, 275 ArrayRef<FlatSymbolRefAttr> nestedReferences) { 276 return Base::get(ctx, value, nestedReferences); 277 } 278 279 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } 280 281 StringRef SymbolRefAttr::getLeafReference() const { 282 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 283 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 284 } 285 286 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const { 287 return getImpl()->getNestedRefs(); 288 } 289 290 //===----------------------------------------------------------------------===// 291 // IntegerAttr 292 //===----------------------------------------------------------------------===// 293 294 IntegerAttr IntegerAttr::get(Type type, const APInt &value) { 295 if (type.isSignlessInteger(1)) 296 return BoolAttr::get(type.getContext(), value.getBoolValue()); 297 return Base::get(type.getContext(), type, value); 298 } 299 300 IntegerAttr IntegerAttr::get(Type type, int64_t value) { 301 // This uses 64 bit APInts by default for index type. 302 if (type.isIndex()) 303 return get(type, APInt(IndexType::kInternalStorageBitWidth, value)); 304 305 auto intType = type.cast<IntegerType>(); 306 return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger())); 307 } 308 309 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } 310 311 int64_t IntegerAttr::getInt() const { 312 assert((getImpl()->getType().isIndex() || 313 getImpl()->getType().isSignlessInteger()) && 314 "must be signless integer"); 315 return getValue().getSExtValue(); 316 } 317 318 int64_t IntegerAttr::getSInt() const { 319 assert(getImpl()->getType().isSignedInteger() && "must be signed integer"); 320 return getValue().getSExtValue(); 321 } 322 323 uint64_t IntegerAttr::getUInt() const { 324 assert(getImpl()->getType().isUnsignedInteger() && 325 "must be unsigned integer"); 326 return getValue().getZExtValue(); 327 } 328 329 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) { 330 if (type.isa<IntegerType, IndexType>()) 331 return success(); 332 return emitError(loc, "expected integer or index type"); 333 } 334 335 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 336 int64_t value) { 337 return verifyIntegerTypeInvariants(loc, type); 338 } 339 340 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 341 const APInt &value) { 342 if (failed(verifyIntegerTypeInvariants(loc, type))) 343 return failure(); 344 if (auto integerType = type.dyn_cast<IntegerType>()) 345 if (integerType.getWidth() != value.getBitWidth()) 346 return emitError(loc, "integer type bit width (") 347 << integerType.getWidth() << ") doesn't match value bit width (" 348 << value.getBitWidth() << ")"; 349 return success(); 350 } 351 352 //===----------------------------------------------------------------------===// 353 // BoolAttr 354 355 bool BoolAttr::getValue() const { 356 auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl); 357 return storage->getValue().getBoolValue(); 358 } 359 360 bool BoolAttr::classof(Attribute attr) { 361 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 362 return intAttr && intAttr.getType().isSignlessInteger(1); 363 } 364 365 //===----------------------------------------------------------------------===// 366 // IntegerSetAttr 367 //===----------------------------------------------------------------------===// 368 369 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { 370 return Base::get(value.getConstraint(0).getContext(), value); 371 } 372 373 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } 374 375 //===----------------------------------------------------------------------===// 376 // OpaqueAttr 377 //===----------------------------------------------------------------------===// 378 379 OpaqueAttr OpaqueAttr::get(MLIRContext *context, Identifier dialect, 380 StringRef attrData, Type type) { 381 return Base::get(context, dialect, attrData, type); 382 } 383 384 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, 385 Type type, Location location) { 386 return Base::getChecked(location, dialect, attrData, type); 387 } 388 389 /// Returns the dialect namespace of the opaque attribute. 390 Identifier OpaqueAttr::getDialectNamespace() const { 391 return getImpl()->dialectNamespace; 392 } 393 394 /// Returns the raw attribute data of the opaque attribute. 395 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } 396 397 /// Verify the construction of an opaque attribute. 398 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc, 399 Identifier dialect, 400 StringRef attrData, 401 Type type) { 402 if (!Dialect::isValidNamespace(dialect.strref())) 403 return emitError(loc, "invalid dialect namespace '") << dialect << "'"; 404 return success(); 405 } 406 407 //===----------------------------------------------------------------------===// 408 // StringAttr 409 //===----------------------------------------------------------------------===// 410 411 StringAttr StringAttr::get(MLIRContext *context, StringRef bytes) { 412 return get(bytes, NoneType::get(context)); 413 } 414 415 /// Get an instance of a StringAttr with the given string and Type. 416 StringAttr StringAttr::get(StringRef bytes, Type type) { 417 return Base::get(type.getContext(), bytes, type); 418 } 419 420 StringRef StringAttr::getValue() const { return getImpl()->value; } 421 422 //===----------------------------------------------------------------------===// 423 // TypeAttr 424 //===----------------------------------------------------------------------===// 425 426 TypeAttr TypeAttr::get(Type value) { 427 return Base::get(value.getContext(), value); 428 } 429 430 Type TypeAttr::getValue() const { return getImpl()->value; } 431 432 //===----------------------------------------------------------------------===// 433 // ElementsAttr 434 //===----------------------------------------------------------------------===// 435 436 ShapedType ElementsAttr::getType() const { 437 return Attribute::getType().cast<ShapedType>(); 438 } 439 440 /// Returns the number of elements held by this attribute. 441 int64_t ElementsAttr::getNumElements() const { 442 return getType().getNumElements(); 443 } 444 445 /// Return the value at the given index. If index does not refer to a valid 446 /// element, then a null attribute is returned. 447 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 448 if (auto denseAttr = dyn_cast<DenseElementsAttr>()) 449 return denseAttr.getValue(index); 450 if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>()) 451 return opaqueAttr.getValue(index); 452 return cast<SparseElementsAttr>().getValue(index); 453 } 454 455 /// Return if the given 'index' refers to a valid element in this attribute. 456 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 457 auto type = getType(); 458 459 // Verify that the rank of the indices matches the held type. 460 auto rank = type.getRank(); 461 if (rank == 0 && index.size() == 1 && index[0] == 0) 462 return true; 463 if (rank != static_cast<int64_t>(index.size())) 464 return false; 465 466 // Verify that all of the indices are within the shape dimensions. 467 auto shape = type.getShape(); 468 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 469 return static_cast<int64_t>(index[i]) < shape[i]; 470 }); 471 } 472 473 ElementsAttr 474 ElementsAttr::mapValues(Type newElementType, 475 function_ref<APInt(const APInt &)> mapping) const { 476 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 477 return intOrFpAttr.mapValues(newElementType, mapping); 478 llvm_unreachable("unsupported ElementsAttr subtype"); 479 } 480 481 ElementsAttr 482 ElementsAttr::mapValues(Type newElementType, 483 function_ref<APInt(const APFloat &)> mapping) const { 484 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 485 return intOrFpAttr.mapValues(newElementType, mapping); 486 llvm_unreachable("unsupported ElementsAttr subtype"); 487 } 488 489 /// Method for support type inquiry through isa, cast and dyn_cast. 490 bool ElementsAttr::classof(Attribute attr) { 491 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr, 492 OpaqueElementsAttr, SparseElementsAttr>(); 493 } 494 495 /// Returns the 1 dimensional flattened row-major index from the given 496 /// multi-dimensional index. 497 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 498 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 499 auto type = getType(); 500 501 // Reduce the provided multidimensional index into a flattended 1D row-major 502 // index. 503 auto rank = type.getRank(); 504 auto shape = type.getShape(); 505 uint64_t valueIndex = 0; 506 uint64_t dimMultiplier = 1; 507 for (int i = rank - 1; i >= 0; --i) { 508 valueIndex += index[i] * dimMultiplier; 509 dimMultiplier *= shape[i]; 510 } 511 return valueIndex; 512 } 513 514 //===----------------------------------------------------------------------===// 515 // DenseElementsAttr Utilities 516 //===----------------------------------------------------------------------===// 517 518 /// Get the bitwidth of a dense element type within the buffer. 519 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 520 static size_t getDenseElementStorageWidth(size_t origWidth) { 521 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 522 } 523 static size_t getDenseElementStorageWidth(Type elementType) { 524 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 525 } 526 527 /// Set a bit to a specific value. 528 static void setBit(char *rawData, size_t bitPos, bool value) { 529 if (value) 530 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 531 else 532 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 533 } 534 535 /// Return the value of the specified bit. 536 static bool getBit(const char *rawData, size_t bitPos) { 537 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 538 } 539 540 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for 541 /// BE format. 542 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, 543 char *result) { 544 assert(llvm::support::endian::system_endianness() == // NOLINT 545 llvm::support::endianness::big); // NOLINT 546 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 547 548 // Copy the words filled with data. 549 // For example, when `value` has 2 words, the first word is filled with data. 550 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| 551 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 552 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 553 numFilledWords, result); 554 // Convert last word of APInt to LE format and store it in char 555 // array(`valueLE`). 556 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| 557 size_t lastWordPos = numFilledWords; 558 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); 559 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 560 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, 561 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); 562 // Extract actual APInt data from `valueLE`, convert endianness to BE format, 563 // and store it in `result`. 564 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| 565 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 566 valueLE.begin(), result + lastWordPos, 567 (numBytes - lastWordPos) * CHAR_BIT, 1); 568 } 569 570 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE 571 /// format. 572 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, 573 APInt &result) { 574 assert(llvm::support::endian::system_endianness() == // NOLINT 575 llvm::support::endianness::big); // NOLINT 576 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 577 578 // Copy the data that fills the word of `result` from `inArray`. 579 // For example, when `result` has 2 words, the first word will be filled with 580 // data. So, the first 8 bytes are copied from `inArray` here. 581 // `inArray` (10 bytes, BE): |abcdefgh|ij| 582 // ==> `result` (2 words, BE): |abcdefgh|--------| 583 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 584 std::copy_n( 585 inArray, numFilledWords, 586 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 587 588 // Convert array data which will be last word of `result` to LE format, and 589 // store it in char array(`inArrayLE`). 590 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| 591 size_t lastWordPos = numFilledWords; 592 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); 593 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 594 inArray + lastWordPos, inArrayLE.begin(), 595 (numBytes - lastWordPos) * CHAR_BIT, 1); 596 597 // Convert `inArrayLE` to BE format, and store it in last word of `result`. 598 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| 599 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 600 inArrayLE.begin(), 601 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + 602 lastWordPos, 603 APInt::APINT_BITS_PER_WORD, 1); 604 } 605 606 /// Writes value to the bit position `bitPos` in array `rawData`. 607 static void writeBits(char *rawData, size_t bitPos, APInt value) { 608 size_t bitWidth = value.getBitWidth(); 609 610 // If the bitwidth is 1 we just toggle the specific bit. 611 if (bitWidth == 1) 612 return setBit(rawData, bitPos, value.isOneValue()); 613 614 // Otherwise, the bit position is guaranteed to be byte aligned. 615 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 616 if (llvm::support::endian::system_endianness() == 617 llvm::support::endianness::big) { 618 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. 619 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 620 // work correctly in BE format. 621 // ex. `value` (2 words including 10 bytes) 622 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| 623 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), 624 rawData + (bitPos / CHAR_BIT)); 625 } else { 626 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 627 llvm::divideCeil(bitWidth, CHAR_BIT), 628 rawData + (bitPos / CHAR_BIT)); 629 } 630 } 631 632 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 633 /// `rawData`. 634 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 635 // Handle a boolean bit position. 636 if (bitWidth == 1) 637 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 638 639 // Otherwise, the bit position must be 8-bit aligned. 640 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 641 APInt result(bitWidth, 0); 642 if (llvm::support::endian::system_endianness() == 643 llvm::support::endianness::big) { 644 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. 645 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 646 // work correctly in BE format. 647 // ex. `result` (2 words including 10 bytes) 648 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function 649 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), 650 llvm::divideCeil(bitWidth, CHAR_BIT), result); 651 } else { 652 std::copy_n(rawData + (bitPos / CHAR_BIT), 653 llvm::divideCeil(bitWidth, CHAR_BIT), 654 const_cast<char *>( 655 reinterpret_cast<const char *>(result.getRawData()))); 656 } 657 return result; 658 } 659 660 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 661 /// the same element count as 'type'. 662 template <typename Values> 663 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 664 return (values.size() == 1) || 665 (type.getNumElements() == static_cast<int64_t>(values.size())); 666 } 667 668 //===----------------------------------------------------------------------===// 669 // DenseElementsAttr Iterators 670 //===----------------------------------------------------------------------===// 671 672 //===----------------------------------------------------------------------===// 673 // AttributeElementIterator 674 675 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 676 DenseElementsAttr attr, size_t index) 677 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 678 Attribute, Attribute, Attribute>( 679 attr.getAsOpaquePointer(), index) {} 680 681 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 682 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 683 Type eltTy = owner.getType().getElementType(); 684 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 685 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 686 if (eltTy.isa<IndexType>()) 687 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 688 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 689 IntElementIterator intIt(owner, index); 690 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 691 return FloatAttr::get(eltTy, *floatIt); 692 } 693 if (owner.isa<DenseStringElementsAttr>()) { 694 ArrayRef<StringRef> vals = owner.getRawStringData(); 695 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 696 } 697 llvm_unreachable("unexpected element type"); 698 } 699 700 //===----------------------------------------------------------------------===// 701 // BoolElementIterator 702 703 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 704 DenseElementsAttr attr, size_t dataIndex) 705 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 706 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 707 708 bool DenseElementsAttr::BoolElementIterator::operator*() const { 709 return getBit(getData(), getDataIndex()); 710 } 711 712 //===----------------------------------------------------------------------===// 713 // IntElementIterator 714 715 DenseElementsAttr::IntElementIterator::IntElementIterator( 716 DenseElementsAttr attr, size_t dataIndex) 717 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 718 attr.getRawData().data(), attr.isSplat(), dataIndex), 719 bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} 720 721 APInt DenseElementsAttr::IntElementIterator::operator*() const { 722 return readBits(getData(), 723 getDataIndex() * getDenseElementStorageWidth(bitWidth), 724 bitWidth); 725 } 726 727 //===----------------------------------------------------------------------===// 728 // ComplexIntElementIterator 729 730 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 731 DenseElementsAttr attr, size_t dataIndex) 732 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 733 std::complex<APInt>, std::complex<APInt>, 734 std::complex<APInt>>( 735 attr.getRawData().data(), attr.isSplat(), dataIndex) { 736 auto complexType = attr.getType().getElementType().cast<ComplexType>(); 737 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 738 } 739 740 std::complex<APInt> 741 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 742 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 743 size_t offset = getDataIndex() * storageWidth * 2; 744 return {readBits(getData(), offset, bitWidth), 745 readBits(getData(), offset + storageWidth, bitWidth)}; 746 } 747 748 //===----------------------------------------------------------------------===// 749 // FloatElementIterator 750 751 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 752 const llvm::fltSemantics &smt, IntElementIterator it) 753 : llvm::mapped_iterator<IntElementIterator, 754 std::function<APFloat(const APInt &)>>( 755 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 756 757 //===----------------------------------------------------------------------===// 758 // ComplexFloatElementIterator 759 760 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( 761 const llvm::fltSemantics &smt, ComplexIntElementIterator it) 762 : llvm::mapped_iterator< 763 ComplexIntElementIterator, 764 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( 765 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { 766 return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; 767 }) {} 768 769 //===----------------------------------------------------------------------===// 770 // DenseElementsAttr 771 //===----------------------------------------------------------------------===// 772 773 /// Method for support type inquiry through isa, cast and dyn_cast. 774 bool DenseElementsAttr::classof(Attribute attr) { 775 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 776 } 777 778 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 779 ArrayRef<Attribute> values) { 780 assert(hasSameElementsOrSplat(type, values)); 781 782 // If the element type is not based on int/float/index, assume it is a string 783 // type. 784 auto eltType = type.getElementType(); 785 if (!type.getElementType().isIntOrIndexOrFloat()) { 786 SmallVector<StringRef, 8> stringValues; 787 stringValues.reserve(values.size()); 788 for (Attribute attr : values) { 789 assert(attr.isa<StringAttr>() && 790 "expected string value for non integer/index/float element"); 791 stringValues.push_back(attr.cast<StringAttr>().getValue()); 792 } 793 return get(type, stringValues); 794 } 795 796 // Otherwise, get the raw storage width to use for the allocation. 797 size_t bitWidth = getDenseElementBitWidth(eltType); 798 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 799 800 // Compress the attribute values into a character buffer. 801 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 802 values.size()); 803 APInt intVal; 804 for (unsigned i = 0, e = values.size(); i < e; ++i) { 805 assert(eltType == values[i].getType() && 806 "expected attribute value to have element type"); 807 if (eltType.isa<FloatType>()) 808 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 809 else if (eltType.isa<IntegerType>()) 810 intVal = values[i].cast<IntegerAttr>().getValue(); 811 else 812 llvm_unreachable("unexpected element type"); 813 814 assert(intVal.getBitWidth() == bitWidth && 815 "expected value to have same bitwidth as element type"); 816 writeBits(data.data(), i * storageBitWidth, intVal); 817 } 818 return DenseIntOrFPElementsAttr::getRaw(type, data, 819 /*isSplat=*/(values.size() == 1)); 820 } 821 822 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 823 ArrayRef<bool> values) { 824 assert(hasSameElementsOrSplat(type, values)); 825 assert(type.getElementType().isInteger(1)); 826 827 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 828 for (int i = 0, e = values.size(); i != e; ++i) 829 setBit(buff.data(), i, values[i]); 830 return DenseIntOrFPElementsAttr::getRaw(type, buff, 831 /*isSplat=*/(values.size() == 1)); 832 } 833 834 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 835 ArrayRef<StringRef> values) { 836 assert(!type.getElementType().isIntOrFloat()); 837 return DenseStringElementsAttr::get(type, values); 838 } 839 840 /// Constructs a dense integer elements attribute from an array of APInt 841 /// values. Each APInt value is expected to have the same bitwidth as the 842 /// element type of 'type'. 843 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 844 ArrayRef<APInt> values) { 845 assert(type.getElementType().isIntOrIndex()); 846 assert(hasSameElementsOrSplat(type, values)); 847 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 848 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 849 /*isSplat=*/(values.size() == 1)); 850 } 851 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 852 ArrayRef<std::complex<APInt>> values) { 853 ComplexType complex = type.getElementType().cast<ComplexType>(); 854 assert(complex.getElementType().isa<IntegerType>()); 855 assert(hasSameElementsOrSplat(type, values)); 856 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 857 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 858 values.size() * 2); 859 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, 860 /*isSplat=*/(values.size() == 1)); 861 } 862 863 // Constructs a dense float elements attribute from an array of APFloat 864 // values. Each APFloat value is expected to have the same bitwidth as the 865 // element type of 'type'. 866 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 867 ArrayRef<APFloat> values) { 868 assert(type.getElementType().isa<FloatType>()); 869 assert(hasSameElementsOrSplat(type, values)); 870 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 871 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 872 /*isSplat=*/(values.size() == 1)); 873 } 874 DenseElementsAttr 875 DenseElementsAttr::get(ShapedType type, 876 ArrayRef<std::complex<APFloat>> values) { 877 ComplexType complex = type.getElementType().cast<ComplexType>(); 878 assert(complex.getElementType().isa<FloatType>()); 879 assert(hasSameElementsOrSplat(type, values)); 880 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 881 values.size() * 2); 882 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 883 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, 884 /*isSplat=*/(values.size() == 1)); 885 } 886 887 /// Construct a dense elements attribute from a raw buffer representing the 888 /// data for this attribute. Users should generally not use this methods as 889 /// the expected buffer format may not be a form the user expects. 890 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 891 ArrayRef<char> rawBuffer, 892 bool isSplatBuffer) { 893 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 894 } 895 896 /// Returns true if the given buffer is a valid raw buffer for the given type. 897 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 898 ArrayRef<char> rawBuffer, 899 bool &detectedSplat) { 900 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 901 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 902 903 // Storage width of 1 is special as it is packed by the bit. 904 if (storageWidth == 1) { 905 // Check for a splat, or a buffer equal to the number of elements. 906 if ((detectedSplat = rawBuffer.size() == 1)) 907 return true; 908 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); 909 } 910 // All other types are 8-bit aligned. 911 if ((detectedSplat = rawBufferWidth == storageWidth)) 912 return true; 913 return rawBufferWidth == (storageWidth * type.getNumElements()); 914 } 915 916 /// Check the information for a C++ data type, check if this type is valid for 917 /// the current attribute. This method is used to verify specific type 918 /// invariants that the templatized 'getValues' method cannot. 919 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 920 bool isSigned) { 921 // Make sure that the data element size is the same as the type element width. 922 if (getDenseElementBitWidth(type) != 923 static_cast<size_t>(dataEltSize * CHAR_BIT)) 924 return false; 925 926 // Check that the element type is either float or integer or index. 927 if (!isInt) 928 return type.isa<FloatType>(); 929 if (type.isIndex()) 930 return true; 931 932 auto intType = type.dyn_cast<IntegerType>(); 933 if (!intType) 934 return false; 935 936 // Make sure signedness semantics is consistent. 937 if (intType.isSignless()) 938 return true; 939 return intType.isSigned() ? isSigned : !isSigned; 940 } 941 942 /// Defaults down the subclass implementation. 943 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 944 ArrayRef<char> data, 945 int64_t dataEltSize, 946 bool isInt, bool isSigned) { 947 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 948 isSigned); 949 } 950 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 951 ArrayRef<char> data, 952 int64_t dataEltSize, 953 bool isInt, 954 bool isSigned) { 955 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 956 isInt, isSigned); 957 } 958 959 /// A method used to verify specific type invariants that the templatized 'get' 960 /// method cannot. 961 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 962 bool isSigned) const { 963 return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt, 964 isSigned); 965 } 966 967 /// Check the information for a C++ data type, check if this type is valid for 968 /// the current attribute. 969 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 970 bool isSigned) const { 971 return ::isValidIntOrFloat( 972 getType().getElementType().cast<ComplexType>().getElementType(), 973 dataEltSize / 2, isInt, isSigned); 974 } 975 976 /// Returns true if this attribute corresponds to a splat, i.e. if all element 977 /// values are the same. 978 bool DenseElementsAttr::isSplat() const { 979 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 980 } 981 982 /// Return the held element values as a range of Attributes. 983 auto DenseElementsAttr::getAttributeValues() const 984 -> llvm::iterator_range<AttributeElementIterator> { 985 return {attr_value_begin(), attr_value_end()}; 986 } 987 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 988 return AttributeElementIterator(*this, 0); 989 } 990 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 991 return AttributeElementIterator(*this, getNumElements()); 992 } 993 994 /// Return the held element values as a range of bool. The element type of 995 /// this attribute must be of integer type of bitwidth 1. 996 auto DenseElementsAttr::getBoolValues() const 997 -> llvm::iterator_range<BoolElementIterator> { 998 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 999 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 1000 (void)eltType; 1001 return {BoolElementIterator(*this, 0), 1002 BoolElementIterator(*this, getNumElements())}; 1003 } 1004 1005 /// Return the held element values as a range of APInts. The element type of 1006 /// this attribute must be of integer type. 1007 auto DenseElementsAttr::getIntValues() const 1008 -> llvm::iterator_range<IntElementIterator> { 1009 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 1010 return {raw_int_begin(), raw_int_end()}; 1011 } 1012 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 1013 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 1014 return raw_int_begin(); 1015 } 1016 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 1017 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 1018 return raw_int_end(); 1019 } 1020 auto DenseElementsAttr::getComplexIntValues() const 1021 -> llvm::iterator_range<ComplexIntElementIterator> { 1022 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 1023 (void)eltTy; 1024 assert(eltTy.isa<IntegerType>() && "expected complex integral type"); 1025 return {ComplexIntElementIterator(*this, 0), 1026 ComplexIntElementIterator(*this, getNumElements())}; 1027 } 1028 1029 /// Return the held element values as a range of APFloat. The element type of 1030 /// this attribute must be of float type. 1031 auto DenseElementsAttr::getFloatValues() const 1032 -> llvm::iterator_range<FloatElementIterator> { 1033 auto elementType = getType().getElementType().cast<FloatType>(); 1034 const auto &elementSemantics = elementType.getFloatSemantics(); 1035 return {FloatElementIterator(elementSemantics, raw_int_begin()), 1036 FloatElementIterator(elementSemantics, raw_int_end())}; 1037 } 1038 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 1039 return getFloatValues().begin(); 1040 } 1041 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 1042 return getFloatValues().end(); 1043 } 1044 auto DenseElementsAttr::getComplexFloatValues() const 1045 -> llvm::iterator_range<ComplexFloatElementIterator> { 1046 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 1047 assert(eltTy.isa<FloatType>() && "expected complex float type"); 1048 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 1049 return {{semantics, {*this, 0}}, 1050 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 1051 } 1052 1053 /// Return the raw storage data held by this attribute. 1054 ArrayRef<char> DenseElementsAttr::getRawData() const { 1055 return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data; 1056 } 1057 1058 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 1059 return static_cast<DenseStringElementsAttributeStorage *>(impl)->data; 1060 } 1061 1062 /// Return a new DenseElementsAttr that has the same data as the current 1063 /// attribute, but has been reshaped to 'newType'. The new type must have the 1064 /// same total number of elements as well as element type. 1065 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 1066 ShapedType curType = getType(); 1067 if (curType == newType) 1068 return *this; 1069 1070 (void)curType; 1071 assert(newType.getElementType() == curType.getElementType() && 1072 "expected the same element type"); 1073 assert(newType.getNumElements() == curType.getNumElements() && 1074 "expected the same number of elements"); 1075 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 1076 } 1077 1078 DenseElementsAttr 1079 DenseElementsAttr::mapValues(Type newElementType, 1080 function_ref<APInt(const APInt &)> mapping) const { 1081 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 1082 } 1083 1084 DenseElementsAttr DenseElementsAttr::mapValues( 1085 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1086 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 1087 } 1088 1089 //===----------------------------------------------------------------------===// 1090 // DenseStringElementsAttr 1091 //===----------------------------------------------------------------------===// 1092 1093 DenseStringElementsAttr 1094 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) { 1095 return Base::get(type.getContext(), type, values, (values.size() == 1)); 1096 } 1097 1098 //===----------------------------------------------------------------------===// 1099 // DenseIntOrFPElementsAttr 1100 //===----------------------------------------------------------------------===// 1101 1102 /// Utility method to write a range of APInt values to a buffer. 1103 template <typename APRangeT> 1104 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1105 APRangeT &&values) { 1106 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); 1107 size_t offset = 0; 1108 for (auto it = values.begin(), e = values.end(); it != e; 1109 ++it, offset += storageWidth) { 1110 assert((*it).getBitWidth() <= storageWidth); 1111 writeBits(data.data(), offset, *it); 1112 } 1113 } 1114 1115 /// Constructs a dense elements attribute from an array of raw APFloat values. 1116 /// Each APFloat value is expected to have the same bitwidth as the element 1117 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1118 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1119 size_t storageWidth, 1120 ArrayRef<APFloat> values, 1121 bool isSplat) { 1122 std::vector<char> data; 1123 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1124 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1125 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1126 } 1127 1128 /// Constructs a dense elements attribute from an array of raw APInt values. 1129 /// Each APInt value is expected to have the same bitwidth as the element type 1130 /// of 'type'. 1131 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1132 size_t storageWidth, 1133 ArrayRef<APInt> values, 1134 bool isSplat) { 1135 std::vector<char> data; 1136 writeAPIntsToBuffer(storageWidth, data, values); 1137 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1138 } 1139 1140 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1141 ArrayRef<char> data, 1142 bool isSplat) { 1143 assert((type.isa<RankedTensorType, VectorType>()) && 1144 "type must be ranked tensor or vector"); 1145 assert(type.hasStaticShape() && "type must have static shape"); 1146 return Base::get(type.getContext(), type, data, isSplat); 1147 } 1148 1149 /// Overload of the raw 'get' method that asserts that the given type is of 1150 /// complex type. This method is used to verify type invariants that the 1151 /// templatized 'get' method cannot. 1152 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1153 ArrayRef<char> data, 1154 int64_t dataEltSize, 1155 bool isInt, 1156 bool isSigned) { 1157 assert(::isValidIntOrFloat( 1158 type.getElementType().cast<ComplexType>().getElementType(), 1159 dataEltSize / 2, isInt, isSigned)); 1160 1161 int64_t numElements = data.size() / dataEltSize; 1162 assert(numElements == 1 || numElements == type.getNumElements()); 1163 return getRaw(type, data, /*isSplat=*/numElements == 1); 1164 } 1165 1166 /// Overload of the 'getRaw' method that asserts that the given type is of 1167 /// integer type. This method is used to verify type invariants that the 1168 /// templatized 'get' method cannot. 1169 DenseElementsAttr 1170 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1171 int64_t dataEltSize, bool isInt, 1172 bool isSigned) { 1173 assert( 1174 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1175 1176 int64_t numElements = data.size() / dataEltSize; 1177 assert(numElements == 1 || numElements == type.getNumElements()); 1178 return getRaw(type, data, /*isSplat=*/numElements == 1); 1179 } 1180 1181 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 1182 const char *inRawData, char *outRawData, size_t elementBitWidth, 1183 size_t numElements) { 1184 using llvm::support::ulittle16_t; 1185 using llvm::support::ulittle32_t; 1186 using llvm::support::ulittle64_t; 1187 1188 assert(llvm::support::endian::system_endianness() == // NOLINT 1189 llvm::support::endianness::big); // NOLINT 1190 // NOLINT to avoid warning message about replacing by static_assert() 1191 1192 // Following std::copy_n always converts endianness on BE machine. 1193 switch (elementBitWidth) { 1194 case 16: { 1195 const ulittle16_t *inRawDataPos = 1196 reinterpret_cast<const ulittle16_t *>(inRawData); 1197 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); 1198 std::copy_n(inRawDataPos, numElements, outDataPos); 1199 break; 1200 } 1201 case 32: { 1202 const ulittle32_t *inRawDataPos = 1203 reinterpret_cast<const ulittle32_t *>(inRawData); 1204 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); 1205 std::copy_n(inRawDataPos, numElements, outDataPos); 1206 break; 1207 } 1208 case 64: { 1209 const ulittle64_t *inRawDataPos = 1210 reinterpret_cast<const ulittle64_t *>(inRawData); 1211 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); 1212 std::copy_n(inRawDataPos, numElements, outDataPos); 1213 break; 1214 } 1215 default: { 1216 size_t nBytes = elementBitWidth / CHAR_BIT; 1217 for (size_t i = 0; i < nBytes; i++) 1218 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); 1219 break; 1220 } 1221 } 1222 } 1223 1224 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1225 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, 1226 ShapedType type) { 1227 size_t numElements = type.getNumElements(); 1228 Type elementType = type.getElementType(); 1229 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1230 elementType = complexTy.getElementType(); 1231 numElements = numElements * 2; 1232 } 1233 size_t elementBitWidth = getDenseElementStorageWidth(elementType); 1234 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && 1235 inRawData.size() <= outRawData.size()); 1236 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), 1237 elementBitWidth, numElements); 1238 } 1239 1240 //===----------------------------------------------------------------------===// 1241 // DenseFPElementsAttr 1242 //===----------------------------------------------------------------------===// 1243 1244 template <typename Fn, typename Attr> 1245 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1246 Type newElementType, 1247 llvm::SmallVectorImpl<char> &data) { 1248 size_t bitWidth = getDenseElementBitWidth(newElementType); 1249 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1250 1251 ShapedType newArrayType; 1252 if (inType.isa<RankedTensorType>()) 1253 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1254 else if (inType.isa<UnrankedTensorType>()) 1255 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1256 else if (inType.isa<VectorType>()) 1257 newArrayType = VectorType::get(inType.getShape(), newElementType); 1258 else 1259 assert(newArrayType && "Unhandled tensor type"); 1260 1261 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1262 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 1263 1264 // Functor used to process a single element value of the attribute. 1265 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1266 auto newInt = mapping(value); 1267 assert(newInt.getBitWidth() == bitWidth); 1268 writeBits(data.data(), index * storageBitWidth, newInt); 1269 }; 1270 1271 // Check for the splat case. 1272 if (attr.isSplat()) { 1273 processElt(*attr.begin(), /*index=*/0); 1274 return newArrayType; 1275 } 1276 1277 // Otherwise, process all of the element values. 1278 uint64_t elementIdx = 0; 1279 for (auto value : attr) 1280 processElt(value, elementIdx++); 1281 return newArrayType; 1282 } 1283 1284 DenseElementsAttr DenseFPElementsAttr::mapValues( 1285 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1286 llvm::SmallVector<char, 8> elementData; 1287 auto newArrayType = 1288 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1289 1290 return getRaw(newArrayType, elementData, isSplat()); 1291 } 1292 1293 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1294 bool DenseFPElementsAttr::classof(Attribute attr) { 1295 return attr.isa<DenseElementsAttr>() && 1296 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1297 } 1298 1299 //===----------------------------------------------------------------------===// 1300 // DenseIntElementsAttr 1301 //===----------------------------------------------------------------------===// 1302 1303 DenseElementsAttr DenseIntElementsAttr::mapValues( 1304 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1305 llvm::SmallVector<char, 8> elementData; 1306 auto newArrayType = 1307 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1308 1309 return getRaw(newArrayType, elementData, isSplat()); 1310 } 1311 1312 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1313 bool DenseIntElementsAttr::classof(Attribute attr) { 1314 return attr.isa<DenseElementsAttr>() && 1315 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1316 } 1317 1318 //===----------------------------------------------------------------------===// 1319 // OpaqueElementsAttr 1320 //===----------------------------------------------------------------------===// 1321 1322 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, 1323 StringRef bytes) { 1324 assert(TensorType::isValidElementType(type.getElementType()) && 1325 "Input element type should be a valid tensor element type"); 1326 return Base::get(type.getContext(), type, dialect, bytes); 1327 } 1328 1329 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } 1330 1331 /// Return the value at the given index. If index does not refer to a valid 1332 /// element, then a null attribute is returned. 1333 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1334 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1335 return Attribute(); 1336 } 1337 1338 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } 1339 1340 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1341 auto *d = getDialect(); 1342 if (!d) 1343 return true; 1344 auto *interface = 1345 d->getRegisteredInterface<DialectDecodeAttributesInterface>(); 1346 if (!interface) 1347 return true; 1348 return failed(interface->decode(*this, result)); 1349 } 1350 1351 //===----------------------------------------------------------------------===// 1352 // SparseElementsAttr 1353 //===----------------------------------------------------------------------===// 1354 1355 SparseElementsAttr SparseElementsAttr::get(ShapedType type, 1356 DenseElementsAttr indices, 1357 DenseElementsAttr values) { 1358 assert(indices.getType().getElementType().isInteger(64) && 1359 "expected sparse indices to be 64-bit integer values"); 1360 assert((type.isa<RankedTensorType, VectorType>()) && 1361 "type must be ranked tensor or vector"); 1362 assert(type.hasStaticShape() && "type must have static shape"); 1363 return Base::get(type.getContext(), type, 1364 indices.cast<DenseIntElementsAttr>(), values); 1365 } 1366 1367 DenseIntElementsAttr SparseElementsAttr::getIndices() const { 1368 return getImpl()->indices; 1369 } 1370 1371 DenseElementsAttr SparseElementsAttr::getValues() const { 1372 return getImpl()->values; 1373 } 1374 1375 /// Return the value of the element at the given index. 1376 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1377 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1378 auto type = getType(); 1379 1380 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1381 // as a 1-D index array. 1382 auto sparseIndices = getIndices(); 1383 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1384 1385 // Check to see if the indices are a splat. 1386 if (sparseIndices.isSplat()) { 1387 // If the index is also not a splat of the index value, we know that the 1388 // value is zero. 1389 auto splatIndex = *sparseIndexValues.begin(); 1390 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1391 return getZeroAttr(); 1392 1393 // If the indices are a splat, we also expect the values to be a splat. 1394 assert(getValues().isSplat() && "expected splat values"); 1395 return getValues().getSplatValue(); 1396 } 1397 1398 // Build a mapping between known indices and the offset of the stored element. 1399 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1400 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1401 size_t rank = type.getRank(); 1402 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1403 mappedIndices.try_emplace( 1404 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1405 1406 // Look for the provided index key within the mapped indices. If the provided 1407 // index is not found, then return a zero attribute. 1408 auto it = mappedIndices.find(index); 1409 if (it == mappedIndices.end()) 1410 return getZeroAttr(); 1411 1412 // Otherwise, return the held sparse value element. 1413 return getValues().getValue(it->second); 1414 } 1415 1416 /// Get a zero APFloat for the given sparse attribute. 1417 APFloat SparseElementsAttr::getZeroAPFloat() const { 1418 auto eltType = getType().getElementType().cast<FloatType>(); 1419 return APFloat(eltType.getFloatSemantics()); 1420 } 1421 1422 /// Get a zero APInt for the given sparse attribute. 1423 APInt SparseElementsAttr::getZeroAPInt() const { 1424 auto eltType = getType().getElementType().cast<IntegerType>(); 1425 return APInt::getNullValue(eltType.getWidth()); 1426 } 1427 1428 /// Get a zero attribute for the given attribute type. 1429 Attribute SparseElementsAttr::getZeroAttr() const { 1430 auto eltType = getType().getElementType(); 1431 1432 // Handle floating point elements. 1433 if (eltType.isa<FloatType>()) 1434 return FloatAttr::get(eltType, 0); 1435 1436 // Otherwise, this is an integer. 1437 // TODO: Handle StringAttr here. 1438 return IntegerAttr::get(eltType, 0); 1439 } 1440 1441 /// Flatten, and return, all of the sparse indices in this attribute in 1442 /// row-major order. 1443 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1444 std::vector<ptrdiff_t> flatSparseIndices; 1445 1446 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1447 // as a 1-D index array. 1448 auto sparseIndices = getIndices(); 1449 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1450 if (sparseIndices.isSplat()) { 1451 SmallVector<uint64_t, 8> indices(getType().getRank(), 1452 *sparseIndexValues.begin()); 1453 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1454 return flatSparseIndices; 1455 } 1456 1457 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1458 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1459 size_t rank = getType().getRank(); 1460 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1461 flatSparseIndices.push_back(getFlattenedIndex( 1462 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1463 return flatSparseIndices; 1464 } 1465