1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Attributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/BuiltinOps.h" 13 #include "mlir/IR/Diagnostics.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/IR/Types.h" 17 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/Twine.h" 20 #include "llvm/Support/Endian.h" 21 22 using namespace mlir; 23 using namespace mlir::detail; 24 25 //===----------------------------------------------------------------------===// 26 // AttributeStorage 27 //===----------------------------------------------------------------------===// 28 29 AttributeStorage::AttributeStorage(Type type) 30 : type(type.getAsOpaquePointer()) {} 31 AttributeStorage::AttributeStorage() : type(nullptr) {} 32 33 Type AttributeStorage::getType() const { 34 return Type::getFromOpaquePointer(type); 35 } 36 void AttributeStorage::setType(Type newType) { 37 type = newType.getAsOpaquePointer(); 38 } 39 40 //===----------------------------------------------------------------------===// 41 // Attribute 42 //===----------------------------------------------------------------------===// 43 44 /// Return the type of this attribute. 45 Type Attribute::getType() const { return impl->getType(); } 46 47 /// Return the context this attribute belongs to. 48 MLIRContext *Attribute::getContext() const { return getType().getContext(); } 49 50 /// Get the dialect this attribute is registered to. 51 Dialect &Attribute::getDialect() const { 52 return impl->getAbstractAttribute().getDialect(); 53 } 54 55 //===----------------------------------------------------------------------===// 56 // AffineMapAttr 57 //===----------------------------------------------------------------------===// 58 59 AffineMapAttr AffineMapAttr::get(AffineMap value) { 60 return Base::get(value.getContext(), value); 61 } 62 63 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } 64 65 //===----------------------------------------------------------------------===// 66 // ArrayAttr 67 //===----------------------------------------------------------------------===// 68 69 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) { 70 return Base::get(context, value); 71 } 72 73 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } 74 75 Attribute ArrayAttr::operator[](unsigned idx) const { 76 assert(idx < size() && "index out of bounds"); 77 return getValue()[idx]; 78 } 79 80 //===----------------------------------------------------------------------===// 81 // DictionaryAttr 82 //===----------------------------------------------------------------------===// 83 84 /// Helper function that does either an in place sort or sorts from source array 85 /// into destination. If inPlace then storage is both the source and the 86 /// destination, else value is the source and storage destination. Returns 87 /// whether source was sorted. 88 template <bool inPlace> 89 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 90 SmallVectorImpl<NamedAttribute> &storage) { 91 // Specialize for the common case. 92 switch (value.size()) { 93 case 0: 94 // Zero already sorted. 95 break; 96 case 1: 97 // One already sorted but may need to be copied. 98 if (!inPlace) 99 storage.assign({value[0]}); 100 break; 101 case 2: { 102 bool isSorted = value[0] < value[1]; 103 if (inPlace) { 104 if (!isSorted) 105 std::swap(storage[0], storage[1]); 106 } else if (isSorted) { 107 storage.assign({value[0], value[1]}); 108 } else { 109 storage.assign({value[1], value[0]}); 110 } 111 return !isSorted; 112 } 113 default: 114 if (!inPlace) 115 storage.assign(value.begin(), value.end()); 116 // Check to see they are sorted already. 117 bool isSorted = llvm::is_sorted(value); 118 if (!isSorted) { 119 // If not, do a general sort. 120 llvm::array_pod_sort(storage.begin(), storage.end()); 121 value = storage; 122 } 123 return !isSorted; 124 } 125 return false; 126 } 127 128 /// Returns an entry with a duplicate name from the given sorted array of named 129 /// attributes. Returns llvm::None if all elements have unique names. 130 static Optional<NamedAttribute> 131 findDuplicateElement(ArrayRef<NamedAttribute> value) { 132 const Optional<NamedAttribute> none{llvm::None}; 133 if (value.size() < 2) 134 return none; 135 136 if (value.size() == 2) 137 return value[0].first == value[1].first ? value[0] : none; 138 139 auto it = std::adjacent_find( 140 value.begin(), value.end(), 141 [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; }); 142 return it != value.end() ? *it : none; 143 } 144 145 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 146 SmallVectorImpl<NamedAttribute> &storage) { 147 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); 148 assert(!findDuplicateElement(storage) && 149 "DictionaryAttr element names must be unique"); 150 return isSorted; 151 } 152 153 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 154 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); 155 assert(!findDuplicateElement(array) && 156 "DictionaryAttr element names must be unique"); 157 return isSorted; 158 } 159 160 Optional<NamedAttribute> 161 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, 162 bool isSorted) { 163 if (!isSorted) 164 dictionaryAttrSort</*inPlace=*/true>(array, array); 165 return findDuplicateElement(array); 166 } 167 168 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value, 169 MLIRContext *context) { 170 if (value.empty()) 171 return DictionaryAttr::getEmpty(context); 172 assert(llvm::all_of(value, 173 [](const NamedAttribute &attr) { return attr.second; }) && 174 "value cannot have null entries"); 175 176 // We need to sort the element list to canonicalize it. 177 SmallVector<NamedAttribute, 8> storage; 178 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 179 value = storage; 180 assert(!findDuplicateElement(value) && 181 "DictionaryAttr element names must be unique"); 182 return Base::get(context, value); 183 } 184 /// Construct a dictionary with an array of values that is known to already be 185 /// sorted by name and uniqued. 186 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value, 187 MLIRContext *context) { 188 if (value.empty()) 189 return DictionaryAttr::getEmpty(context); 190 // Ensure that the attribute elements are unique and sorted. 191 assert(llvm::is_sorted(value, 192 [](NamedAttribute l, NamedAttribute r) { 193 return l.first.strref() < r.first.strref(); 194 }) && 195 "expected attribute values to be sorted"); 196 assert(!findDuplicateElement(value) && 197 "DictionaryAttr element names must be unique"); 198 return Base::get(context, value); 199 } 200 201 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const { 202 return getImpl()->getElements(); 203 } 204 205 /// Return the specified attribute if present, null otherwise. 206 Attribute DictionaryAttr::get(StringRef name) const { 207 Optional<NamedAttribute> attr = getNamed(name); 208 return attr ? attr->second : nullptr; 209 } 210 Attribute DictionaryAttr::get(Identifier name) const { 211 Optional<NamedAttribute> attr = getNamed(name); 212 return attr ? attr->second : nullptr; 213 } 214 215 /// Return the specified named attribute if present, None otherwise. 216 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 217 ArrayRef<NamedAttribute> values = getValue(); 218 const auto *it = llvm::lower_bound(values, name); 219 return it != values.end() && it->first == name ? *it 220 : Optional<NamedAttribute>(); 221 } 222 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { 223 for (auto elt : getValue()) 224 if (elt.first == name) 225 return elt; 226 return llvm::None; 227 } 228 229 DictionaryAttr::iterator DictionaryAttr::begin() const { 230 return getValue().begin(); 231 } 232 DictionaryAttr::iterator DictionaryAttr::end() const { 233 return getValue().end(); 234 } 235 size_t DictionaryAttr::size() const { return getValue().size(); } 236 237 //===----------------------------------------------------------------------===// 238 // FloatAttr 239 //===----------------------------------------------------------------------===// 240 241 FloatAttr FloatAttr::get(Type type, double value) { 242 return Base::get(type.getContext(), type, value); 243 } 244 245 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { 246 return Base::getChecked(loc, type, value); 247 } 248 249 FloatAttr FloatAttr::get(Type type, const APFloat &value) { 250 return Base::get(type.getContext(), type, value); 251 } 252 253 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { 254 return Base::getChecked(loc, type, value); 255 } 256 257 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } 258 259 double FloatAttr::getValueAsDouble() const { 260 return getValueAsDouble(getValue()); 261 } 262 double FloatAttr::getValueAsDouble(APFloat value) { 263 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 264 bool losesInfo = false; 265 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 266 &losesInfo); 267 } 268 return value.convertToDouble(); 269 } 270 271 /// Verify construction invariants. 272 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) { 273 if (!type.isa<FloatType>()) 274 return emitError(loc, "expected floating point type"); 275 return success(); 276 } 277 278 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 279 double value) { 280 return verifyFloatTypeInvariants(loc, type); 281 } 282 283 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 284 const APFloat &value) { 285 // Verify that the type is correct. 286 if (failed(verifyFloatTypeInvariants(loc, type))) 287 return failure(); 288 289 // Verify that the type semantics match that of the value. 290 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 291 return emitError( 292 loc, "FloatAttr type doesn't match the type implied by its value"); 293 } 294 return success(); 295 } 296 297 //===----------------------------------------------------------------------===// 298 // SymbolRefAttr 299 //===----------------------------------------------------------------------===// 300 301 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { 302 return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>(); 303 } 304 305 SymbolRefAttr SymbolRefAttr::get(StringRef value, 306 ArrayRef<FlatSymbolRefAttr> nestedReferences, 307 MLIRContext *ctx) { 308 return Base::get(ctx, value, nestedReferences); 309 } 310 311 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } 312 313 StringRef SymbolRefAttr::getLeafReference() const { 314 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 315 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 316 } 317 318 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const { 319 return getImpl()->getNestedRefs(); 320 } 321 322 //===----------------------------------------------------------------------===// 323 // IntegerAttr 324 //===----------------------------------------------------------------------===// 325 326 IntegerAttr IntegerAttr::get(Type type, const APInt &value) { 327 if (type.isSignlessInteger(1)) 328 return BoolAttr::get(value.getBoolValue(), type.getContext()); 329 return Base::get(type.getContext(), type, value); 330 } 331 332 IntegerAttr IntegerAttr::get(Type type, int64_t value) { 333 // This uses 64 bit APInts by default for index type. 334 if (type.isIndex()) 335 return get(type, APInt(IndexType::kInternalStorageBitWidth, value)); 336 337 auto intType = type.cast<IntegerType>(); 338 return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger())); 339 } 340 341 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } 342 343 int64_t IntegerAttr::getInt() const { 344 assert((getImpl()->getType().isIndex() || 345 getImpl()->getType().isSignlessInteger()) && 346 "must be signless integer"); 347 return getValue().getSExtValue(); 348 } 349 350 int64_t IntegerAttr::getSInt() const { 351 assert(getImpl()->getType().isSignedInteger() && "must be signed integer"); 352 return getValue().getSExtValue(); 353 } 354 355 uint64_t IntegerAttr::getUInt() const { 356 assert(getImpl()->getType().isUnsignedInteger() && 357 "must be unsigned integer"); 358 return getValue().getZExtValue(); 359 } 360 361 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) { 362 if (type.isa<IntegerType, IndexType>()) 363 return success(); 364 return emitError(loc, "expected integer or index type"); 365 } 366 367 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 368 int64_t value) { 369 return verifyIntegerTypeInvariants(loc, type); 370 } 371 372 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 373 const APInt &value) { 374 if (failed(verifyIntegerTypeInvariants(loc, type))) 375 return failure(); 376 if (auto integerType = type.dyn_cast<IntegerType>()) 377 if (integerType.getWidth() != value.getBitWidth()) 378 return emitError(loc, "integer type bit width (") 379 << integerType.getWidth() << ") doesn't match value bit width (" 380 << value.getBitWidth() << ")"; 381 return success(); 382 } 383 384 //===----------------------------------------------------------------------===// 385 // BoolAttr 386 387 bool BoolAttr::getValue() const { 388 auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl); 389 return storage->getValue().getBoolValue(); 390 } 391 392 bool BoolAttr::classof(Attribute attr) { 393 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 394 return intAttr && intAttr.getType().isSignlessInteger(1); 395 } 396 397 //===----------------------------------------------------------------------===// 398 // IntegerSetAttr 399 //===----------------------------------------------------------------------===// 400 401 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { 402 return Base::get(value.getConstraint(0).getContext(), value); 403 } 404 405 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } 406 407 //===----------------------------------------------------------------------===// 408 // OpaqueAttr 409 //===----------------------------------------------------------------------===// 410 411 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, 412 MLIRContext *context) { 413 return Base::get(context, dialect, attrData, type); 414 } 415 416 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, 417 Type type, Location location) { 418 return Base::getChecked(location, dialect, attrData, type); 419 } 420 421 /// Returns the dialect namespace of the opaque attribute. 422 Identifier OpaqueAttr::getDialectNamespace() const { 423 return getImpl()->dialectNamespace; 424 } 425 426 /// Returns the raw attribute data of the opaque attribute. 427 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } 428 429 /// Verify the construction of an opaque attribute. 430 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc, 431 Identifier dialect, 432 StringRef attrData, 433 Type type) { 434 if (!Dialect::isValidNamespace(dialect.strref())) 435 return emitError(loc, "invalid dialect namespace '") << dialect << "'"; 436 return success(); 437 } 438 439 //===----------------------------------------------------------------------===// 440 // StringAttr 441 //===----------------------------------------------------------------------===// 442 443 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { 444 return get(bytes, NoneType::get(context)); 445 } 446 447 /// Get an instance of a StringAttr with the given string and Type. 448 StringAttr StringAttr::get(StringRef bytes, Type type) { 449 return Base::get(type.getContext(), bytes, type); 450 } 451 452 StringRef StringAttr::getValue() const { return getImpl()->value; } 453 454 //===----------------------------------------------------------------------===// 455 // TypeAttr 456 //===----------------------------------------------------------------------===// 457 458 TypeAttr TypeAttr::get(Type value) { 459 return Base::get(value.getContext(), value); 460 } 461 462 Type TypeAttr::getValue() const { return getImpl()->value; } 463 464 //===----------------------------------------------------------------------===// 465 // ElementsAttr 466 //===----------------------------------------------------------------------===// 467 468 ShapedType ElementsAttr::getType() const { 469 return Attribute::getType().cast<ShapedType>(); 470 } 471 472 /// Returns the number of elements held by this attribute. 473 int64_t ElementsAttr::getNumElements() const { 474 return getType().getNumElements(); 475 } 476 477 /// Return the value at the given index. If index does not refer to a valid 478 /// element, then a null attribute is returned. 479 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 480 if (auto denseAttr = dyn_cast<DenseElementsAttr>()) 481 return denseAttr.getValue(index); 482 if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>()) 483 return opaqueAttr.getValue(index); 484 return cast<SparseElementsAttr>().getValue(index); 485 } 486 487 /// Return if the given 'index' refers to a valid element in this attribute. 488 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 489 auto type = getType(); 490 491 // Verify that the rank of the indices matches the held type. 492 auto rank = type.getRank(); 493 if (rank != static_cast<int64_t>(index.size())) 494 return false; 495 496 // Verify that all of the indices are within the shape dimensions. 497 auto shape = type.getShape(); 498 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 499 return static_cast<int64_t>(index[i]) < shape[i]; 500 }); 501 } 502 503 ElementsAttr 504 ElementsAttr::mapValues(Type newElementType, 505 function_ref<APInt(const APInt &)> mapping) const { 506 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 507 return intOrFpAttr.mapValues(newElementType, mapping); 508 llvm_unreachable("unsupported ElementsAttr subtype"); 509 } 510 511 ElementsAttr 512 ElementsAttr::mapValues(Type newElementType, 513 function_ref<APInt(const APFloat &)> mapping) const { 514 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 515 return intOrFpAttr.mapValues(newElementType, mapping); 516 llvm_unreachable("unsupported ElementsAttr subtype"); 517 } 518 519 /// Method for support type inquiry through isa, cast and dyn_cast. 520 bool ElementsAttr::classof(Attribute attr) { 521 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr, 522 OpaqueElementsAttr, SparseElementsAttr>(); 523 } 524 525 /// Returns the 1 dimensional flattened row-major index from the given 526 /// multi-dimensional index. 527 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 528 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 529 auto type = getType(); 530 531 // Reduce the provided multidimensional index into a flattended 1D row-major 532 // index. 533 auto rank = type.getRank(); 534 auto shape = type.getShape(); 535 uint64_t valueIndex = 0; 536 uint64_t dimMultiplier = 1; 537 for (int i = rank - 1; i >= 0; --i) { 538 valueIndex += index[i] * dimMultiplier; 539 dimMultiplier *= shape[i]; 540 } 541 return valueIndex; 542 } 543 544 //===----------------------------------------------------------------------===// 545 // DenseElementsAttr Utilities 546 //===----------------------------------------------------------------------===// 547 548 /// Get the bitwidth of a dense element type within the buffer. 549 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 550 static size_t getDenseElementStorageWidth(size_t origWidth) { 551 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 552 } 553 static size_t getDenseElementStorageWidth(Type elementType) { 554 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 555 } 556 557 /// Set a bit to a specific value. 558 static void setBit(char *rawData, size_t bitPos, bool value) { 559 if (value) 560 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 561 else 562 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 563 } 564 565 /// Return the value of the specified bit. 566 static bool getBit(const char *rawData, size_t bitPos) { 567 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 568 } 569 570 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for 571 /// BE format. 572 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, 573 char *result) { 574 assert(llvm::support::endian::system_endianness() == // NOLINT 575 llvm::support::endianness::big); // NOLINT 576 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 577 578 // Copy the words filled with data. 579 // For example, when `value` has 2 words, the first word is filled with data. 580 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| 581 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 582 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 583 numFilledWords, result); 584 // Convert last word of APInt to LE format and store it in char 585 // array(`valueLE`). 586 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| 587 size_t lastWordPos = numFilledWords; 588 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); 589 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 590 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, 591 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); 592 // Extract actual APInt data from `valueLE`, convert endianness to BE format, 593 // and store it in `result`. 594 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| 595 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 596 valueLE.begin(), result + lastWordPos, 597 (numBytes - lastWordPos) * CHAR_BIT, 1); 598 } 599 600 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE 601 /// format. 602 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, 603 APInt &result) { 604 assert(llvm::support::endian::system_endianness() == // NOLINT 605 llvm::support::endianness::big); // NOLINT 606 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 607 608 // Copy the data that fills the word of `result` from `inArray`. 609 // For example, when `result` has 2 words, the first word will be filled with 610 // data. So, the first 8 bytes are copied from `inArray` here. 611 // `inArray` (10 bytes, BE): |abcdefgh|ij| 612 // ==> `result` (2 words, BE): |abcdefgh|--------| 613 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 614 std::copy_n( 615 inArray, numFilledWords, 616 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 617 618 // Convert array data which will be last word of `result` to LE format, and 619 // store it in char array(`inArrayLE`). 620 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| 621 size_t lastWordPos = numFilledWords; 622 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); 623 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 624 inArray + lastWordPos, inArrayLE.begin(), 625 (numBytes - lastWordPos) * CHAR_BIT, 1); 626 627 // Convert `inArrayLE` to BE format, and store it in last word of `result`. 628 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| 629 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 630 inArrayLE.begin(), 631 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + 632 lastWordPos, 633 APInt::APINT_BITS_PER_WORD, 1); 634 } 635 636 /// Writes value to the bit position `bitPos` in array `rawData`. 637 static void writeBits(char *rawData, size_t bitPos, APInt value) { 638 size_t bitWidth = value.getBitWidth(); 639 640 // If the bitwidth is 1 we just toggle the specific bit. 641 if (bitWidth == 1) 642 return setBit(rawData, bitPos, value.isOneValue()); 643 644 // Otherwise, the bit position is guaranteed to be byte aligned. 645 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 646 if (llvm::support::endian::system_endianness() == 647 llvm::support::endianness::big) { 648 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. 649 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 650 // work correctly in BE format. 651 // ex. `value` (2 words including 10 bytes) 652 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| 653 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), 654 rawData + (bitPos / CHAR_BIT)); 655 } else { 656 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 657 llvm::divideCeil(bitWidth, CHAR_BIT), 658 rawData + (bitPos / CHAR_BIT)); 659 } 660 } 661 662 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 663 /// `rawData`. 664 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 665 // Handle a boolean bit position. 666 if (bitWidth == 1) 667 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 668 669 // Otherwise, the bit position must be 8-bit aligned. 670 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 671 APInt result(bitWidth, 0); 672 if (llvm::support::endian::system_endianness() == 673 llvm::support::endianness::big) { 674 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. 675 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 676 // work correctly in BE format. 677 // ex. `result` (2 words including 10 bytes) 678 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function 679 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), 680 llvm::divideCeil(bitWidth, CHAR_BIT), result); 681 } else { 682 std::copy_n(rawData + (bitPos / CHAR_BIT), 683 llvm::divideCeil(bitWidth, CHAR_BIT), 684 const_cast<char *>( 685 reinterpret_cast<const char *>(result.getRawData()))); 686 } 687 return result; 688 } 689 690 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 691 /// the same element count as 'type'. 692 template <typename Values> 693 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 694 return (values.size() == 1) || 695 (type.getNumElements() == static_cast<int64_t>(values.size())); 696 } 697 698 //===----------------------------------------------------------------------===// 699 // DenseElementsAttr Iterators 700 //===----------------------------------------------------------------------===// 701 702 //===----------------------------------------------------------------------===// 703 // AttributeElementIterator 704 705 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 706 DenseElementsAttr attr, size_t index) 707 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 708 Attribute, Attribute, Attribute>( 709 attr.getAsOpaquePointer(), index) {} 710 711 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 712 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 713 Type eltTy = owner.getType().getElementType(); 714 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 715 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 716 if (eltTy.isa<IndexType>()) 717 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 718 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 719 IntElementIterator intIt(owner, index); 720 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 721 return FloatAttr::get(eltTy, *floatIt); 722 } 723 if (owner.isa<DenseStringElementsAttr>()) { 724 ArrayRef<StringRef> vals = owner.getRawStringData(); 725 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 726 } 727 llvm_unreachable("unexpected element type"); 728 } 729 730 //===----------------------------------------------------------------------===// 731 // BoolElementIterator 732 733 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 734 DenseElementsAttr attr, size_t dataIndex) 735 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 736 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 737 738 bool DenseElementsAttr::BoolElementIterator::operator*() const { 739 return getBit(getData(), getDataIndex()); 740 } 741 742 //===----------------------------------------------------------------------===// 743 // IntElementIterator 744 745 DenseElementsAttr::IntElementIterator::IntElementIterator( 746 DenseElementsAttr attr, size_t dataIndex) 747 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 748 attr.getRawData().data(), attr.isSplat(), dataIndex), 749 bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} 750 751 APInt DenseElementsAttr::IntElementIterator::operator*() const { 752 return readBits(getData(), 753 getDataIndex() * getDenseElementStorageWidth(bitWidth), 754 bitWidth); 755 } 756 757 //===----------------------------------------------------------------------===// 758 // ComplexIntElementIterator 759 760 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 761 DenseElementsAttr attr, size_t dataIndex) 762 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 763 std::complex<APInt>, std::complex<APInt>, 764 std::complex<APInt>>( 765 attr.getRawData().data(), attr.isSplat(), dataIndex) { 766 auto complexType = attr.getType().getElementType().cast<ComplexType>(); 767 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 768 } 769 770 std::complex<APInt> 771 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 772 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 773 size_t offset = getDataIndex() * storageWidth * 2; 774 return {readBits(getData(), offset, bitWidth), 775 readBits(getData(), offset + storageWidth, bitWidth)}; 776 } 777 778 //===----------------------------------------------------------------------===// 779 // FloatElementIterator 780 781 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 782 const llvm::fltSemantics &smt, IntElementIterator it) 783 : llvm::mapped_iterator<IntElementIterator, 784 std::function<APFloat(const APInt &)>>( 785 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 786 787 //===----------------------------------------------------------------------===// 788 // ComplexFloatElementIterator 789 790 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( 791 const llvm::fltSemantics &smt, ComplexIntElementIterator it) 792 : llvm::mapped_iterator< 793 ComplexIntElementIterator, 794 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( 795 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { 796 return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; 797 }) {} 798 799 //===----------------------------------------------------------------------===// 800 // DenseElementsAttr 801 //===----------------------------------------------------------------------===// 802 803 /// Method for support type inquiry through isa, cast and dyn_cast. 804 bool DenseElementsAttr::classof(Attribute attr) { 805 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 806 } 807 808 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 809 ArrayRef<Attribute> values) { 810 assert(hasSameElementsOrSplat(type, values)); 811 812 // If the element type is not based on int/float/index, assume it is a string 813 // type. 814 auto eltType = type.getElementType(); 815 if (!type.getElementType().isIntOrIndexOrFloat()) { 816 SmallVector<StringRef, 8> stringValues; 817 stringValues.reserve(values.size()); 818 for (Attribute attr : values) { 819 assert(attr.isa<StringAttr>() && 820 "expected string value for non integer/index/float element"); 821 stringValues.push_back(attr.cast<StringAttr>().getValue()); 822 } 823 return get(type, stringValues); 824 } 825 826 // Otherwise, get the raw storage width to use for the allocation. 827 size_t bitWidth = getDenseElementBitWidth(eltType); 828 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 829 830 // Compress the attribute values into a character buffer. 831 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 832 values.size()); 833 APInt intVal; 834 for (unsigned i = 0, e = values.size(); i < e; ++i) { 835 assert(eltType == values[i].getType() && 836 "expected attribute value to have element type"); 837 if (eltType.isa<FloatType>()) 838 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 839 else if (eltType.isa<IntegerType>()) 840 intVal = values[i].cast<IntegerAttr>().getValue(); 841 else 842 llvm_unreachable("unexpected element type"); 843 844 assert(intVal.getBitWidth() == bitWidth && 845 "expected value to have same bitwidth as element type"); 846 writeBits(data.data(), i * storageBitWidth, intVal); 847 } 848 return DenseIntOrFPElementsAttr::getRaw(type, data, 849 /*isSplat=*/(values.size() == 1)); 850 } 851 852 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 853 ArrayRef<bool> values) { 854 assert(hasSameElementsOrSplat(type, values)); 855 assert(type.getElementType().isInteger(1)); 856 857 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 858 for (int i = 0, e = values.size(); i != e; ++i) 859 setBit(buff.data(), i, values[i]); 860 return DenseIntOrFPElementsAttr::getRaw(type, buff, 861 /*isSplat=*/(values.size() == 1)); 862 } 863 864 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 865 ArrayRef<StringRef> values) { 866 assert(!type.getElementType().isIntOrFloat()); 867 return DenseStringElementsAttr::get(type, values); 868 } 869 870 /// Constructs a dense integer elements attribute from an array of APInt 871 /// values. Each APInt value is expected to have the same bitwidth as the 872 /// element type of 'type'. 873 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 874 ArrayRef<APInt> values) { 875 assert(type.getElementType().isIntOrIndex()); 876 assert(hasSameElementsOrSplat(type, values)); 877 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 878 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 879 /*isSplat=*/(values.size() == 1)); 880 } 881 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 882 ArrayRef<std::complex<APInt>> values) { 883 ComplexType complex = type.getElementType().cast<ComplexType>(); 884 assert(complex.getElementType().isa<IntegerType>()); 885 assert(hasSameElementsOrSplat(type, values)); 886 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 887 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 888 values.size() * 2); 889 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, 890 /*isSplat=*/(values.size() == 1)); 891 } 892 893 // Constructs a dense float elements attribute from an array of APFloat 894 // values. Each APFloat value is expected to have the same bitwidth as the 895 // element type of 'type'. 896 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 897 ArrayRef<APFloat> values) { 898 assert(type.getElementType().isa<FloatType>()); 899 assert(hasSameElementsOrSplat(type, values)); 900 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 901 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 902 /*isSplat=*/(values.size() == 1)); 903 } 904 DenseElementsAttr 905 DenseElementsAttr::get(ShapedType type, 906 ArrayRef<std::complex<APFloat>> values) { 907 ComplexType complex = type.getElementType().cast<ComplexType>(); 908 assert(complex.getElementType().isa<FloatType>()); 909 assert(hasSameElementsOrSplat(type, values)); 910 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 911 values.size() * 2); 912 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 913 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, 914 /*isSplat=*/(values.size() == 1)); 915 } 916 917 /// Construct a dense elements attribute from a raw buffer representing the 918 /// data for this attribute. Users should generally not use this methods as 919 /// the expected buffer format may not be a form the user expects. 920 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 921 ArrayRef<char> rawBuffer, 922 bool isSplatBuffer) { 923 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 924 } 925 926 /// Returns true if the given buffer is a valid raw buffer for the given type. 927 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 928 ArrayRef<char> rawBuffer, 929 bool &detectedSplat) { 930 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 931 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 932 933 // Storage width of 1 is special as it is packed by the bit. 934 if (storageWidth == 1) { 935 // Check for a splat, or a buffer equal to the number of elements. 936 if ((detectedSplat = rawBuffer.size() == 1)) 937 return true; 938 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); 939 } 940 // All other types are 8-bit aligned. 941 if ((detectedSplat = rawBufferWidth == storageWidth)) 942 return true; 943 return rawBufferWidth == (storageWidth * type.getNumElements()); 944 } 945 946 /// Check the information for a C++ data type, check if this type is valid for 947 /// the current attribute. This method is used to verify specific type 948 /// invariants that the templatized 'getValues' method cannot. 949 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 950 bool isSigned) { 951 // Make sure that the data element size is the same as the type element width. 952 if (getDenseElementBitWidth(type) != 953 static_cast<size_t>(dataEltSize * CHAR_BIT)) 954 return false; 955 956 // Check that the element type is either float or integer or index. 957 if (!isInt) 958 return type.isa<FloatType>(); 959 if (type.isIndex()) 960 return true; 961 962 auto intType = type.dyn_cast<IntegerType>(); 963 if (!intType) 964 return false; 965 966 // Make sure signedness semantics is consistent. 967 if (intType.isSignless()) 968 return true; 969 return intType.isSigned() ? isSigned : !isSigned; 970 } 971 972 /// Defaults down the subclass implementation. 973 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 974 ArrayRef<char> data, 975 int64_t dataEltSize, 976 bool isInt, bool isSigned) { 977 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 978 isSigned); 979 } 980 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 981 ArrayRef<char> data, 982 int64_t dataEltSize, 983 bool isInt, 984 bool isSigned) { 985 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 986 isInt, isSigned); 987 } 988 989 /// A method used to verify specific type invariants that the templatized 'get' 990 /// method cannot. 991 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 992 bool isSigned) const { 993 return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt, 994 isSigned); 995 } 996 997 /// Check the information for a C++ data type, check if this type is valid for 998 /// the current attribute. 999 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 1000 bool isSigned) const { 1001 return ::isValidIntOrFloat( 1002 getType().getElementType().cast<ComplexType>().getElementType(), 1003 dataEltSize / 2, isInt, isSigned); 1004 } 1005 1006 /// Returns true if this attribute corresponds to a splat, i.e. if all element 1007 /// values are the same. 1008 bool DenseElementsAttr::isSplat() const { 1009 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 1010 } 1011 1012 /// Return the held element values as a range of Attributes. 1013 auto DenseElementsAttr::getAttributeValues() const 1014 -> llvm::iterator_range<AttributeElementIterator> { 1015 return {attr_value_begin(), attr_value_end()}; 1016 } 1017 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 1018 return AttributeElementIterator(*this, 0); 1019 } 1020 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 1021 return AttributeElementIterator(*this, getNumElements()); 1022 } 1023 1024 /// Return the held element values as a range of bool. The element type of 1025 /// this attribute must be of integer type of bitwidth 1. 1026 auto DenseElementsAttr::getBoolValues() const 1027 -> llvm::iterator_range<BoolElementIterator> { 1028 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 1029 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 1030 (void)eltType; 1031 return {BoolElementIterator(*this, 0), 1032 BoolElementIterator(*this, getNumElements())}; 1033 } 1034 1035 /// Return the held element values as a range of APInts. The element type of 1036 /// this attribute must be of integer type. 1037 auto DenseElementsAttr::getIntValues() const 1038 -> llvm::iterator_range<IntElementIterator> { 1039 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 1040 return {raw_int_begin(), raw_int_end()}; 1041 } 1042 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 1043 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 1044 return raw_int_begin(); 1045 } 1046 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 1047 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 1048 return raw_int_end(); 1049 } 1050 auto DenseElementsAttr::getComplexIntValues() const 1051 -> llvm::iterator_range<ComplexIntElementIterator> { 1052 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 1053 (void)eltTy; 1054 assert(eltTy.isa<IntegerType>() && "expected complex integral type"); 1055 return {ComplexIntElementIterator(*this, 0), 1056 ComplexIntElementIterator(*this, getNumElements())}; 1057 } 1058 1059 /// Return the held element values as a range of APFloat. The element type of 1060 /// this attribute must be of float type. 1061 auto DenseElementsAttr::getFloatValues() const 1062 -> llvm::iterator_range<FloatElementIterator> { 1063 auto elementType = getType().getElementType().cast<FloatType>(); 1064 const auto &elementSemantics = elementType.getFloatSemantics(); 1065 return {FloatElementIterator(elementSemantics, raw_int_begin()), 1066 FloatElementIterator(elementSemantics, raw_int_end())}; 1067 } 1068 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 1069 return getFloatValues().begin(); 1070 } 1071 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 1072 return getFloatValues().end(); 1073 } 1074 auto DenseElementsAttr::getComplexFloatValues() const 1075 -> llvm::iterator_range<ComplexFloatElementIterator> { 1076 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 1077 assert(eltTy.isa<FloatType>() && "expected complex float type"); 1078 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 1079 return {{semantics, {*this, 0}}, 1080 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 1081 } 1082 1083 /// Return the raw storage data held by this attribute. 1084 ArrayRef<char> DenseElementsAttr::getRawData() const { 1085 return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data; 1086 } 1087 1088 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 1089 return static_cast<DenseStringElementsAttributeStorage *>(impl)->data; 1090 } 1091 1092 /// Return a new DenseElementsAttr that has the same data as the current 1093 /// attribute, but has been reshaped to 'newType'. The new type must have the 1094 /// same total number of elements as well as element type. 1095 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 1096 ShapedType curType = getType(); 1097 if (curType == newType) 1098 return *this; 1099 1100 (void)curType; 1101 assert(newType.getElementType() == curType.getElementType() && 1102 "expected the same element type"); 1103 assert(newType.getNumElements() == curType.getNumElements() && 1104 "expected the same number of elements"); 1105 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 1106 } 1107 1108 DenseElementsAttr 1109 DenseElementsAttr::mapValues(Type newElementType, 1110 function_ref<APInt(const APInt &)> mapping) const { 1111 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 1112 } 1113 1114 DenseElementsAttr DenseElementsAttr::mapValues( 1115 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1116 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 1117 } 1118 1119 //===----------------------------------------------------------------------===// 1120 // DenseStringElementsAttr 1121 //===----------------------------------------------------------------------===// 1122 1123 DenseStringElementsAttr 1124 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) { 1125 return Base::get(type.getContext(), type, values, (values.size() == 1)); 1126 } 1127 1128 //===----------------------------------------------------------------------===// 1129 // DenseIntOrFPElementsAttr 1130 //===----------------------------------------------------------------------===// 1131 1132 /// Utility method to write a range of APInt values to a buffer. 1133 template <typename APRangeT> 1134 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1135 APRangeT &&values) { 1136 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); 1137 size_t offset = 0; 1138 for (auto it = values.begin(), e = values.end(); it != e; 1139 ++it, offset += storageWidth) { 1140 assert((*it).getBitWidth() <= storageWidth); 1141 writeBits(data.data(), offset, *it); 1142 } 1143 } 1144 1145 /// Constructs a dense elements attribute from an array of raw APFloat values. 1146 /// Each APFloat value is expected to have the same bitwidth as the element 1147 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1148 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1149 size_t storageWidth, 1150 ArrayRef<APFloat> values, 1151 bool isSplat) { 1152 std::vector<char> data; 1153 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1154 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1155 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1156 } 1157 1158 /// Constructs a dense elements attribute from an array of raw APInt values. 1159 /// Each APInt value is expected to have the same bitwidth as the element type 1160 /// of 'type'. 1161 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1162 size_t storageWidth, 1163 ArrayRef<APInt> values, 1164 bool isSplat) { 1165 std::vector<char> data; 1166 writeAPIntsToBuffer(storageWidth, data, values); 1167 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1168 } 1169 1170 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1171 ArrayRef<char> data, 1172 bool isSplat) { 1173 assert((type.isa<RankedTensorType, VectorType>()) && 1174 "type must be ranked tensor or vector"); 1175 assert(type.hasStaticShape() && "type must have static shape"); 1176 return Base::get(type.getContext(), type, data, isSplat); 1177 } 1178 1179 /// Overload of the raw 'get' method that asserts that the given type is of 1180 /// complex type. This method is used to verify type invariants that the 1181 /// templatized 'get' method cannot. 1182 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1183 ArrayRef<char> data, 1184 int64_t dataEltSize, 1185 bool isInt, 1186 bool isSigned) { 1187 assert(::isValidIntOrFloat( 1188 type.getElementType().cast<ComplexType>().getElementType(), 1189 dataEltSize / 2, isInt, isSigned)); 1190 1191 int64_t numElements = data.size() / dataEltSize; 1192 assert(numElements == 1 || numElements == type.getNumElements()); 1193 return getRaw(type, data, /*isSplat=*/numElements == 1); 1194 } 1195 1196 /// Overload of the 'getRaw' method that asserts that the given type is of 1197 /// integer type. This method is used to verify type invariants that the 1198 /// templatized 'get' method cannot. 1199 DenseElementsAttr 1200 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1201 int64_t dataEltSize, bool isInt, 1202 bool isSigned) { 1203 assert( 1204 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1205 1206 int64_t numElements = data.size() / dataEltSize; 1207 assert(numElements == 1 || numElements == type.getNumElements()); 1208 return getRaw(type, data, /*isSplat=*/numElements == 1); 1209 } 1210 1211 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 1212 const char *inRawData, char *outRawData, size_t elementBitWidth, 1213 size_t numElements) { 1214 using llvm::support::ulittle16_t; 1215 using llvm::support::ulittle32_t; 1216 using llvm::support::ulittle64_t; 1217 1218 assert(llvm::support::endian::system_endianness() == // NOLINT 1219 llvm::support::endianness::big); // NOLINT 1220 // NOLINT to avoid warning message about replacing by static_assert() 1221 1222 // Following std::copy_n always converts endianness on BE machine. 1223 switch (elementBitWidth) { 1224 case 16: { 1225 const ulittle16_t *inRawDataPos = 1226 reinterpret_cast<const ulittle16_t *>(inRawData); 1227 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); 1228 std::copy_n(inRawDataPos, numElements, outDataPos); 1229 break; 1230 } 1231 case 32: { 1232 const ulittle32_t *inRawDataPos = 1233 reinterpret_cast<const ulittle32_t *>(inRawData); 1234 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); 1235 std::copy_n(inRawDataPos, numElements, outDataPos); 1236 break; 1237 } 1238 case 64: { 1239 const ulittle64_t *inRawDataPos = 1240 reinterpret_cast<const ulittle64_t *>(inRawData); 1241 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); 1242 std::copy_n(inRawDataPos, numElements, outDataPos); 1243 break; 1244 } 1245 default: { 1246 size_t nBytes = elementBitWidth / CHAR_BIT; 1247 for (size_t i = 0; i < nBytes; i++) 1248 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); 1249 break; 1250 } 1251 } 1252 } 1253 1254 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1255 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, 1256 ShapedType type) { 1257 size_t numElements = type.getNumElements(); 1258 Type elementType = type.getElementType(); 1259 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1260 elementType = complexTy.getElementType(); 1261 numElements = numElements * 2; 1262 } 1263 size_t elementBitWidth = getDenseElementStorageWidth(elementType); 1264 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && 1265 inRawData.size() <= outRawData.size()); 1266 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), 1267 elementBitWidth, numElements); 1268 } 1269 1270 //===----------------------------------------------------------------------===// 1271 // DenseFPElementsAttr 1272 //===----------------------------------------------------------------------===// 1273 1274 template <typename Fn, typename Attr> 1275 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1276 Type newElementType, 1277 llvm::SmallVectorImpl<char> &data) { 1278 size_t bitWidth = getDenseElementBitWidth(newElementType); 1279 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1280 1281 ShapedType newArrayType; 1282 if (inType.isa<RankedTensorType>()) 1283 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1284 else if (inType.isa<UnrankedTensorType>()) 1285 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1286 else if (inType.isa<VectorType>()) 1287 newArrayType = VectorType::get(inType.getShape(), newElementType); 1288 else 1289 assert(newArrayType && "Unhandled tensor type"); 1290 1291 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1292 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 1293 1294 // Functor used to process a single element value of the attribute. 1295 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1296 auto newInt = mapping(value); 1297 assert(newInt.getBitWidth() == bitWidth); 1298 writeBits(data.data(), index * storageBitWidth, newInt); 1299 }; 1300 1301 // Check for the splat case. 1302 if (attr.isSplat()) { 1303 processElt(*attr.begin(), /*index=*/0); 1304 return newArrayType; 1305 } 1306 1307 // Otherwise, process all of the element values. 1308 uint64_t elementIdx = 0; 1309 for (auto value : attr) 1310 processElt(value, elementIdx++); 1311 return newArrayType; 1312 } 1313 1314 DenseElementsAttr DenseFPElementsAttr::mapValues( 1315 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1316 llvm::SmallVector<char, 8> elementData; 1317 auto newArrayType = 1318 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1319 1320 return getRaw(newArrayType, elementData, isSplat()); 1321 } 1322 1323 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1324 bool DenseFPElementsAttr::classof(Attribute attr) { 1325 return attr.isa<DenseElementsAttr>() && 1326 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1327 } 1328 1329 //===----------------------------------------------------------------------===// 1330 // DenseIntElementsAttr 1331 //===----------------------------------------------------------------------===// 1332 1333 DenseElementsAttr DenseIntElementsAttr::mapValues( 1334 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1335 llvm::SmallVector<char, 8> elementData; 1336 auto newArrayType = 1337 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1338 1339 return getRaw(newArrayType, elementData, isSplat()); 1340 } 1341 1342 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1343 bool DenseIntElementsAttr::classof(Attribute attr) { 1344 return attr.isa<DenseElementsAttr>() && 1345 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1346 } 1347 1348 //===----------------------------------------------------------------------===// 1349 // OpaqueElementsAttr 1350 //===----------------------------------------------------------------------===// 1351 1352 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, 1353 StringRef bytes) { 1354 assert(TensorType::isValidElementType(type.getElementType()) && 1355 "Input element type should be a valid tensor element type"); 1356 return Base::get(type.getContext(), type, dialect, bytes); 1357 } 1358 1359 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } 1360 1361 /// Return the value at the given index. If index does not refer to a valid 1362 /// element, then a null attribute is returned. 1363 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1364 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1365 return Attribute(); 1366 } 1367 1368 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } 1369 1370 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1371 auto *d = getDialect(); 1372 if (!d) 1373 return true; 1374 auto *interface = 1375 d->getRegisteredInterface<DialectDecodeAttributesInterface>(); 1376 if (!interface) 1377 return true; 1378 return failed(interface->decode(*this, result)); 1379 } 1380 1381 //===----------------------------------------------------------------------===// 1382 // SparseElementsAttr 1383 //===----------------------------------------------------------------------===// 1384 1385 SparseElementsAttr SparseElementsAttr::get(ShapedType type, 1386 DenseElementsAttr indices, 1387 DenseElementsAttr values) { 1388 assert(indices.getType().getElementType().isInteger(64) && 1389 "expected sparse indices to be 64-bit integer values"); 1390 assert((type.isa<RankedTensorType, VectorType>()) && 1391 "type must be ranked tensor or vector"); 1392 assert(type.hasStaticShape() && "type must have static shape"); 1393 return Base::get(type.getContext(), type, 1394 indices.cast<DenseIntElementsAttr>(), values); 1395 } 1396 1397 DenseIntElementsAttr SparseElementsAttr::getIndices() const { 1398 return getImpl()->indices; 1399 } 1400 1401 DenseElementsAttr SparseElementsAttr::getValues() const { 1402 return getImpl()->values; 1403 } 1404 1405 /// Return the value of the element at the given index. 1406 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1407 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1408 auto type = getType(); 1409 1410 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1411 // as a 1-D index array. 1412 auto sparseIndices = getIndices(); 1413 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1414 1415 // Check to see if the indices are a splat. 1416 if (sparseIndices.isSplat()) { 1417 // If the index is also not a splat of the index value, we know that the 1418 // value is zero. 1419 auto splatIndex = *sparseIndexValues.begin(); 1420 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1421 return getZeroAttr(); 1422 1423 // If the indices are a splat, we also expect the values to be a splat. 1424 assert(getValues().isSplat() && "expected splat values"); 1425 return getValues().getSplatValue(); 1426 } 1427 1428 // Build a mapping between known indices and the offset of the stored element. 1429 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1430 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1431 size_t rank = type.getRank(); 1432 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1433 mappedIndices.try_emplace( 1434 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1435 1436 // Look for the provided index key within the mapped indices. If the provided 1437 // index is not found, then return a zero attribute. 1438 auto it = mappedIndices.find(index); 1439 if (it == mappedIndices.end()) 1440 return getZeroAttr(); 1441 1442 // Otherwise, return the held sparse value element. 1443 return getValues().getValue(it->second); 1444 } 1445 1446 /// Get a zero APFloat for the given sparse attribute. 1447 APFloat SparseElementsAttr::getZeroAPFloat() const { 1448 auto eltType = getType().getElementType().cast<FloatType>(); 1449 return APFloat(eltType.getFloatSemantics()); 1450 } 1451 1452 /// Get a zero APInt for the given sparse attribute. 1453 APInt SparseElementsAttr::getZeroAPInt() const { 1454 auto eltType = getType().getElementType().cast<IntegerType>(); 1455 return APInt::getNullValue(eltType.getWidth()); 1456 } 1457 1458 /// Get a zero attribute for the given attribute type. 1459 Attribute SparseElementsAttr::getZeroAttr() const { 1460 auto eltType = getType().getElementType(); 1461 1462 // Handle floating point elements. 1463 if (eltType.isa<FloatType>()) 1464 return FloatAttr::get(eltType, 0); 1465 1466 // Otherwise, this is an integer. 1467 // TODO: Handle StringAttr here. 1468 return IntegerAttr::get(eltType, 0); 1469 } 1470 1471 /// Flatten, and return, all of the sparse indices in this attribute in 1472 /// row-major order. 1473 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1474 std::vector<ptrdiff_t> flatSparseIndices; 1475 1476 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1477 // as a 1-D index array. 1478 auto sparseIndices = getIndices(); 1479 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1480 if (sparseIndices.isSplat()) { 1481 SmallVector<uint64_t, 8> indices(getType().getRank(), 1482 *sparseIndexValues.begin()); 1483 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1484 return flatSparseIndices; 1485 } 1486 1487 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1488 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1489 size_t rank = getType().getRank(); 1490 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1491 flatSparseIndices.push_back(getFlattenedIndex( 1492 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1493 return flatSparseIndices; 1494 } 1495 1496 //===----------------------------------------------------------------------===// 1497 // MutableDictionaryAttr 1498 //===----------------------------------------------------------------------===// 1499 1500 MutableDictionaryAttr::MutableDictionaryAttr( 1501 ArrayRef<NamedAttribute> attributes) { 1502 setAttrs(attributes); 1503 } 1504 1505 /// Return the underlying dictionary attribute. 1506 DictionaryAttr 1507 MutableDictionaryAttr::getDictionary(MLIRContext *context) const { 1508 // Construct empty DictionaryAttr if needed. 1509 if (!attrs) 1510 return DictionaryAttr::get({}, context); 1511 return attrs; 1512 } 1513 1514 ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const { 1515 return attrs ? attrs.getValue() : llvm::None; 1516 } 1517 1518 /// Replace the held attributes with ones provided in 'newAttrs'. 1519 void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) { 1520 // Don't create an attribute list if there are no attributes. 1521 if (attributes.empty()) 1522 attrs = nullptr; 1523 else 1524 attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext()); 1525 } 1526 1527 /// Return the specified attribute if present, null otherwise. 1528 Attribute MutableDictionaryAttr::get(StringRef name) const { 1529 return attrs ? attrs.get(name) : nullptr; 1530 } 1531 1532 /// Return the specified attribute if present, null otherwise. 1533 Attribute MutableDictionaryAttr::get(Identifier name) const { 1534 return attrs ? attrs.get(name) : nullptr; 1535 } 1536 1537 /// Return the specified named attribute if present, None otherwise. 1538 Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const { 1539 return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>(); 1540 } 1541 Optional<NamedAttribute> 1542 MutableDictionaryAttr::getNamed(Identifier name) const { 1543 return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>(); 1544 } 1545 1546 /// If the an attribute exists with the specified name, change it to the new 1547 /// value. Otherwise, add a new attribute with the specified name/value. 1548 void MutableDictionaryAttr::set(Identifier name, Attribute value) { 1549 assert(value && "attributes may never be null"); 1550 1551 // Look for an existing value for the given name, and set it in-place. 1552 ArrayRef<NamedAttribute> values = getAttrs(); 1553 const auto *it = llvm::find_if( 1554 values, [name](NamedAttribute attr) { return attr.first == name; }); 1555 if (it != values.end()) { 1556 // Bail out early if the value is the same as what we already have. 1557 if (it->second == value) 1558 return; 1559 1560 SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end()); 1561 newAttrs[it - values.begin()].second = value; 1562 attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); 1563 return; 1564 } 1565 1566 // Otherwise, insert the new attribute into its sorted position. 1567 it = llvm::lower_bound(values, name); 1568 SmallVector<NamedAttribute, 8> newAttrs; 1569 newAttrs.reserve(values.size() + 1); 1570 newAttrs.append(values.begin(), it); 1571 newAttrs.push_back({name, value}); 1572 newAttrs.append(it, values.end()); 1573 attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); 1574 } 1575 1576 /// Remove the attribute with the specified name if it exists. The return 1577 /// value indicates whether the attribute was present or not. 1578 auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult { 1579 auto origAttrs = getAttrs(); 1580 for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { 1581 if (origAttrs[i].first == name) { 1582 // Handle the simple case of removing the only attribute in the list. 1583 if (e == 1) { 1584 attrs = nullptr; 1585 return RemoveResult::Removed; 1586 } 1587 1588 SmallVector<NamedAttribute, 8> newAttrs; 1589 newAttrs.reserve(origAttrs.size() - 1); 1590 newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); 1591 newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); 1592 attrs = DictionaryAttr::getWithSorted(newAttrs, 1593 newAttrs[0].second.getContext()); 1594 return RemoveResult::Removed; 1595 } 1596 } 1597 return RemoveResult::NotFound; 1598 } 1599 1600 bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) { 1601 return strcmp(lhs.first.data(), rhs.first.data()) < 0; 1602 } 1603 bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) { 1604 // This is correct even when attr.first.data()[name.size()] is not a zero 1605 // string terminator, because we only care about a less than comparison. 1606 // This can't use memcmp, because it doesn't guarantee that it will stop 1607 // reading both buffers if one is shorter than the other, even if there is 1608 // a difference. 1609 return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0; 1610 } 1611