1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Diagnostics.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/IntegerSet.h" 15 #include "mlir/IR/Types.h" 16 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 #include "llvm/Support/Endian.h" 20 21 using namespace mlir; 22 using namespace mlir::detail; 23 24 //===----------------------------------------------------------------------===// 25 // AffineMapAttr 26 //===----------------------------------------------------------------------===// 27 28 AffineMapAttr AffineMapAttr::get(AffineMap value) { 29 return Base::get(value.getContext(), value); 30 } 31 32 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } 33 34 //===----------------------------------------------------------------------===// 35 // ArrayAttr 36 //===----------------------------------------------------------------------===// 37 38 ArrayAttr ArrayAttr::get(MLIRContext *context, ArrayRef<Attribute> value) { 39 return Base::get(context, value); 40 } 41 42 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } 43 44 Attribute ArrayAttr::operator[](unsigned idx) const { 45 assert(idx < size() && "index out of bounds"); 46 return getValue()[idx]; 47 } 48 49 //===----------------------------------------------------------------------===// 50 // DictionaryAttr 51 //===----------------------------------------------------------------------===// 52 53 /// Helper function that does either an in place sort or sorts from source array 54 /// into destination. If inPlace then storage is both the source and the 55 /// destination, else value is the source and storage destination. Returns 56 /// whether source was sorted. 57 template <bool inPlace> 58 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 59 SmallVectorImpl<NamedAttribute> &storage) { 60 // Specialize for the common case. 61 switch (value.size()) { 62 case 0: 63 // Zero already sorted. 64 break; 65 case 1: 66 // One already sorted but may need to be copied. 67 if (!inPlace) 68 storage.assign({value[0]}); 69 break; 70 case 2: { 71 bool isSorted = value[0] < value[1]; 72 if (inPlace) { 73 if (!isSorted) 74 std::swap(storage[0], storage[1]); 75 } else if (isSorted) { 76 storage.assign({value[0], value[1]}); 77 } else { 78 storage.assign({value[1], value[0]}); 79 } 80 return !isSorted; 81 } 82 default: 83 if (!inPlace) 84 storage.assign(value.begin(), value.end()); 85 // Check to see they are sorted already. 86 bool isSorted = llvm::is_sorted(value); 87 if (!isSorted) { 88 // If not, do a general sort. 89 llvm::array_pod_sort(storage.begin(), storage.end()); 90 value = storage; 91 } 92 return !isSorted; 93 } 94 return false; 95 } 96 97 /// Returns an entry with a duplicate name from the given sorted array of named 98 /// attributes. Returns llvm::None if all elements have unique names. 99 static Optional<NamedAttribute> 100 findDuplicateElement(ArrayRef<NamedAttribute> value) { 101 const Optional<NamedAttribute> none{llvm::None}; 102 if (value.size() < 2) 103 return none; 104 105 if (value.size() == 2) 106 return value[0].first == value[1].first ? value[0] : none; 107 108 auto it = std::adjacent_find( 109 value.begin(), value.end(), 110 [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; }); 111 return it != value.end() ? *it : none; 112 } 113 114 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 115 SmallVectorImpl<NamedAttribute> &storage) { 116 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); 117 assert(!findDuplicateElement(storage) && 118 "DictionaryAttr element names must be unique"); 119 return isSorted; 120 } 121 122 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 123 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); 124 assert(!findDuplicateElement(array) && 125 "DictionaryAttr element names must be unique"); 126 return isSorted; 127 } 128 129 Optional<NamedAttribute> 130 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, 131 bool isSorted) { 132 if (!isSorted) 133 dictionaryAttrSort</*inPlace=*/true>(array, array); 134 return findDuplicateElement(array); 135 } 136 137 DictionaryAttr DictionaryAttr::get(MLIRContext *context, 138 ArrayRef<NamedAttribute> value) { 139 if (value.empty()) 140 return DictionaryAttr::getEmpty(context); 141 assert(llvm::all_of(value, 142 [](const NamedAttribute &attr) { return attr.second; }) && 143 "value cannot have null entries"); 144 145 // We need to sort the element list to canonicalize it. 146 SmallVector<NamedAttribute, 8> storage; 147 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 148 value = storage; 149 assert(!findDuplicateElement(value) && 150 "DictionaryAttr element names must be unique"); 151 return Base::get(context, value); 152 } 153 /// Construct a dictionary with an array of values that is known to already be 154 /// sorted by name and uniqued. 155 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value, 156 MLIRContext *context) { 157 if (value.empty()) 158 return DictionaryAttr::getEmpty(context); 159 // Ensure that the attribute elements are unique and sorted. 160 assert(llvm::is_sorted(value, 161 [](NamedAttribute l, NamedAttribute r) { 162 return l.first.strref() < r.first.strref(); 163 }) && 164 "expected attribute values to be sorted"); 165 assert(!findDuplicateElement(value) && 166 "DictionaryAttr element names must be unique"); 167 return Base::get(context, value); 168 } 169 170 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const { 171 return getImpl()->getElements(); 172 } 173 174 /// Return the specified attribute if present, null otherwise. 175 Attribute DictionaryAttr::get(StringRef name) const { 176 Optional<NamedAttribute> attr = getNamed(name); 177 return attr ? attr->second : nullptr; 178 } 179 Attribute DictionaryAttr::get(Identifier name) const { 180 Optional<NamedAttribute> attr = getNamed(name); 181 return attr ? attr->second : nullptr; 182 } 183 184 /// Return the specified named attribute if present, None otherwise. 185 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 186 ArrayRef<NamedAttribute> values = getValue(); 187 const auto *it = llvm::lower_bound(values, name); 188 return it != values.end() && it->first == name ? *it 189 : Optional<NamedAttribute>(); 190 } 191 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { 192 for (auto elt : getValue()) 193 if (elt.first == name) 194 return elt; 195 return llvm::None; 196 } 197 198 DictionaryAttr::iterator DictionaryAttr::begin() const { 199 return getValue().begin(); 200 } 201 DictionaryAttr::iterator DictionaryAttr::end() const { 202 return getValue().end(); 203 } 204 size_t DictionaryAttr::size() const { return getValue().size(); } 205 206 //===----------------------------------------------------------------------===// 207 // FloatAttr 208 //===----------------------------------------------------------------------===// 209 210 FloatAttr FloatAttr::get(Type type, double value) { 211 return Base::get(type.getContext(), type, value); 212 } 213 214 FloatAttr FloatAttr::getChecked(function_ref<InFlightDiagnostic()> emitError, 215 Type type, double value) { 216 return Base::getChecked(emitError, type.getContext(), type, value); 217 } 218 219 FloatAttr FloatAttr::get(Type type, const APFloat &value) { 220 return Base::get(type.getContext(), type, value); 221 } 222 223 FloatAttr FloatAttr::getChecked(function_ref<InFlightDiagnostic()> emitError, 224 Type type, const APFloat &value) { 225 return Base::getChecked(emitError, type.getContext(), type, value); 226 } 227 228 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } 229 230 double FloatAttr::getValueAsDouble() const { 231 return getValueAsDouble(getValue()); 232 } 233 double FloatAttr::getValueAsDouble(APFloat value) { 234 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 235 bool losesInfo = false; 236 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 237 &losesInfo); 238 } 239 return value.convertToDouble(); 240 } 241 242 /// Verify construction invariants. 243 static LogicalResult 244 verifyFloatTypeInvariants(function_ref<InFlightDiagnostic()> emitError, 245 Type type) { 246 if (!type.isa<FloatType>()) 247 return emitError() << "expected floating point type"; 248 return success(); 249 } 250 251 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 252 Type type, double value) { 253 return verifyFloatTypeInvariants(emitError, type); 254 } 255 256 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 257 Type type, const APFloat &value) { 258 // Verify that the type is correct. 259 if (failed(verifyFloatTypeInvariants(emitError, type))) 260 return failure(); 261 262 // Verify that the type semantics match that of the value. 263 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 264 return emitError() 265 << "FloatAttr type doesn't match the type implied by its value"; 266 } 267 return success(); 268 } 269 270 //===----------------------------------------------------------------------===// 271 // SymbolRefAttr 272 //===----------------------------------------------------------------------===// 273 274 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { 275 return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>(); 276 } 277 278 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value, 279 ArrayRef<FlatSymbolRefAttr> nestedReferences) { 280 return Base::get(ctx, value, nestedReferences); 281 } 282 283 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } 284 285 StringRef SymbolRefAttr::getLeafReference() const { 286 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 287 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 288 } 289 290 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const { 291 return getImpl()->getNestedRefs(); 292 } 293 294 //===----------------------------------------------------------------------===// 295 // IntegerAttr 296 //===----------------------------------------------------------------------===// 297 298 IntegerAttr IntegerAttr::get(Type type, const APInt &value) { 299 if (type.isSignlessInteger(1)) 300 return BoolAttr::get(type.getContext(), value.getBoolValue()); 301 return Base::get(type.getContext(), type, value); 302 } 303 304 IntegerAttr IntegerAttr::get(Type type, int64_t value) { 305 // This uses 64 bit APInts by default for index type. 306 if (type.isIndex()) 307 return get(type, APInt(IndexType::kInternalStorageBitWidth, value)); 308 309 auto intType = type.cast<IntegerType>(); 310 return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger())); 311 } 312 313 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } 314 315 int64_t IntegerAttr::getInt() const { 316 assert((getImpl()->getType().isIndex() || 317 getImpl()->getType().isSignlessInteger()) && 318 "must be signless integer"); 319 return getValue().getSExtValue(); 320 } 321 322 int64_t IntegerAttr::getSInt() const { 323 assert(getImpl()->getType().isSignedInteger() && "must be signed integer"); 324 return getValue().getSExtValue(); 325 } 326 327 uint64_t IntegerAttr::getUInt() const { 328 assert(getImpl()->getType().isUnsignedInteger() && 329 "must be unsigned integer"); 330 return getValue().getZExtValue(); 331 } 332 333 static LogicalResult 334 verifyIntegerTypeInvariants(function_ref<InFlightDiagnostic()> emitError, 335 Type type) { 336 if (type.isa<IntegerType, IndexType>()) 337 return success(); 338 return emitError() << "expected integer or index type"; 339 } 340 341 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, 342 Type type, int64_t value) { 343 return verifyIntegerTypeInvariants(emitError, type); 344 } 345 346 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, 347 Type type, const APInt &value) { 348 if (failed(verifyIntegerTypeInvariants(emitError, type))) 349 return failure(); 350 if (auto integerType = type.dyn_cast<IntegerType>()) 351 if (integerType.getWidth() != value.getBitWidth()) 352 return emitError() << "integer type bit width (" << integerType.getWidth() 353 << ") doesn't match value bit width (" 354 << value.getBitWidth() << ")"; 355 return success(); 356 } 357 358 //===----------------------------------------------------------------------===// 359 // BoolAttr 360 361 bool BoolAttr::getValue() const { 362 auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl); 363 return storage->getValue().getBoolValue(); 364 } 365 366 bool BoolAttr::classof(Attribute attr) { 367 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 368 return intAttr && intAttr.getType().isSignlessInteger(1); 369 } 370 371 //===----------------------------------------------------------------------===// 372 // IntegerSetAttr 373 //===----------------------------------------------------------------------===// 374 375 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { 376 return Base::get(value.getConstraint(0).getContext(), value); 377 } 378 379 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } 380 381 //===----------------------------------------------------------------------===// 382 // OpaqueAttr 383 //===----------------------------------------------------------------------===// 384 385 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type) { 386 return Base::get(dialect.getContext(), dialect, attrData, type); 387 } 388 389 OpaqueAttr OpaqueAttr::getChecked(function_ref<InFlightDiagnostic()> emitError, 390 Identifier dialect, StringRef attrData, 391 Type type) { 392 return Base::getChecked(emitError, dialect.getContext(), dialect, attrData, 393 type); 394 } 395 396 /// Returns the dialect namespace of the opaque attribute. 397 Identifier OpaqueAttr::getDialectNamespace() const { 398 return getImpl()->dialectNamespace; 399 } 400 401 /// Returns the raw attribute data of the opaque attribute. 402 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } 403 404 /// Verify the construction of an opaque attribute. 405 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError, 406 Identifier dialect, StringRef attrData, 407 Type type) { 408 if (!Dialect::isValidNamespace(dialect.strref())) 409 return emitError() << "invalid dialect namespace '" << dialect << "'"; 410 return success(); 411 } 412 413 //===----------------------------------------------------------------------===// 414 // StringAttr 415 //===----------------------------------------------------------------------===// 416 417 StringAttr StringAttr::get(MLIRContext *context, StringRef bytes) { 418 return get(bytes, NoneType::get(context)); 419 } 420 421 /// Get an instance of a StringAttr with the given string and Type. 422 StringAttr StringAttr::get(StringRef bytes, Type type) { 423 return Base::get(type.getContext(), bytes, type); 424 } 425 426 StringRef StringAttr::getValue() const { return getImpl()->value; } 427 428 //===----------------------------------------------------------------------===// 429 // TypeAttr 430 //===----------------------------------------------------------------------===// 431 432 TypeAttr TypeAttr::get(Type value) { 433 return Base::get(value.getContext(), value); 434 } 435 436 Type TypeAttr::getValue() const { return getImpl()->value; } 437 438 //===----------------------------------------------------------------------===// 439 // ElementsAttr 440 //===----------------------------------------------------------------------===// 441 442 ShapedType ElementsAttr::getType() const { 443 return Attribute::getType().cast<ShapedType>(); 444 } 445 446 /// Returns the number of elements held by this attribute. 447 int64_t ElementsAttr::getNumElements() const { 448 return getType().getNumElements(); 449 } 450 451 /// Return the value at the given index. If index does not refer to a valid 452 /// element, then a null attribute is returned. 453 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 454 if (auto denseAttr = dyn_cast<DenseElementsAttr>()) 455 return denseAttr.getValue(index); 456 if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>()) 457 return opaqueAttr.getValue(index); 458 return cast<SparseElementsAttr>().getValue(index); 459 } 460 461 /// Return if the given 'index' refers to a valid element in this attribute. 462 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 463 auto type = getType(); 464 465 // Verify that the rank of the indices matches the held type. 466 auto rank = type.getRank(); 467 if (rank == 0 && index.size() == 1 && index[0] == 0) 468 return true; 469 if (rank != static_cast<int64_t>(index.size())) 470 return false; 471 472 // Verify that all of the indices are within the shape dimensions. 473 auto shape = type.getShape(); 474 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 475 return static_cast<int64_t>(index[i]) < shape[i]; 476 }); 477 } 478 479 ElementsAttr 480 ElementsAttr::mapValues(Type newElementType, 481 function_ref<APInt(const APInt &)> mapping) const { 482 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 483 return intOrFpAttr.mapValues(newElementType, mapping); 484 llvm_unreachable("unsupported ElementsAttr subtype"); 485 } 486 487 ElementsAttr 488 ElementsAttr::mapValues(Type newElementType, 489 function_ref<APInt(const APFloat &)> mapping) const { 490 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 491 return intOrFpAttr.mapValues(newElementType, mapping); 492 llvm_unreachable("unsupported ElementsAttr subtype"); 493 } 494 495 /// Method for support type inquiry through isa, cast and dyn_cast. 496 bool ElementsAttr::classof(Attribute attr) { 497 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr, 498 OpaqueElementsAttr, SparseElementsAttr>(); 499 } 500 501 /// Returns the 1 dimensional flattened row-major index from the given 502 /// multi-dimensional index. 503 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 504 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 505 auto type = getType(); 506 507 // Reduce the provided multidimensional index into a flattended 1D row-major 508 // index. 509 auto rank = type.getRank(); 510 auto shape = type.getShape(); 511 uint64_t valueIndex = 0; 512 uint64_t dimMultiplier = 1; 513 for (int i = rank - 1; i >= 0; --i) { 514 valueIndex += index[i] * dimMultiplier; 515 dimMultiplier *= shape[i]; 516 } 517 return valueIndex; 518 } 519 520 //===----------------------------------------------------------------------===// 521 // DenseElementsAttr Utilities 522 //===----------------------------------------------------------------------===// 523 524 /// Get the bitwidth of a dense element type within the buffer. 525 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 526 static size_t getDenseElementStorageWidth(size_t origWidth) { 527 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 528 } 529 static size_t getDenseElementStorageWidth(Type elementType) { 530 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 531 } 532 533 /// Set a bit to a specific value. 534 static void setBit(char *rawData, size_t bitPos, bool value) { 535 if (value) 536 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 537 else 538 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 539 } 540 541 /// Return the value of the specified bit. 542 static bool getBit(const char *rawData, size_t bitPos) { 543 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 544 } 545 546 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for 547 /// BE format. 548 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, 549 char *result) { 550 assert(llvm::support::endian::system_endianness() == // NOLINT 551 llvm::support::endianness::big); // NOLINT 552 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 553 554 // Copy the words filled with data. 555 // For example, when `value` has 2 words, the first word is filled with data. 556 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| 557 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 558 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 559 numFilledWords, result); 560 // Convert last word of APInt to LE format and store it in char 561 // array(`valueLE`). 562 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| 563 size_t lastWordPos = numFilledWords; 564 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); 565 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 566 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, 567 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); 568 // Extract actual APInt data from `valueLE`, convert endianness to BE format, 569 // and store it in `result`. 570 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| 571 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 572 valueLE.begin(), result + lastWordPos, 573 (numBytes - lastWordPos) * CHAR_BIT, 1); 574 } 575 576 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE 577 /// format. 578 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, 579 APInt &result) { 580 assert(llvm::support::endian::system_endianness() == // NOLINT 581 llvm::support::endianness::big); // NOLINT 582 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 583 584 // Copy the data that fills the word of `result` from `inArray`. 585 // For example, when `result` has 2 words, the first word will be filled with 586 // data. So, the first 8 bytes are copied from `inArray` here. 587 // `inArray` (10 bytes, BE): |abcdefgh|ij| 588 // ==> `result` (2 words, BE): |abcdefgh|--------| 589 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 590 std::copy_n( 591 inArray, numFilledWords, 592 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 593 594 // Convert array data which will be last word of `result` to LE format, and 595 // store it in char array(`inArrayLE`). 596 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| 597 size_t lastWordPos = numFilledWords; 598 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); 599 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 600 inArray + lastWordPos, inArrayLE.begin(), 601 (numBytes - lastWordPos) * CHAR_BIT, 1); 602 603 // Convert `inArrayLE` to BE format, and store it in last word of `result`. 604 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| 605 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 606 inArrayLE.begin(), 607 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + 608 lastWordPos, 609 APInt::APINT_BITS_PER_WORD, 1); 610 } 611 612 /// Writes value to the bit position `bitPos` in array `rawData`. 613 static void writeBits(char *rawData, size_t bitPos, APInt value) { 614 size_t bitWidth = value.getBitWidth(); 615 616 // If the bitwidth is 1 we just toggle the specific bit. 617 if (bitWidth == 1) 618 return setBit(rawData, bitPos, value.isOneValue()); 619 620 // Otherwise, the bit position is guaranteed to be byte aligned. 621 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 622 if (llvm::support::endian::system_endianness() == 623 llvm::support::endianness::big) { 624 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. 625 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 626 // work correctly in BE format. 627 // ex. `value` (2 words including 10 bytes) 628 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| 629 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), 630 rawData + (bitPos / CHAR_BIT)); 631 } else { 632 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 633 llvm::divideCeil(bitWidth, CHAR_BIT), 634 rawData + (bitPos / CHAR_BIT)); 635 } 636 } 637 638 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 639 /// `rawData`. 640 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 641 // Handle a boolean bit position. 642 if (bitWidth == 1) 643 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 644 645 // Otherwise, the bit position must be 8-bit aligned. 646 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 647 APInt result(bitWidth, 0); 648 if (llvm::support::endian::system_endianness() == 649 llvm::support::endianness::big) { 650 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. 651 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 652 // work correctly in BE format. 653 // ex. `result` (2 words including 10 bytes) 654 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function 655 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), 656 llvm::divideCeil(bitWidth, CHAR_BIT), result); 657 } else { 658 std::copy_n(rawData + (bitPos / CHAR_BIT), 659 llvm::divideCeil(bitWidth, CHAR_BIT), 660 const_cast<char *>( 661 reinterpret_cast<const char *>(result.getRawData()))); 662 } 663 return result; 664 } 665 666 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 667 /// the same element count as 'type'. 668 template <typename Values> 669 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 670 return (values.size() == 1) || 671 (type.getNumElements() == static_cast<int64_t>(values.size())); 672 } 673 674 //===----------------------------------------------------------------------===// 675 // DenseElementsAttr Iterators 676 //===----------------------------------------------------------------------===// 677 678 //===----------------------------------------------------------------------===// 679 // AttributeElementIterator 680 681 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 682 DenseElementsAttr attr, size_t index) 683 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 684 Attribute, Attribute, Attribute>( 685 attr.getAsOpaquePointer(), index) {} 686 687 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 688 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 689 Type eltTy = owner.getType().getElementType(); 690 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 691 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 692 if (eltTy.isa<IndexType>()) 693 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 694 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 695 IntElementIterator intIt(owner, index); 696 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 697 return FloatAttr::get(eltTy, *floatIt); 698 } 699 if (owner.isa<DenseStringElementsAttr>()) { 700 ArrayRef<StringRef> vals = owner.getRawStringData(); 701 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 702 } 703 llvm_unreachable("unexpected element type"); 704 } 705 706 //===----------------------------------------------------------------------===// 707 // BoolElementIterator 708 709 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 710 DenseElementsAttr attr, size_t dataIndex) 711 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 712 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 713 714 bool DenseElementsAttr::BoolElementIterator::operator*() const { 715 return getBit(getData(), getDataIndex()); 716 } 717 718 //===----------------------------------------------------------------------===// 719 // IntElementIterator 720 721 DenseElementsAttr::IntElementIterator::IntElementIterator( 722 DenseElementsAttr attr, size_t dataIndex) 723 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 724 attr.getRawData().data(), attr.isSplat(), dataIndex), 725 bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} 726 727 APInt DenseElementsAttr::IntElementIterator::operator*() const { 728 return readBits(getData(), 729 getDataIndex() * getDenseElementStorageWidth(bitWidth), 730 bitWidth); 731 } 732 733 //===----------------------------------------------------------------------===// 734 // ComplexIntElementIterator 735 736 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 737 DenseElementsAttr attr, size_t dataIndex) 738 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 739 std::complex<APInt>, std::complex<APInt>, 740 std::complex<APInt>>( 741 attr.getRawData().data(), attr.isSplat(), dataIndex) { 742 auto complexType = attr.getType().getElementType().cast<ComplexType>(); 743 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 744 } 745 746 std::complex<APInt> 747 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 748 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 749 size_t offset = getDataIndex() * storageWidth * 2; 750 return {readBits(getData(), offset, bitWidth), 751 readBits(getData(), offset + storageWidth, bitWidth)}; 752 } 753 754 //===----------------------------------------------------------------------===// 755 // FloatElementIterator 756 757 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 758 const llvm::fltSemantics &smt, IntElementIterator it) 759 : llvm::mapped_iterator<IntElementIterator, 760 std::function<APFloat(const APInt &)>>( 761 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 762 763 //===----------------------------------------------------------------------===// 764 // ComplexFloatElementIterator 765 766 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( 767 const llvm::fltSemantics &smt, ComplexIntElementIterator it) 768 : llvm::mapped_iterator< 769 ComplexIntElementIterator, 770 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( 771 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { 772 return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; 773 }) {} 774 775 //===----------------------------------------------------------------------===// 776 // DenseElementsAttr 777 //===----------------------------------------------------------------------===// 778 779 /// Method for support type inquiry through isa, cast and dyn_cast. 780 bool DenseElementsAttr::classof(Attribute attr) { 781 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 782 } 783 784 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 785 ArrayRef<Attribute> values) { 786 assert(hasSameElementsOrSplat(type, values)); 787 788 // If the element type is not based on int/float/index, assume it is a string 789 // type. 790 auto eltType = type.getElementType(); 791 if (!type.getElementType().isIntOrIndexOrFloat()) { 792 SmallVector<StringRef, 8> stringValues; 793 stringValues.reserve(values.size()); 794 for (Attribute attr : values) { 795 assert(attr.isa<StringAttr>() && 796 "expected string value for non integer/index/float element"); 797 stringValues.push_back(attr.cast<StringAttr>().getValue()); 798 } 799 return get(type, stringValues); 800 } 801 802 // Otherwise, get the raw storage width to use for the allocation. 803 size_t bitWidth = getDenseElementBitWidth(eltType); 804 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 805 806 // Compress the attribute values into a character buffer. 807 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 808 values.size()); 809 APInt intVal; 810 for (unsigned i = 0, e = values.size(); i < e; ++i) { 811 assert(eltType == values[i].getType() && 812 "expected attribute value to have element type"); 813 if (eltType.isa<FloatType>()) 814 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 815 else if (eltType.isa<IntegerType>()) 816 intVal = values[i].cast<IntegerAttr>().getValue(); 817 else 818 llvm_unreachable("unexpected element type"); 819 820 assert(intVal.getBitWidth() == bitWidth && 821 "expected value to have same bitwidth as element type"); 822 writeBits(data.data(), i * storageBitWidth, intVal); 823 } 824 return DenseIntOrFPElementsAttr::getRaw(type, data, 825 /*isSplat=*/(values.size() == 1)); 826 } 827 828 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 829 ArrayRef<bool> values) { 830 assert(hasSameElementsOrSplat(type, values)); 831 assert(type.getElementType().isInteger(1)); 832 833 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 834 for (int i = 0, e = values.size(); i != e; ++i) 835 setBit(buff.data(), i, values[i]); 836 return DenseIntOrFPElementsAttr::getRaw(type, buff, 837 /*isSplat=*/(values.size() == 1)); 838 } 839 840 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 841 ArrayRef<StringRef> values) { 842 assert(!type.getElementType().isIntOrFloat()); 843 return DenseStringElementsAttr::get(type, values); 844 } 845 846 /// Constructs a dense integer elements attribute from an array of APInt 847 /// values. Each APInt value is expected to have the same bitwidth as the 848 /// element type of 'type'. 849 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 850 ArrayRef<APInt> values) { 851 assert(type.getElementType().isIntOrIndex()); 852 assert(hasSameElementsOrSplat(type, values)); 853 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 854 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 855 /*isSplat=*/(values.size() == 1)); 856 } 857 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 858 ArrayRef<std::complex<APInt>> values) { 859 ComplexType complex = type.getElementType().cast<ComplexType>(); 860 assert(complex.getElementType().isa<IntegerType>()); 861 assert(hasSameElementsOrSplat(type, values)); 862 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 863 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 864 values.size() * 2); 865 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, 866 /*isSplat=*/(values.size() == 1)); 867 } 868 869 // Constructs a dense float elements attribute from an array of APFloat 870 // values. Each APFloat value is expected to have the same bitwidth as the 871 // element type of 'type'. 872 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 873 ArrayRef<APFloat> values) { 874 assert(type.getElementType().isa<FloatType>()); 875 assert(hasSameElementsOrSplat(type, values)); 876 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 877 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 878 /*isSplat=*/(values.size() == 1)); 879 } 880 DenseElementsAttr 881 DenseElementsAttr::get(ShapedType type, 882 ArrayRef<std::complex<APFloat>> values) { 883 ComplexType complex = type.getElementType().cast<ComplexType>(); 884 assert(complex.getElementType().isa<FloatType>()); 885 assert(hasSameElementsOrSplat(type, values)); 886 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 887 values.size() * 2); 888 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 889 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, 890 /*isSplat=*/(values.size() == 1)); 891 } 892 893 /// Construct a dense elements attribute from a raw buffer representing the 894 /// data for this attribute. Users should generally not use this methods as 895 /// the expected buffer format may not be a form the user expects. 896 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 897 ArrayRef<char> rawBuffer, 898 bool isSplatBuffer) { 899 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 900 } 901 902 /// Returns true if the given buffer is a valid raw buffer for the given type. 903 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 904 ArrayRef<char> rawBuffer, 905 bool &detectedSplat) { 906 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 907 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 908 909 // Storage width of 1 is special as it is packed by the bit. 910 if (storageWidth == 1) { 911 // Check for a splat, or a buffer equal to the number of elements. 912 if ((detectedSplat = rawBuffer.size() == 1)) 913 return true; 914 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); 915 } 916 // All other types are 8-bit aligned. 917 if ((detectedSplat = rawBufferWidth == storageWidth)) 918 return true; 919 return rawBufferWidth == (storageWidth * type.getNumElements()); 920 } 921 922 /// Check the information for a C++ data type, check if this type is valid for 923 /// the current attribute. This method is used to verify specific type 924 /// invariants that the templatized 'getValues' method cannot. 925 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 926 bool isSigned) { 927 // Make sure that the data element size is the same as the type element width. 928 if (getDenseElementBitWidth(type) != 929 static_cast<size_t>(dataEltSize * CHAR_BIT)) 930 return false; 931 932 // Check that the element type is either float or integer or index. 933 if (!isInt) 934 return type.isa<FloatType>(); 935 if (type.isIndex()) 936 return true; 937 938 auto intType = type.dyn_cast<IntegerType>(); 939 if (!intType) 940 return false; 941 942 // Make sure signedness semantics is consistent. 943 if (intType.isSignless()) 944 return true; 945 return intType.isSigned() ? isSigned : !isSigned; 946 } 947 948 /// Defaults down the subclass implementation. 949 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 950 ArrayRef<char> data, 951 int64_t dataEltSize, 952 bool isInt, bool isSigned) { 953 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 954 isSigned); 955 } 956 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 957 ArrayRef<char> data, 958 int64_t dataEltSize, 959 bool isInt, 960 bool isSigned) { 961 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 962 isInt, isSigned); 963 } 964 965 /// A method used to verify specific type invariants that the templatized 'get' 966 /// method cannot. 967 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 968 bool isSigned) const { 969 return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt, 970 isSigned); 971 } 972 973 /// Check the information for a C++ data type, check if this type is valid for 974 /// the current attribute. 975 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 976 bool isSigned) const { 977 return ::isValidIntOrFloat( 978 getType().getElementType().cast<ComplexType>().getElementType(), 979 dataEltSize / 2, isInt, isSigned); 980 } 981 982 /// Returns true if this attribute corresponds to a splat, i.e. if all element 983 /// values are the same. 984 bool DenseElementsAttr::isSplat() const { 985 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 986 } 987 988 /// Return the held element values as a range of Attributes. 989 auto DenseElementsAttr::getAttributeValues() const 990 -> llvm::iterator_range<AttributeElementIterator> { 991 return {attr_value_begin(), attr_value_end()}; 992 } 993 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 994 return AttributeElementIterator(*this, 0); 995 } 996 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 997 return AttributeElementIterator(*this, getNumElements()); 998 } 999 1000 /// Return the held element values as a range of bool. The element type of 1001 /// this attribute must be of integer type of bitwidth 1. 1002 auto DenseElementsAttr::getBoolValues() const 1003 -> llvm::iterator_range<BoolElementIterator> { 1004 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 1005 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 1006 (void)eltType; 1007 return {BoolElementIterator(*this, 0), 1008 BoolElementIterator(*this, getNumElements())}; 1009 } 1010 1011 /// Return the held element values as a range of APInts. The element type of 1012 /// this attribute must be of integer type. 1013 auto DenseElementsAttr::getIntValues() const 1014 -> llvm::iterator_range<IntElementIterator> { 1015 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 1016 return {raw_int_begin(), raw_int_end()}; 1017 } 1018 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 1019 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 1020 return raw_int_begin(); 1021 } 1022 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 1023 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 1024 return raw_int_end(); 1025 } 1026 auto DenseElementsAttr::getComplexIntValues() const 1027 -> llvm::iterator_range<ComplexIntElementIterator> { 1028 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 1029 (void)eltTy; 1030 assert(eltTy.isa<IntegerType>() && "expected complex integral type"); 1031 return {ComplexIntElementIterator(*this, 0), 1032 ComplexIntElementIterator(*this, getNumElements())}; 1033 } 1034 1035 /// Return the held element values as a range of APFloat. The element type of 1036 /// this attribute must be of float type. 1037 auto DenseElementsAttr::getFloatValues() const 1038 -> llvm::iterator_range<FloatElementIterator> { 1039 auto elementType = getType().getElementType().cast<FloatType>(); 1040 const auto &elementSemantics = elementType.getFloatSemantics(); 1041 return {FloatElementIterator(elementSemantics, raw_int_begin()), 1042 FloatElementIterator(elementSemantics, raw_int_end())}; 1043 } 1044 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 1045 return getFloatValues().begin(); 1046 } 1047 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 1048 return getFloatValues().end(); 1049 } 1050 auto DenseElementsAttr::getComplexFloatValues() const 1051 -> llvm::iterator_range<ComplexFloatElementIterator> { 1052 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 1053 assert(eltTy.isa<FloatType>() && "expected complex float type"); 1054 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 1055 return {{semantics, {*this, 0}}, 1056 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 1057 } 1058 1059 /// Return the raw storage data held by this attribute. 1060 ArrayRef<char> DenseElementsAttr::getRawData() const { 1061 return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data; 1062 } 1063 1064 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 1065 return static_cast<DenseStringElementsAttributeStorage *>(impl)->data; 1066 } 1067 1068 /// Return a new DenseElementsAttr that has the same data as the current 1069 /// attribute, but has been reshaped to 'newType'. The new type must have the 1070 /// same total number of elements as well as element type. 1071 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 1072 ShapedType curType = getType(); 1073 if (curType == newType) 1074 return *this; 1075 1076 (void)curType; 1077 assert(newType.getElementType() == curType.getElementType() && 1078 "expected the same element type"); 1079 assert(newType.getNumElements() == curType.getNumElements() && 1080 "expected the same number of elements"); 1081 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 1082 } 1083 1084 DenseElementsAttr 1085 DenseElementsAttr::mapValues(Type newElementType, 1086 function_ref<APInt(const APInt &)> mapping) const { 1087 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 1088 } 1089 1090 DenseElementsAttr DenseElementsAttr::mapValues( 1091 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1092 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 1093 } 1094 1095 //===----------------------------------------------------------------------===// 1096 // DenseStringElementsAttr 1097 //===----------------------------------------------------------------------===// 1098 1099 DenseStringElementsAttr 1100 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) { 1101 return Base::get(type.getContext(), type, values, (values.size() == 1)); 1102 } 1103 1104 //===----------------------------------------------------------------------===// 1105 // DenseIntOrFPElementsAttr 1106 //===----------------------------------------------------------------------===// 1107 1108 /// Utility method to write a range of APInt values to a buffer. 1109 template <typename APRangeT> 1110 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1111 APRangeT &&values) { 1112 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); 1113 size_t offset = 0; 1114 for (auto it = values.begin(), e = values.end(); it != e; 1115 ++it, offset += storageWidth) { 1116 assert((*it).getBitWidth() <= storageWidth); 1117 writeBits(data.data(), offset, *it); 1118 } 1119 } 1120 1121 /// Constructs a dense elements attribute from an array of raw APFloat values. 1122 /// Each APFloat value is expected to have the same bitwidth as the element 1123 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1124 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1125 size_t storageWidth, 1126 ArrayRef<APFloat> values, 1127 bool isSplat) { 1128 std::vector<char> data; 1129 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1130 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1131 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1132 } 1133 1134 /// Constructs a dense elements attribute from an array of raw APInt values. 1135 /// Each APInt value is expected to have the same bitwidth as the element type 1136 /// of 'type'. 1137 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1138 size_t storageWidth, 1139 ArrayRef<APInt> values, 1140 bool isSplat) { 1141 std::vector<char> data; 1142 writeAPIntsToBuffer(storageWidth, data, values); 1143 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1144 } 1145 1146 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1147 ArrayRef<char> data, 1148 bool isSplat) { 1149 assert((type.isa<RankedTensorType, VectorType>()) && 1150 "type must be ranked tensor or vector"); 1151 assert(type.hasStaticShape() && "type must have static shape"); 1152 return Base::get(type.getContext(), type, data, isSplat); 1153 } 1154 1155 /// Overload of the raw 'get' method that asserts that the given type is of 1156 /// complex type. This method is used to verify type invariants that the 1157 /// templatized 'get' method cannot. 1158 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1159 ArrayRef<char> data, 1160 int64_t dataEltSize, 1161 bool isInt, 1162 bool isSigned) { 1163 assert(::isValidIntOrFloat( 1164 type.getElementType().cast<ComplexType>().getElementType(), 1165 dataEltSize / 2, isInt, isSigned)); 1166 1167 int64_t numElements = data.size() / dataEltSize; 1168 assert(numElements == 1 || numElements == type.getNumElements()); 1169 return getRaw(type, data, /*isSplat=*/numElements == 1); 1170 } 1171 1172 /// Overload of the 'getRaw' method that asserts that the given type is of 1173 /// integer type. This method is used to verify type invariants that the 1174 /// templatized 'get' method cannot. 1175 DenseElementsAttr 1176 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1177 int64_t dataEltSize, bool isInt, 1178 bool isSigned) { 1179 assert( 1180 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1181 1182 int64_t numElements = data.size() / dataEltSize; 1183 assert(numElements == 1 || numElements == type.getNumElements()); 1184 return getRaw(type, data, /*isSplat=*/numElements == 1); 1185 } 1186 1187 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 1188 const char *inRawData, char *outRawData, size_t elementBitWidth, 1189 size_t numElements) { 1190 using llvm::support::ulittle16_t; 1191 using llvm::support::ulittle32_t; 1192 using llvm::support::ulittle64_t; 1193 1194 assert(llvm::support::endian::system_endianness() == // NOLINT 1195 llvm::support::endianness::big); // NOLINT 1196 // NOLINT to avoid warning message about replacing by static_assert() 1197 1198 // Following std::copy_n always converts endianness on BE machine. 1199 switch (elementBitWidth) { 1200 case 16: { 1201 const ulittle16_t *inRawDataPos = 1202 reinterpret_cast<const ulittle16_t *>(inRawData); 1203 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); 1204 std::copy_n(inRawDataPos, numElements, outDataPos); 1205 break; 1206 } 1207 case 32: { 1208 const ulittle32_t *inRawDataPos = 1209 reinterpret_cast<const ulittle32_t *>(inRawData); 1210 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); 1211 std::copy_n(inRawDataPos, numElements, outDataPos); 1212 break; 1213 } 1214 case 64: { 1215 const ulittle64_t *inRawDataPos = 1216 reinterpret_cast<const ulittle64_t *>(inRawData); 1217 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); 1218 std::copy_n(inRawDataPos, numElements, outDataPos); 1219 break; 1220 } 1221 default: { 1222 size_t nBytes = elementBitWidth / CHAR_BIT; 1223 for (size_t i = 0; i < nBytes; i++) 1224 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); 1225 break; 1226 } 1227 } 1228 } 1229 1230 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1231 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, 1232 ShapedType type) { 1233 size_t numElements = type.getNumElements(); 1234 Type elementType = type.getElementType(); 1235 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1236 elementType = complexTy.getElementType(); 1237 numElements = numElements * 2; 1238 } 1239 size_t elementBitWidth = getDenseElementStorageWidth(elementType); 1240 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && 1241 inRawData.size() <= outRawData.size()); 1242 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), 1243 elementBitWidth, numElements); 1244 } 1245 1246 //===----------------------------------------------------------------------===// 1247 // DenseFPElementsAttr 1248 //===----------------------------------------------------------------------===// 1249 1250 template <typename Fn, typename Attr> 1251 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1252 Type newElementType, 1253 llvm::SmallVectorImpl<char> &data) { 1254 size_t bitWidth = getDenseElementBitWidth(newElementType); 1255 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1256 1257 ShapedType newArrayType; 1258 if (inType.isa<RankedTensorType>()) 1259 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1260 else if (inType.isa<UnrankedTensorType>()) 1261 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1262 else if (inType.isa<VectorType>()) 1263 newArrayType = VectorType::get(inType.getShape(), newElementType); 1264 else 1265 assert(newArrayType && "Unhandled tensor type"); 1266 1267 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1268 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 1269 1270 // Functor used to process a single element value of the attribute. 1271 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1272 auto newInt = mapping(value); 1273 assert(newInt.getBitWidth() == bitWidth); 1274 writeBits(data.data(), index * storageBitWidth, newInt); 1275 }; 1276 1277 // Check for the splat case. 1278 if (attr.isSplat()) { 1279 processElt(*attr.begin(), /*index=*/0); 1280 return newArrayType; 1281 } 1282 1283 // Otherwise, process all of the element values. 1284 uint64_t elementIdx = 0; 1285 for (auto value : attr) 1286 processElt(value, elementIdx++); 1287 return newArrayType; 1288 } 1289 1290 DenseElementsAttr DenseFPElementsAttr::mapValues( 1291 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1292 llvm::SmallVector<char, 8> elementData; 1293 auto newArrayType = 1294 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1295 1296 return getRaw(newArrayType, elementData, isSplat()); 1297 } 1298 1299 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1300 bool DenseFPElementsAttr::classof(Attribute attr) { 1301 return attr.isa<DenseElementsAttr>() && 1302 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1303 } 1304 1305 //===----------------------------------------------------------------------===// 1306 // DenseIntElementsAttr 1307 //===----------------------------------------------------------------------===// 1308 1309 DenseElementsAttr DenseIntElementsAttr::mapValues( 1310 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1311 llvm::SmallVector<char, 8> elementData; 1312 auto newArrayType = 1313 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1314 1315 return getRaw(newArrayType, elementData, isSplat()); 1316 } 1317 1318 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1319 bool DenseIntElementsAttr::classof(Attribute attr) { 1320 return attr.isa<DenseElementsAttr>() && 1321 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1322 } 1323 1324 //===----------------------------------------------------------------------===// 1325 // OpaqueElementsAttr 1326 //===----------------------------------------------------------------------===// 1327 1328 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, 1329 StringRef bytes) { 1330 assert(TensorType::isValidElementType(type.getElementType()) && 1331 "Input element type should be a valid tensor element type"); 1332 return Base::get(type.getContext(), type, dialect, bytes); 1333 } 1334 1335 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } 1336 1337 /// Return the value at the given index. If index does not refer to a valid 1338 /// element, then a null attribute is returned. 1339 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1340 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1341 return Attribute(); 1342 } 1343 1344 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } 1345 1346 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1347 auto *d = getDialect(); 1348 if (!d) 1349 return true; 1350 auto *interface = 1351 d->getRegisteredInterface<DialectDecodeAttributesInterface>(); 1352 if (!interface) 1353 return true; 1354 return failed(interface->decode(*this, result)); 1355 } 1356 1357 //===----------------------------------------------------------------------===// 1358 // SparseElementsAttr 1359 //===----------------------------------------------------------------------===// 1360 1361 SparseElementsAttr SparseElementsAttr::get(ShapedType type, 1362 DenseElementsAttr indices, 1363 DenseElementsAttr values) { 1364 assert(indices.getType().getElementType().isInteger(64) && 1365 "expected sparse indices to be 64-bit integer values"); 1366 assert((type.isa<RankedTensorType, VectorType>()) && 1367 "type must be ranked tensor or vector"); 1368 assert(type.hasStaticShape() && "type must have static shape"); 1369 return Base::get(type.getContext(), type, 1370 indices.cast<DenseIntElementsAttr>(), values); 1371 } 1372 1373 DenseIntElementsAttr SparseElementsAttr::getIndices() const { 1374 return getImpl()->indices; 1375 } 1376 1377 DenseElementsAttr SparseElementsAttr::getValues() const { 1378 return getImpl()->values; 1379 } 1380 1381 /// Return the value of the element at the given index. 1382 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1383 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1384 auto type = getType(); 1385 1386 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1387 // as a 1-D index array. 1388 auto sparseIndices = getIndices(); 1389 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1390 1391 // Check to see if the indices are a splat. 1392 if (sparseIndices.isSplat()) { 1393 // If the index is also not a splat of the index value, we know that the 1394 // value is zero. 1395 auto splatIndex = *sparseIndexValues.begin(); 1396 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1397 return getZeroAttr(); 1398 1399 // If the indices are a splat, we also expect the values to be a splat. 1400 assert(getValues().isSplat() && "expected splat values"); 1401 return getValues().getSplatValue(); 1402 } 1403 1404 // Build a mapping between known indices and the offset of the stored element. 1405 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1406 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1407 size_t rank = type.getRank(); 1408 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1409 mappedIndices.try_emplace( 1410 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1411 1412 // Look for the provided index key within the mapped indices. If the provided 1413 // index is not found, then return a zero attribute. 1414 auto it = mappedIndices.find(index); 1415 if (it == mappedIndices.end()) 1416 return getZeroAttr(); 1417 1418 // Otherwise, return the held sparse value element. 1419 return getValues().getValue(it->second); 1420 } 1421 1422 /// Get a zero APFloat for the given sparse attribute. 1423 APFloat SparseElementsAttr::getZeroAPFloat() const { 1424 auto eltType = getType().getElementType().cast<FloatType>(); 1425 return APFloat(eltType.getFloatSemantics()); 1426 } 1427 1428 /// Get a zero APInt for the given sparse attribute. 1429 APInt SparseElementsAttr::getZeroAPInt() const { 1430 auto eltType = getType().getElementType().cast<IntegerType>(); 1431 return APInt::getNullValue(eltType.getWidth()); 1432 } 1433 1434 /// Get a zero attribute for the given attribute type. 1435 Attribute SparseElementsAttr::getZeroAttr() const { 1436 auto eltType = getType().getElementType(); 1437 1438 // Handle floating point elements. 1439 if (eltType.isa<FloatType>()) 1440 return FloatAttr::get(eltType, 0); 1441 1442 // Otherwise, this is an integer. 1443 // TODO: Handle StringAttr here. 1444 return IntegerAttr::get(eltType, 0); 1445 } 1446 1447 /// Flatten, and return, all of the sparse indices in this attribute in 1448 /// row-major order. 1449 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1450 std::vector<ptrdiff_t> flatSparseIndices; 1451 1452 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1453 // as a 1-D index array. 1454 auto sparseIndices = getIndices(); 1455 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1456 if (sparseIndices.isSplat()) { 1457 SmallVector<uint64_t, 8> indices(getType().getRank(), 1458 *sparseIndexValues.begin()); 1459 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1460 return flatSparseIndices; 1461 } 1462 1463 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1464 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1465 size_t rank = getType().getRank(); 1466 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1467 flatSparseIndices.push_back(getFlattenedIndex( 1468 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1469 return flatSparseIndices; 1470 } 1471