1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Attributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Diagnostics.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/Function.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/IR/Types.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 #include "llvm/Support/Endian.h" 20 21 using namespace mlir; 22 using namespace mlir::detail; 23 24 //===----------------------------------------------------------------------===// 25 // AttributeStorage 26 //===----------------------------------------------------------------------===// 27 28 AttributeStorage::AttributeStorage(Type type) 29 : type(type.getAsOpaquePointer()) {} 30 AttributeStorage::AttributeStorage() : type(nullptr) {} 31 32 Type AttributeStorage::getType() const { 33 return Type::getFromOpaquePointer(type); 34 } 35 void AttributeStorage::setType(Type newType) { 36 type = newType.getAsOpaquePointer(); 37 } 38 39 //===----------------------------------------------------------------------===// 40 // Attribute 41 //===----------------------------------------------------------------------===// 42 43 /// Return the type of this attribute. 44 Type Attribute::getType() const { return impl->getType(); } 45 46 /// Return the context this attribute belongs to. 47 MLIRContext *Attribute::getContext() const { return getType().getContext(); } 48 49 /// Get the dialect this attribute is registered to. 50 Dialect &Attribute::getDialect() const { return impl->getDialect(); } 51 52 //===----------------------------------------------------------------------===// 53 // AffineMapAttr 54 //===----------------------------------------------------------------------===// 55 56 AffineMapAttr AffineMapAttr::get(AffineMap value) { 57 return Base::get(value.getContext(), StandardAttributes::AffineMap, value); 58 } 59 60 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } 61 62 //===----------------------------------------------------------------------===// 63 // ArrayAttr 64 //===----------------------------------------------------------------------===// 65 66 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) { 67 return Base::get(context, StandardAttributes::Array, value); 68 } 69 70 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } 71 72 Attribute ArrayAttr::operator[](unsigned idx) const { 73 assert(idx < size() && "index out of bounds"); 74 return getValue()[idx]; 75 } 76 77 //===----------------------------------------------------------------------===// 78 // BoolAttr 79 //===----------------------------------------------------------------------===// 80 81 bool BoolAttr::getValue() const { return getImpl()->value; } 82 83 //===----------------------------------------------------------------------===// 84 // DictionaryAttr 85 //===----------------------------------------------------------------------===// 86 87 /// Perform a three-way comparison between the names of the specified 88 /// NamedAttributes. 89 static int compareNamedAttributes(const NamedAttribute *lhs, 90 const NamedAttribute *rhs) { 91 return strcmp(lhs->first.data(), rhs->first.data()); 92 } 93 94 /// Returns if the name of the given attribute precedes that of 'name'. 95 static bool compareNamedAttributeWithName(const NamedAttribute &attr, 96 StringRef name) { 97 // This is correct even when attr.first.data()[name.size()] is not a zero 98 // string terminator, because we only care about a less than comparison. 99 // This can't use memcmp, because it doesn't guarantee that it will stop 100 // reading both buffers if one is shorter than the other, even if there is 101 // a difference. 102 return strncmp(attr.first.data(), name.data(), name.size()) < 0; 103 } 104 105 /// Helper function that does either an in place sort or sorts from source array 106 /// into destination. If inPlace then storage is both the source and the 107 /// destination, else value is the source and storage destination. Returns 108 /// whether source was sorted. 109 template <bool inPlace> 110 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 111 SmallVectorImpl<NamedAttribute> &storage) { 112 // Specialize for the common case. 113 switch (value.size()) { 114 case 0: 115 case 1: 116 // Zero or one elements are already sorted. 117 break; 118 case 2: 119 assert(value[0].first != value[1].first && 120 "DictionaryAttr element names must be unique"); 121 if (compareNamedAttributes(&value[0], &value[1]) > 0) { 122 if (inPlace) 123 std::swap(storage[0], storage[1]); 124 else 125 storage.append({value[1], value[0]}); 126 return true; 127 } 128 break; 129 default: 130 // Check to see they are sorted already. 131 bool isSorted = 132 llvm::is_sorted(value, [](NamedAttribute l, NamedAttribute r) { 133 return compareNamedAttributes(&l, &r) < 0; 134 }); 135 if (!isSorted) { 136 // If not, do a general sort. 137 if (!inPlace) 138 storage.append(value.begin(), value.end()); 139 llvm::array_pod_sort(storage.begin(), storage.end(), 140 compareNamedAttributes); 141 value = storage; 142 } 143 144 // Ensure that the attribute elements are unique. 145 assert(std::adjacent_find(value.begin(), value.end(), 146 [](NamedAttribute l, NamedAttribute r) { 147 return l.first == r.first; 148 }) == value.end() && 149 "DictionaryAttr element names must be unique"); 150 return !isSorted; 151 } 152 return false; 153 } 154 155 /// Sorts the NamedAttributes in the array ordered by name as expected by 156 /// getWithSorted. 157 /// Requires: uniquely named attributes. 158 void DictionaryAttr::sort(SmallVectorImpl<NamedAttribute> &array) { 159 dictionaryAttrSort</*inPlace=*/true>(array, array); 160 } 161 162 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value, 163 MLIRContext *context) { 164 assert(llvm::all_of(value, 165 [](const NamedAttribute &attr) { return attr.second; }) && 166 "value cannot have null entries"); 167 168 // We need to sort the element list to canonicalize it. 169 SmallVector<NamedAttribute, 8> storage; 170 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 171 value = storage; 172 173 return Base::get(context, StandardAttributes::Dictionary, value); 174 } 175 176 /// Construct a dictionary with an array of values that is known to already be 177 /// sorted by name and uniqued. 178 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value, 179 MLIRContext *context) { 180 // Ensure that the attribute elements are unique and sorted. 181 assert(llvm::is_sorted(value, 182 [](NamedAttribute l, NamedAttribute r) { 183 return l.first.strref() < r.first.strref(); 184 }) && 185 "expected attribute values to be sorted"); 186 assert(std::adjacent_find(value.begin(), value.end(), 187 [](NamedAttribute l, NamedAttribute r) { 188 return l.first == r.first; 189 }) == value.end() && 190 "DictionaryAttr element names must be unique"); 191 return Base::get(context, StandardAttributes::Dictionary, value); 192 } 193 194 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const { 195 return getImpl()->getElements(); 196 } 197 198 /// Return the specified attribute if present, null otherwise. 199 Attribute DictionaryAttr::get(StringRef name) const { 200 Optional<NamedAttribute> attr = getNamed(name); 201 return attr ? attr->second : nullptr; 202 } 203 Attribute DictionaryAttr::get(Identifier name) const { 204 Optional<NamedAttribute> attr = getNamed(name); 205 return attr ? attr->second : nullptr; 206 } 207 208 /// Return the specified named attribute if present, None otherwise. 209 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 210 ArrayRef<NamedAttribute> values = getValue(); 211 auto it = llvm::lower_bound(values, name, compareNamedAttributeWithName); 212 return it != values.end() && it->first == name ? *it 213 : Optional<NamedAttribute>(); 214 } 215 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { 216 for (auto elt : getValue()) 217 if (elt.first == name) 218 return elt; 219 return llvm::None; 220 } 221 222 DictionaryAttr::iterator DictionaryAttr::begin() const { 223 return getValue().begin(); 224 } 225 DictionaryAttr::iterator DictionaryAttr::end() const { 226 return getValue().end(); 227 } 228 size_t DictionaryAttr::size() const { return getValue().size(); } 229 230 //===----------------------------------------------------------------------===// 231 // FloatAttr 232 //===----------------------------------------------------------------------===// 233 234 FloatAttr FloatAttr::get(Type type, double value) { 235 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 236 } 237 238 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { 239 return Base::getChecked(loc, StandardAttributes::Float, type, value); 240 } 241 242 FloatAttr FloatAttr::get(Type type, const APFloat &value) { 243 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 244 } 245 246 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { 247 return Base::getChecked(loc, StandardAttributes::Float, type, value); 248 } 249 250 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } 251 252 double FloatAttr::getValueAsDouble() const { 253 return getValueAsDouble(getValue()); 254 } 255 double FloatAttr::getValueAsDouble(APFloat value) { 256 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 257 bool losesInfo = false; 258 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 259 &losesInfo); 260 } 261 return value.convertToDouble(); 262 } 263 264 /// Verify construction invariants. 265 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) { 266 if (!type.isa<FloatType>()) 267 return emitError(loc, "expected floating point type"); 268 return success(); 269 } 270 271 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 272 double value) { 273 return verifyFloatTypeInvariants(loc, type); 274 } 275 276 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 277 const APFloat &value) { 278 // Verify that the type is correct. 279 if (failed(verifyFloatTypeInvariants(loc, type))) 280 return failure(); 281 282 // Verify that the type semantics match that of the value. 283 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 284 return emitError( 285 loc, "FloatAttr type doesn't match the type implied by its value"); 286 } 287 return success(); 288 } 289 290 //===----------------------------------------------------------------------===// 291 // SymbolRefAttr 292 //===----------------------------------------------------------------------===// 293 294 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { 295 return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None) 296 .cast<FlatSymbolRefAttr>(); 297 } 298 299 SymbolRefAttr SymbolRefAttr::get(StringRef value, 300 ArrayRef<FlatSymbolRefAttr> nestedReferences, 301 MLIRContext *ctx) { 302 return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences); 303 } 304 305 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } 306 307 StringRef SymbolRefAttr::getLeafReference() const { 308 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 309 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 310 } 311 312 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const { 313 return getImpl()->getNestedRefs(); 314 } 315 316 //===----------------------------------------------------------------------===// 317 // IntegerAttr 318 //===----------------------------------------------------------------------===// 319 320 IntegerAttr IntegerAttr::get(Type type, const APInt &value) { 321 return Base::get(type.getContext(), StandardAttributes::Integer, type, value); 322 } 323 324 IntegerAttr IntegerAttr::get(Type type, int64_t value) { 325 // This uses 64 bit APInts by default for index type. 326 if (type.isIndex()) 327 return get(type, APInt(IndexType::kInternalStorageBitWidth, value)); 328 329 auto intType = type.cast<IntegerType>(); 330 return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger())); 331 } 332 333 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } 334 335 int64_t IntegerAttr::getInt() const { 336 assert((getImpl()->getType().isIndex() || 337 getImpl()->getType().isSignlessInteger()) && 338 "must be signless integer"); 339 return getValue().getSExtValue(); 340 } 341 342 int64_t IntegerAttr::getSInt() const { 343 assert(getImpl()->getType().isSignedInteger() && "must be signed integer"); 344 return getValue().getSExtValue(); 345 } 346 347 uint64_t IntegerAttr::getUInt() const { 348 assert(getImpl()->getType().isUnsignedInteger() && 349 "must be unsigned integer"); 350 return getValue().getZExtValue(); 351 } 352 353 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) { 354 if (type.isa<IntegerType>() || type.isa<IndexType>()) 355 return success(); 356 return emitError(loc, "expected integer or index type"); 357 } 358 359 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 360 int64_t value) { 361 return verifyIntegerTypeInvariants(loc, type); 362 } 363 364 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 365 const APInt &value) { 366 if (failed(verifyIntegerTypeInvariants(loc, type))) 367 return failure(); 368 if (auto integerType = type.dyn_cast<IntegerType>()) 369 if (integerType.getWidth() != value.getBitWidth()) 370 return emitError(loc, "integer type bit width (") 371 << integerType.getWidth() << ") doesn't match value bit width (" 372 << value.getBitWidth() << ")"; 373 return success(); 374 } 375 376 //===----------------------------------------------------------------------===// 377 // IntegerSetAttr 378 //===----------------------------------------------------------------------===// 379 380 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { 381 return Base::get(value.getConstraint(0).getContext(), 382 StandardAttributes::IntegerSet, value); 383 } 384 385 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } 386 387 //===----------------------------------------------------------------------===// 388 // OpaqueAttr 389 //===----------------------------------------------------------------------===// 390 391 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, 392 MLIRContext *context) { 393 return Base::get(context, StandardAttributes::Opaque, dialect, attrData, 394 type); 395 } 396 397 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, 398 Type type, Location location) { 399 return Base::getChecked(location, StandardAttributes::Opaque, dialect, 400 attrData, type); 401 } 402 403 /// Returns the dialect namespace of the opaque attribute. 404 Identifier OpaqueAttr::getDialectNamespace() const { 405 return getImpl()->dialectNamespace; 406 } 407 408 /// Returns the raw attribute data of the opaque attribute. 409 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } 410 411 /// Verify the construction of an opaque attribute. 412 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc, 413 Identifier dialect, 414 StringRef attrData, 415 Type type) { 416 if (!Dialect::isValidNamespace(dialect.strref())) 417 return emitError(loc, "invalid dialect namespace '") << dialect << "'"; 418 return success(); 419 } 420 421 //===----------------------------------------------------------------------===// 422 // StringAttr 423 //===----------------------------------------------------------------------===// 424 425 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { 426 return get(bytes, NoneType::get(context)); 427 } 428 429 /// Get an instance of a StringAttr with the given string and Type. 430 StringAttr StringAttr::get(StringRef bytes, Type type) { 431 return Base::get(type.getContext(), StandardAttributes::String, bytes, type); 432 } 433 434 StringRef StringAttr::getValue() const { return getImpl()->value; } 435 436 //===----------------------------------------------------------------------===// 437 // TypeAttr 438 //===----------------------------------------------------------------------===// 439 440 TypeAttr TypeAttr::get(Type value) { 441 return Base::get(value.getContext(), StandardAttributes::Type, value); 442 } 443 444 Type TypeAttr::getValue() const { return getImpl()->value; } 445 446 //===----------------------------------------------------------------------===// 447 // ElementsAttr 448 //===----------------------------------------------------------------------===// 449 450 ShapedType ElementsAttr::getType() const { 451 return Attribute::getType().cast<ShapedType>(); 452 } 453 454 /// Returns the number of elements held by this attribute. 455 int64_t ElementsAttr::getNumElements() const { 456 return getType().getNumElements(); 457 } 458 459 /// Return the value at the given index. If index does not refer to a valid 460 /// element, then a null attribute is returned. 461 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 462 switch (getKind()) { 463 case StandardAttributes::DenseIntOrFPElements: 464 return cast<DenseElementsAttr>().getValue(index); 465 case StandardAttributes::OpaqueElements: 466 return cast<OpaqueElementsAttr>().getValue(index); 467 case StandardAttributes::SparseElements: 468 return cast<SparseElementsAttr>().getValue(index); 469 default: 470 llvm_unreachable("unknown ElementsAttr kind"); 471 } 472 } 473 474 /// Return if the given 'index' refers to a valid element in this attribute. 475 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 476 auto type = getType(); 477 478 // Verify that the rank of the indices matches the held type. 479 auto rank = type.getRank(); 480 if (rank != static_cast<int64_t>(index.size())) 481 return false; 482 483 // Verify that all of the indices are within the shape dimensions. 484 auto shape = type.getShape(); 485 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 486 return static_cast<int64_t>(index[i]) < shape[i]; 487 }); 488 } 489 490 ElementsAttr 491 ElementsAttr::mapValues(Type newElementType, 492 function_ref<APInt(const APInt &)> mapping) const { 493 switch (getKind()) { 494 case StandardAttributes::DenseIntOrFPElements: 495 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 496 default: 497 llvm_unreachable("unsupported ElementsAttr subtype"); 498 } 499 } 500 501 ElementsAttr 502 ElementsAttr::mapValues(Type newElementType, 503 function_ref<APInt(const APFloat &)> mapping) const { 504 switch (getKind()) { 505 case StandardAttributes::DenseIntOrFPElements: 506 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 507 default: 508 llvm_unreachable("unsupported ElementsAttr subtype"); 509 } 510 } 511 512 /// Returns the 1 dimensional flattened row-major index from the given 513 /// multi-dimensional index. 514 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 515 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 516 auto type = getType(); 517 518 // Reduce the provided multidimensional index into a flattended 1D row-major 519 // index. 520 auto rank = type.getRank(); 521 auto shape = type.getShape(); 522 uint64_t valueIndex = 0; 523 uint64_t dimMultiplier = 1; 524 for (int i = rank - 1; i >= 0; --i) { 525 valueIndex += index[i] * dimMultiplier; 526 dimMultiplier *= shape[i]; 527 } 528 return valueIndex; 529 } 530 531 //===----------------------------------------------------------------------===// 532 // DenseElementAttr Utilities 533 //===----------------------------------------------------------------------===// 534 535 /// Get the bitwidth of a dense element type within the buffer. 536 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 537 static size_t getDenseElementStorageWidth(size_t origWidth) { 538 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 539 } 540 static size_t getDenseElementStorageWidth(Type elementType) { 541 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 542 } 543 544 /// Set a bit to a specific value. 545 static void setBit(char *rawData, size_t bitPos, bool value) { 546 if (value) 547 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 548 else 549 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 550 } 551 552 /// Return the value of the specified bit. 553 static bool getBit(const char *rawData, size_t bitPos) { 554 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 555 } 556 557 /// Get start position of actual data in `value`. Actual data is 558 /// stored in last `bitWidth`/CHAR_BIT bytes in big endian. 559 static char *getAPIntDataPos(APInt &value, size_t bitWidth) { 560 char *dataPos = 561 const_cast<char *>(reinterpret_cast<const char *>(value.getRawData())); 562 if (llvm::support::endian::system_endianness() == 563 llvm::support::endianness::big) 564 dataPos = dataPos + 8 - llvm::divideCeil(bitWidth, CHAR_BIT); 565 return dataPos; 566 } 567 568 /// Read APInt `value` from appropriate position. 569 static void readAPInt(APInt &value, size_t bitWidth, char *outData) { 570 char *dataPos = getAPIntDataPos(value, bitWidth); 571 std::copy_n(dataPos, llvm::divideCeil(bitWidth, CHAR_BIT), outData); 572 } 573 574 /// Write `inData` to appropriate position of APInt `value`. 575 static void writeAPInt(const char *inData, size_t bitWidth, APInt &value) { 576 char *dataPos = getAPIntDataPos(value, bitWidth); 577 std::copy_n(inData, llvm::divideCeil(bitWidth, CHAR_BIT), dataPos); 578 } 579 580 /// Writes value to the bit position `bitPos` in array `rawData`. 581 static void writeBits(char *rawData, size_t bitPos, APInt value) { 582 size_t bitWidth = value.getBitWidth(); 583 584 // If the bitwidth is 1 we just toggle the specific bit. 585 if (bitWidth == 1) 586 return setBit(rawData, bitPos, value.isOneValue()); 587 588 // Otherwise, the bit position is guaranteed to be byte aligned. 589 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 590 readAPInt(value, bitWidth, rawData + (bitPos / CHAR_BIT)); 591 } 592 593 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 594 /// `rawData`. 595 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 596 // Handle a boolean bit position. 597 if (bitWidth == 1) 598 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 599 600 // Otherwise, the bit position must be 8-bit aligned. 601 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 602 APInt result(bitWidth, 0); 603 writeAPInt(rawData + (bitPos / CHAR_BIT), bitWidth, result); 604 return result; 605 } 606 607 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the 608 /// same element count as 'type'. 609 template <typename Values> 610 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 611 return (values.size() == 1) || 612 (type.getNumElements() == static_cast<int64_t>(values.size())); 613 } 614 615 //===----------------------------------------------------------------------===// 616 // DenseElementAttr Iterators 617 //===----------------------------------------------------------------------===// 618 619 //===----------------------------------------------------------------------===// 620 // AttributeElementIterator 621 622 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 623 DenseElementsAttr attr, size_t index) 624 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 625 Attribute, Attribute, Attribute>( 626 attr.getAsOpaquePointer(), index) {} 627 628 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 629 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 630 Type eltTy = owner.getType().getElementType(); 631 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) { 632 if (intEltTy.getWidth() == 1) 633 return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(), 634 owner.getContext()); 635 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 636 } 637 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 638 IntElementIterator intIt(owner, index); 639 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 640 return FloatAttr::get(eltTy, *floatIt); 641 } 642 if (owner.isa<DenseStringElementsAttr>()) { 643 ArrayRef<StringRef> vals = owner.getRawStringData(); 644 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 645 } 646 llvm_unreachable("unexpected element type"); 647 } 648 649 //===----------------------------------------------------------------------===// 650 // BoolElementIterator 651 652 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 653 DenseElementsAttr attr, size_t dataIndex) 654 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 655 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 656 657 bool DenseElementsAttr::BoolElementIterator::operator*() const { 658 return getBit(getData(), getDataIndex()); 659 } 660 661 //===----------------------------------------------------------------------===// 662 // IntElementIterator 663 664 DenseElementsAttr::IntElementIterator::IntElementIterator( 665 DenseElementsAttr attr, size_t dataIndex) 666 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 667 attr.getRawData().data(), attr.isSplat(), dataIndex), 668 bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} 669 670 APInt DenseElementsAttr::IntElementIterator::operator*() const { 671 return readBits(getData(), 672 getDataIndex() * getDenseElementStorageWidth(bitWidth), 673 bitWidth); 674 } 675 676 //===----------------------------------------------------------------------===// 677 // ComplexIntElementIterator 678 679 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 680 DenseElementsAttr attr, size_t dataIndex) 681 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 682 std::complex<APInt>, std::complex<APInt>, 683 std::complex<APInt>>( 684 attr.getRawData().data(), attr.isSplat(), dataIndex) { 685 auto complexType = attr.getType().getElementType().cast<ComplexType>(); 686 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 687 } 688 689 std::complex<APInt> 690 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 691 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 692 size_t offset = getDataIndex() * storageWidth * 2; 693 return {readBits(getData(), offset, bitWidth), 694 readBits(getData(), offset + storageWidth, bitWidth)}; 695 } 696 697 //===----------------------------------------------------------------------===// 698 // FloatElementIterator 699 700 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 701 const llvm::fltSemantics &smt, IntElementIterator it) 702 : llvm::mapped_iterator<IntElementIterator, 703 std::function<APFloat(const APInt &)>>( 704 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 705 706 //===----------------------------------------------------------------------===// 707 // ComplexFloatElementIterator 708 709 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( 710 const llvm::fltSemantics &smt, ComplexIntElementIterator it) 711 : llvm::mapped_iterator< 712 ComplexIntElementIterator, 713 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( 714 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { 715 return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; 716 }) {} 717 718 //===----------------------------------------------------------------------===// 719 // DenseElementsAttr 720 //===----------------------------------------------------------------------===// 721 722 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 723 ArrayRef<Attribute> values) { 724 assert(hasSameElementsOrSplat(type, values)); 725 726 // If the element type is not based on int/float/index, assume it is a string 727 // type. 728 auto eltType = type.getElementType(); 729 if (!type.getElementType().isIntOrIndexOrFloat()) { 730 SmallVector<StringRef, 8> stringValues; 731 stringValues.reserve(values.size()); 732 for (Attribute attr : values) { 733 assert(attr.isa<StringAttr>() && 734 "expected string value for non integer/index/float element"); 735 stringValues.push_back(attr.cast<StringAttr>().getValue()); 736 } 737 return get(type, stringValues); 738 } 739 740 // Otherwise, get the raw storage width to use for the allocation. 741 size_t bitWidth = getDenseElementBitWidth(eltType); 742 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 743 744 // Compress the attribute values into a character buffer. 745 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 746 values.size()); 747 APInt intVal; 748 for (unsigned i = 0, e = values.size(); i < e; ++i) { 749 assert(eltType == values[i].getType() && 750 "expected attribute value to have element type"); 751 752 switch (eltType.getKind()) { 753 case StandardTypes::BF16: 754 case StandardTypes::F16: 755 case StandardTypes::F32: 756 case StandardTypes::F64: 757 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 758 break; 759 case StandardTypes::Integer: 760 case StandardTypes::Index: 761 intVal = values[i].isa<BoolAttr>() 762 ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0) 763 : values[i].cast<IntegerAttr>().getValue(); 764 break; 765 default: 766 llvm_unreachable("unexpected element type"); 767 } 768 assert(intVal.getBitWidth() == bitWidth && 769 "expected value to have same bitwidth as element type"); 770 writeBits(data.data(), i * storageBitWidth, intVal); 771 } 772 return DenseIntOrFPElementsAttr::getRaw(type, data, 773 /*isSplat=*/(values.size() == 1)); 774 } 775 776 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 777 ArrayRef<bool> values) { 778 assert(hasSameElementsOrSplat(type, values)); 779 assert(type.getElementType().isInteger(1)); 780 781 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 782 for (int i = 0, e = values.size(); i != e; ++i) 783 setBit(buff.data(), i, values[i]); 784 return DenseIntOrFPElementsAttr::getRaw(type, buff, 785 /*isSplat=*/(values.size() == 1)); 786 } 787 788 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 789 ArrayRef<StringRef> values) { 790 assert(!type.getElementType().isIntOrFloat()); 791 return DenseStringElementsAttr::get(type, values); 792 } 793 794 /// Constructs a dense integer elements attribute from an array of APInt 795 /// values. Each APInt value is expected to have the same bitwidth as the 796 /// element type of 'type'. 797 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 798 ArrayRef<APInt> values) { 799 assert(type.getElementType().isIntOrIndex()); 800 assert(hasSameElementsOrSplat(type, values)); 801 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 802 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 803 /*isSplat=*/(values.size() == 1)); 804 } 805 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 806 ArrayRef<std::complex<APInt>> values) { 807 ComplexType complex = type.getElementType().cast<ComplexType>(); 808 assert(complex.getElementType().isa<IntegerType>()); 809 assert(hasSameElementsOrSplat(type, values)); 810 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 811 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 812 values.size() * 2); 813 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, 814 /*isSplat=*/(values.size() == 1)); 815 } 816 817 // Constructs a dense float elements attribute from an array of APFloat 818 // values. Each APFloat value is expected to have the same bitwidth as the 819 // element type of 'type'. 820 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 821 ArrayRef<APFloat> values) { 822 assert(type.getElementType().isa<FloatType>()); 823 assert(hasSameElementsOrSplat(type, values)); 824 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 825 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 826 /*isSplat=*/(values.size() == 1)); 827 } 828 DenseElementsAttr 829 DenseElementsAttr::get(ShapedType type, 830 ArrayRef<std::complex<APFloat>> values) { 831 ComplexType complex = type.getElementType().cast<ComplexType>(); 832 assert(complex.getElementType().isa<FloatType>()); 833 assert(hasSameElementsOrSplat(type, values)); 834 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 835 values.size() * 2); 836 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 837 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, 838 /*isSplat=*/(values.size() == 1)); 839 } 840 841 /// Construct a dense elements attribute from a raw buffer representing the 842 /// data for this attribute. Users should generally not use this methods as 843 /// the expected buffer format may not be a form the user expects. 844 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 845 ArrayRef<char> rawBuffer, 846 bool isSplatBuffer) { 847 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 848 } 849 850 /// Returns true if the given buffer is a valid raw buffer for the given type. 851 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 852 ArrayRef<char> rawBuffer, 853 bool &detectedSplat) { 854 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 855 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 856 857 // Storage width of 1 is special as it is packed by the bit. 858 if (storageWidth == 1) { 859 // Check for a splat, or a buffer equal to the number of elements. 860 if ((detectedSplat = rawBuffer.size() == 1)) 861 return true; 862 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); 863 } 864 // All other types are 8-bit aligned. 865 if ((detectedSplat = rawBufferWidth == storageWidth)) 866 return true; 867 return rawBufferWidth == (storageWidth * type.getNumElements()); 868 } 869 870 /// Check the information for a C++ data type, check if this type is valid for 871 /// the current attribute. This method is used to verify specific type 872 /// invariants that the templatized 'getValues' method cannot. 873 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 874 bool isSigned) { 875 // Make sure that the data element size is the same as the type element width. 876 if (getDenseElementBitWidth(type) != 877 static_cast<size_t>(dataEltSize * CHAR_BIT)) 878 return false; 879 880 // Check that the element type is either float or integer or index. 881 if (!isInt) 882 return type.isa<FloatType>(); 883 if (type.isIndex()) 884 return true; 885 886 auto intType = type.dyn_cast<IntegerType>(); 887 if (!intType) 888 return false; 889 890 // Make sure signedness semantics is consistent. 891 if (intType.isSignless()) 892 return true; 893 return intType.isSigned() ? isSigned : !isSigned; 894 } 895 896 /// Defaults down the subclass implementation. 897 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 898 ArrayRef<char> data, 899 int64_t dataEltSize, 900 bool isInt, bool isSigned) { 901 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 902 isSigned); 903 } 904 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 905 ArrayRef<char> data, 906 int64_t dataEltSize, 907 bool isInt, 908 bool isSigned) { 909 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 910 isInt, isSigned); 911 } 912 913 /// A method used to verify specific type invariants that the templatized 'get' 914 /// method cannot. 915 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 916 bool isSigned) const { 917 return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt, 918 isSigned); 919 } 920 921 /// Check the information for a C++ data type, check if this type is valid for 922 /// the current attribute. 923 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 924 bool isSigned) const { 925 return ::isValidIntOrFloat( 926 getType().getElementType().cast<ComplexType>().getElementType(), 927 dataEltSize / 2, isInt, isSigned); 928 } 929 930 /// Returns if this attribute corresponds to a splat, i.e. if all element 931 /// values are the same. 932 bool DenseElementsAttr::isSplat() const { 933 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 934 } 935 936 /// Return the held element values as a range of Attributes. 937 auto DenseElementsAttr::getAttributeValues() const 938 -> llvm::iterator_range<AttributeElementIterator> { 939 return {attr_value_begin(), attr_value_end()}; 940 } 941 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 942 return AttributeElementIterator(*this, 0); 943 } 944 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 945 return AttributeElementIterator(*this, getNumElements()); 946 } 947 948 /// Return the held element values as a range of bool. The element type of 949 /// this attribute must be of integer type of bitwidth 1. 950 auto DenseElementsAttr::getBoolValues() const 951 -> llvm::iterator_range<BoolElementIterator> { 952 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 953 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 954 (void)eltType; 955 return {BoolElementIterator(*this, 0), 956 BoolElementIterator(*this, getNumElements())}; 957 } 958 959 /// Return the held element values as a range of APInts. The element type of 960 /// this attribute must be of integer type. 961 auto DenseElementsAttr::getIntValues() const 962 -> llvm::iterator_range<IntElementIterator> { 963 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 964 return {raw_int_begin(), raw_int_end()}; 965 } 966 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 967 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 968 return raw_int_begin(); 969 } 970 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 971 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 972 return raw_int_end(); 973 } 974 auto DenseElementsAttr::getComplexIntValues() const 975 -> llvm::iterator_range<ComplexIntElementIterator> { 976 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 977 (void)eltTy; 978 assert(eltTy.isa<IntegerType>() && "expected complex integral type"); 979 return {ComplexIntElementIterator(*this, 0), 980 ComplexIntElementIterator(*this, getNumElements())}; 981 } 982 983 /// Return the held element values as a range of APFloat. The element type of 984 /// this attribute must be of float type. 985 auto DenseElementsAttr::getFloatValues() const 986 -> llvm::iterator_range<FloatElementIterator> { 987 auto elementType = getType().getElementType().cast<FloatType>(); 988 const auto &elementSemantics = elementType.getFloatSemantics(); 989 return {FloatElementIterator(elementSemantics, raw_int_begin()), 990 FloatElementIterator(elementSemantics, raw_int_end())}; 991 } 992 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 993 return getFloatValues().begin(); 994 } 995 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 996 return getFloatValues().end(); 997 } 998 auto DenseElementsAttr::getComplexFloatValues() const 999 -> llvm::iterator_range<ComplexFloatElementIterator> { 1000 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 1001 assert(eltTy.isa<FloatType>() && "expected complex float type"); 1002 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 1003 return {{semantics, {*this, 0}}, 1004 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 1005 } 1006 1007 /// Return the raw storage data held by this attribute. 1008 ArrayRef<char> DenseElementsAttr::getRawData() const { 1009 return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data; 1010 } 1011 1012 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 1013 return static_cast<DenseStringElementsAttributeStorage *>(impl)->data; 1014 } 1015 1016 /// Return a new DenseElementsAttr that has the same data as the current 1017 /// attribute, but has been reshaped to 'newType'. The new type must have the 1018 /// same total number of elements as well as element type. 1019 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 1020 ShapedType curType = getType(); 1021 if (curType == newType) 1022 return *this; 1023 1024 (void)curType; 1025 assert(newType.getElementType() == curType.getElementType() && 1026 "expected the same element type"); 1027 assert(newType.getNumElements() == curType.getNumElements() && 1028 "expected the same number of elements"); 1029 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 1030 } 1031 1032 DenseElementsAttr 1033 DenseElementsAttr::mapValues(Type newElementType, 1034 function_ref<APInt(const APInt &)> mapping) const { 1035 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 1036 } 1037 1038 DenseElementsAttr DenseElementsAttr::mapValues( 1039 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1040 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 1041 } 1042 1043 //===----------------------------------------------------------------------===// 1044 // DenseStringElementsAttr 1045 //===----------------------------------------------------------------------===// 1046 1047 DenseStringElementsAttr 1048 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) { 1049 return Base::get(type.getContext(), StandardAttributes::DenseStringElements, 1050 type, values, (values.size() == 1)); 1051 } 1052 1053 //===----------------------------------------------------------------------===// 1054 // DenseIntOrFPElementsAttr 1055 //===----------------------------------------------------------------------===// 1056 1057 /// Utility method to write a range of APInt values to a buffer. 1058 template <typename APRangeT> 1059 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1060 APRangeT &&values) { 1061 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); 1062 size_t offset = 0; 1063 for (auto it = values.begin(), e = values.end(); it != e; 1064 ++it, offset += storageWidth) { 1065 assert((*it).getBitWidth() <= storageWidth); 1066 writeBits(data.data(), offset, *it); 1067 } 1068 } 1069 1070 /// Constructs a dense elements attribute from an array of raw APFloat values. 1071 /// Each APFloat value is expected to have the same bitwidth as the element 1072 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1073 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1074 size_t storageWidth, 1075 ArrayRef<APFloat> values, 1076 bool isSplat) { 1077 std::vector<char> data; 1078 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1079 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1080 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1081 } 1082 1083 /// Constructs a dense elements attribute from an array of raw APInt values. 1084 /// Each APInt value is expected to have the same bitwidth as the element type 1085 /// of 'type'. 1086 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1087 size_t storageWidth, 1088 ArrayRef<APInt> values, 1089 bool isSplat) { 1090 std::vector<char> data; 1091 writeAPIntsToBuffer(storageWidth, data, values); 1092 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1093 } 1094 1095 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1096 ArrayRef<char> data, 1097 bool isSplat) { 1098 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 1099 "type must be ranked tensor or vector"); 1100 assert(type.hasStaticShape() && "type must have static shape"); 1101 return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements, 1102 type, data, isSplat); 1103 } 1104 1105 /// Overload of the raw 'get' method that asserts that the given type is of 1106 /// complex type. This method is used to verify type invariants that the 1107 /// templatized 'get' method cannot. 1108 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1109 ArrayRef<char> data, 1110 int64_t dataEltSize, 1111 bool isInt, 1112 bool isSigned) { 1113 assert(::isValidIntOrFloat( 1114 type.getElementType().cast<ComplexType>().getElementType(), 1115 dataEltSize / 2, isInt, isSigned)); 1116 1117 int64_t numElements = data.size() / dataEltSize; 1118 assert(numElements == 1 || numElements == type.getNumElements()); 1119 return getRaw(type, data, /*isSplat=*/numElements == 1); 1120 } 1121 1122 /// Overload of the 'getRaw' method that asserts that the given type is of 1123 /// integer type. This method is used to verify type invariants that the 1124 /// templatized 'get' method cannot. 1125 DenseElementsAttr 1126 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1127 int64_t dataEltSize, bool isInt, 1128 bool isSigned) { 1129 assert( 1130 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1131 1132 int64_t numElements = data.size() / dataEltSize; 1133 assert(numElements == 1 || numElements == type.getNumElements()); 1134 return getRaw(type, data, /*isSplat=*/numElements == 1); 1135 } 1136 1137 //===----------------------------------------------------------------------===// 1138 // DenseFPElementsAttr 1139 //===----------------------------------------------------------------------===// 1140 1141 template <typename Fn, typename Attr> 1142 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1143 Type newElementType, 1144 llvm::SmallVectorImpl<char> &data) { 1145 size_t bitWidth = getDenseElementBitWidth(newElementType); 1146 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1147 1148 ShapedType newArrayType; 1149 if (inType.isa<RankedTensorType>()) 1150 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1151 else if (inType.isa<UnrankedTensorType>()) 1152 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1153 else if (inType.isa<VectorType>()) 1154 newArrayType = VectorType::get(inType.getShape(), newElementType); 1155 else 1156 assert(newArrayType && "Unhandled tensor type"); 1157 1158 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1159 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 1160 1161 // Functor used to process a single element value of the attribute. 1162 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1163 auto newInt = mapping(value); 1164 assert(newInt.getBitWidth() == bitWidth); 1165 writeBits(data.data(), index * storageBitWidth, newInt); 1166 }; 1167 1168 // Check for the splat case. 1169 if (attr.isSplat()) { 1170 processElt(*attr.begin(), /*index=*/0); 1171 return newArrayType; 1172 } 1173 1174 // Otherwise, process all of the element values. 1175 uint64_t elementIdx = 0; 1176 for (auto value : attr) 1177 processElt(value, elementIdx++); 1178 return newArrayType; 1179 } 1180 1181 DenseElementsAttr DenseFPElementsAttr::mapValues( 1182 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1183 llvm::SmallVector<char, 8> elementData; 1184 auto newArrayType = 1185 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1186 1187 return getRaw(newArrayType, elementData, isSplat()); 1188 } 1189 1190 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1191 bool DenseFPElementsAttr::classof(Attribute attr) { 1192 return attr.isa<DenseElementsAttr>() && 1193 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1194 } 1195 1196 //===----------------------------------------------------------------------===// 1197 // DenseIntElementsAttr 1198 //===----------------------------------------------------------------------===// 1199 1200 DenseElementsAttr DenseIntElementsAttr::mapValues( 1201 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1202 llvm::SmallVector<char, 8> elementData; 1203 auto newArrayType = 1204 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1205 1206 return getRaw(newArrayType, elementData, isSplat()); 1207 } 1208 1209 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1210 bool DenseIntElementsAttr::classof(Attribute attr) { 1211 return attr.isa<DenseElementsAttr>() && 1212 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1213 } 1214 1215 //===----------------------------------------------------------------------===// 1216 // OpaqueElementsAttr 1217 //===----------------------------------------------------------------------===// 1218 1219 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, 1220 StringRef bytes) { 1221 assert(TensorType::isValidElementType(type.getElementType()) && 1222 "Input element type should be a valid tensor element type"); 1223 return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type, 1224 dialect, bytes); 1225 } 1226 1227 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } 1228 1229 /// Return the value at the given index. If index does not refer to a valid 1230 /// element, then a null attribute is returned. 1231 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1232 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1233 if (Dialect *dialect = getDialect()) 1234 return dialect->extractElementHook(*this, index); 1235 return Attribute(); 1236 } 1237 1238 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } 1239 1240 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1241 if (auto *d = getDialect()) 1242 return d->decodeHook(*this, result); 1243 return true; 1244 } 1245 1246 //===----------------------------------------------------------------------===// 1247 // SparseElementsAttr 1248 //===----------------------------------------------------------------------===// 1249 1250 SparseElementsAttr SparseElementsAttr::get(ShapedType type, 1251 DenseElementsAttr indices, 1252 DenseElementsAttr values) { 1253 assert(indices.getType().getElementType().isInteger(64) && 1254 "expected sparse indices to be 64-bit integer values"); 1255 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 1256 "type must be ranked tensor or vector"); 1257 assert(type.hasStaticShape() && "type must have static shape"); 1258 return Base::get(type.getContext(), StandardAttributes::SparseElements, type, 1259 indices.cast<DenseIntElementsAttr>(), values); 1260 } 1261 1262 DenseIntElementsAttr SparseElementsAttr::getIndices() const { 1263 return getImpl()->indices; 1264 } 1265 1266 DenseElementsAttr SparseElementsAttr::getValues() const { 1267 return getImpl()->values; 1268 } 1269 1270 /// Return the value of the element at the given index. 1271 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1272 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1273 auto type = getType(); 1274 1275 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1276 // as a 1-D index array. 1277 auto sparseIndices = getIndices(); 1278 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1279 1280 // Check to see if the indices are a splat. 1281 if (sparseIndices.isSplat()) { 1282 // If the index is also not a splat of the index value, we know that the 1283 // value is zero. 1284 auto splatIndex = *sparseIndexValues.begin(); 1285 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1286 return getZeroAttr(); 1287 1288 // If the indices are a splat, we also expect the values to be a splat. 1289 assert(getValues().isSplat() && "expected splat values"); 1290 return getValues().getSplatValue(); 1291 } 1292 1293 // Build a mapping between known indices and the offset of the stored element. 1294 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1295 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1296 size_t rank = type.getRank(); 1297 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1298 mappedIndices.try_emplace( 1299 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1300 1301 // Look for the provided index key within the mapped indices. If the provided 1302 // index is not found, then return a zero attribute. 1303 auto it = mappedIndices.find(index); 1304 if (it == mappedIndices.end()) 1305 return getZeroAttr(); 1306 1307 // Otherwise, return the held sparse value element. 1308 return getValues().getValue(it->second); 1309 } 1310 1311 /// Get a zero APFloat for the given sparse attribute. 1312 APFloat SparseElementsAttr::getZeroAPFloat() const { 1313 auto eltType = getType().getElementType().cast<FloatType>(); 1314 return APFloat(eltType.getFloatSemantics()); 1315 } 1316 1317 /// Get a zero APInt for the given sparse attribute. 1318 APInt SparseElementsAttr::getZeroAPInt() const { 1319 auto eltType = getType().getElementType().cast<IntegerType>(); 1320 return APInt::getNullValue(eltType.getWidth()); 1321 } 1322 1323 /// Get a zero attribute for the given attribute type. 1324 Attribute SparseElementsAttr::getZeroAttr() const { 1325 auto eltType = getType().getElementType(); 1326 1327 // Handle floating point elements. 1328 if (eltType.isa<FloatType>()) 1329 return FloatAttr::get(eltType, 0); 1330 1331 // Otherwise, this is an integer. 1332 auto intEltTy = eltType.cast<IntegerType>(); 1333 if (intEltTy.getWidth() == 1) 1334 return BoolAttr::get(false, eltType.getContext()); 1335 return IntegerAttr::get(eltType, 0); 1336 } 1337 1338 /// Flatten, and return, all of the sparse indices in this attribute in 1339 /// row-major order. 1340 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1341 std::vector<ptrdiff_t> flatSparseIndices; 1342 1343 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1344 // as a 1-D index array. 1345 auto sparseIndices = getIndices(); 1346 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1347 if (sparseIndices.isSplat()) { 1348 SmallVector<uint64_t, 8> indices(getType().getRank(), 1349 *sparseIndexValues.begin()); 1350 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1351 return flatSparseIndices; 1352 } 1353 1354 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1355 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1356 size_t rank = getType().getRank(); 1357 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1358 flatSparseIndices.push_back(getFlattenedIndex( 1359 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1360 return flatSparseIndices; 1361 } 1362 1363 //===----------------------------------------------------------------------===// 1364 // MutableDictionaryAttr 1365 //===----------------------------------------------------------------------===// 1366 1367 MutableDictionaryAttr::MutableDictionaryAttr( 1368 ArrayRef<NamedAttribute> attributes) { 1369 setAttrs(attributes); 1370 } 1371 1372 ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const { 1373 return attrs ? attrs.getValue() : llvm::None; 1374 } 1375 1376 /// Replace the held attributes with ones provided in 'newAttrs'. 1377 void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) { 1378 // Don't create an attribute list if there are no attributes. 1379 if (attributes.empty()) 1380 attrs = nullptr; 1381 else 1382 attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext()); 1383 } 1384 1385 /// Return the specified attribute if present, null otherwise. 1386 Attribute MutableDictionaryAttr::get(StringRef name) const { 1387 return attrs ? attrs.get(name) : nullptr; 1388 } 1389 1390 /// Return the specified attribute if present, null otherwise. 1391 Attribute MutableDictionaryAttr::get(Identifier name) const { 1392 return attrs ? attrs.get(name) : nullptr; 1393 } 1394 1395 /// Return the specified named attribute if present, None otherwise. 1396 Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const { 1397 return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>(); 1398 } 1399 Optional<NamedAttribute> 1400 MutableDictionaryAttr::getNamed(Identifier name) const { 1401 return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>(); 1402 } 1403 1404 /// If the an attribute exists with the specified name, change it to the new 1405 /// value. Otherwise, add a new attribute with the specified name/value. 1406 void MutableDictionaryAttr::set(Identifier name, Attribute value) { 1407 assert(value && "attributes may never be null"); 1408 1409 // Look for an existing value for the given name, and set it in-place. 1410 ArrayRef<NamedAttribute> values = getAttrs(); 1411 auto it = llvm::find_if( 1412 values, [name](NamedAttribute attr) { return attr.first == name; }); 1413 if (it != values.end()) { 1414 // Bail out early if the value is the same as what we already have. 1415 if (it->second == value) 1416 return; 1417 1418 SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end()); 1419 newAttrs[it - values.begin()].second = value; 1420 attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); 1421 return; 1422 } 1423 1424 // Otherwise, insert the new attribute into its sorted position. 1425 it = llvm::lower_bound(values, name, compareNamedAttributeWithName); 1426 SmallVector<NamedAttribute, 8> newAttrs; 1427 newAttrs.reserve(values.size() + 1); 1428 newAttrs.append(values.begin(), it); 1429 newAttrs.push_back({name, value}); 1430 newAttrs.append(it, values.end()); 1431 attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); 1432 } 1433 1434 /// Remove the attribute with the specified name if it exists. The return 1435 /// value indicates whether the attribute was present or not. 1436 auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult { 1437 auto origAttrs = getAttrs(); 1438 for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { 1439 if (origAttrs[i].first == name) { 1440 // Handle the simple case of removing the only attribute in the list. 1441 if (e == 1) { 1442 attrs = nullptr; 1443 return RemoveResult::Removed; 1444 } 1445 1446 SmallVector<NamedAttribute, 8> newAttrs; 1447 newAttrs.reserve(origAttrs.size() - 1); 1448 newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); 1449 newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); 1450 attrs = DictionaryAttr::getWithSorted(newAttrs, 1451 newAttrs[0].second.getContext()); 1452 return RemoveResult::Removed; 1453 } 1454 } 1455 return RemoveResult::NotFound; 1456 } 1457