1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Attributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Diagnostics.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/Function.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/IR/Types.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 #include "llvm/Support/Endian.h" 20 21 using namespace mlir; 22 using namespace mlir::detail; 23 24 //===----------------------------------------------------------------------===// 25 // AttributeStorage 26 //===----------------------------------------------------------------------===// 27 28 AttributeStorage::AttributeStorage(Type type) 29 : type(type.getAsOpaquePointer()) {} 30 AttributeStorage::AttributeStorage() : type(nullptr) {} 31 32 Type AttributeStorage::getType() const { 33 return Type::getFromOpaquePointer(type); 34 } 35 void AttributeStorage::setType(Type newType) { 36 type = newType.getAsOpaquePointer(); 37 } 38 39 //===----------------------------------------------------------------------===// 40 // Attribute 41 //===----------------------------------------------------------------------===// 42 43 /// Return the type of this attribute. 44 Type Attribute::getType() const { return impl->getType(); } 45 46 /// Return the context this attribute belongs to. 47 MLIRContext *Attribute::getContext() const { return getType().getContext(); } 48 49 /// Get the dialect this attribute is registered to. 50 Dialect &Attribute::getDialect() const { return impl->getDialect(); } 51 52 //===----------------------------------------------------------------------===// 53 // AffineMapAttr 54 //===----------------------------------------------------------------------===// 55 56 AffineMapAttr AffineMapAttr::get(AffineMap value) { 57 return Base::get(value.getContext(), StandardAttributes::AffineMap, value); 58 } 59 60 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } 61 62 //===----------------------------------------------------------------------===// 63 // ArrayAttr 64 //===----------------------------------------------------------------------===// 65 66 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) { 67 return Base::get(context, StandardAttributes::Array, value); 68 } 69 70 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } 71 72 Attribute ArrayAttr::operator[](unsigned idx) const { 73 assert(idx < size() && "index out of bounds"); 74 return getValue()[idx]; 75 } 76 77 //===----------------------------------------------------------------------===// 78 // BoolAttr 79 //===----------------------------------------------------------------------===// 80 81 bool BoolAttr::getValue() const { return getImpl()->value; } 82 83 //===----------------------------------------------------------------------===// 84 // DictionaryAttr 85 //===----------------------------------------------------------------------===// 86 87 /// Helper function that does either an in place sort or sorts from source array 88 /// into destination. If inPlace then storage is both the source and the 89 /// destination, else value is the source and storage destination. Returns 90 /// whether source was sorted. 91 template <bool inPlace> 92 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 93 SmallVectorImpl<NamedAttribute> &storage) { 94 // Specialize for the common case. 95 switch (value.size()) { 96 case 0: 97 // Zero already sorted. 98 break; 99 case 1: 100 // One already sorted but may need to be copied. 101 if (!inPlace) 102 storage.assign({value[0]}); 103 break; 104 case 2: { 105 assert(value[0].first != value[1].first && 106 "DictionaryAttr element names must be unique"); 107 bool isSorted = value[0] < value[1]; 108 if (inPlace) { 109 if (!isSorted) 110 std::swap(storage[0], storage[1]); 111 } else if (isSorted) { 112 storage.assign({value[0], value[1]}); 113 } else { 114 storage.assign({value[1], value[0]}); 115 } 116 return !isSorted; 117 } 118 default: 119 if (!inPlace) 120 storage.assign(value.begin(), value.end()); 121 // Check to see they are sorted already. 122 bool isSorted = llvm::is_sorted(value); 123 if (!isSorted) { 124 // If not, do a general sort. 125 llvm::array_pod_sort(storage.begin(), storage.end()); 126 value = storage; 127 } 128 129 // Ensure that the attribute elements are unique. 130 assert(std::adjacent_find(value.begin(), value.end(), 131 [](NamedAttribute l, NamedAttribute r) { 132 return l.first == r.first; 133 }) == value.end() && 134 "DictionaryAttr element names must be unique"); 135 return !isSorted; 136 } 137 return false; 138 } 139 140 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 141 SmallVectorImpl<NamedAttribute> &storage) { 142 return dictionaryAttrSort</*inPlace=*/false>(value, storage); 143 } 144 145 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 146 return dictionaryAttrSort</*inPlace=*/true>(array, array); 147 } 148 149 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value, 150 MLIRContext *context) { 151 if (value.empty()) 152 return DictionaryAttr::getEmpty(context); 153 assert(llvm::all_of(value, 154 [](const NamedAttribute &attr) { return attr.second; }) && 155 "value cannot have null entries"); 156 157 // We need to sort the element list to canonicalize it. 158 SmallVector<NamedAttribute, 8> storage; 159 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 160 value = storage; 161 162 return Base::get(context, StandardAttributes::Dictionary, value); 163 } 164 /// Construct a dictionary with an array of values that is known to already be 165 /// sorted by name and uniqued. 166 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value, 167 MLIRContext *context) { 168 if (value.empty()) 169 return DictionaryAttr::getEmpty(context); 170 // Ensure that the attribute elements are unique and sorted. 171 assert(llvm::is_sorted(value, 172 [](NamedAttribute l, NamedAttribute r) { 173 return l.first.strref() < r.first.strref(); 174 }) && 175 "expected attribute values to be sorted"); 176 assert(std::adjacent_find(value.begin(), value.end(), 177 [](NamedAttribute l, NamedAttribute r) { 178 return l.first == r.first; 179 }) == value.end() && 180 "DictionaryAttr element names must be unique"); 181 return Base::get(context, StandardAttributes::Dictionary, value); 182 } 183 184 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const { 185 return getImpl()->getElements(); 186 } 187 188 /// Return the specified attribute if present, null otherwise. 189 Attribute DictionaryAttr::get(StringRef name) const { 190 Optional<NamedAttribute> attr = getNamed(name); 191 return attr ? attr->second : nullptr; 192 } 193 Attribute DictionaryAttr::get(Identifier name) const { 194 Optional<NamedAttribute> attr = getNamed(name); 195 return attr ? attr->second : nullptr; 196 } 197 198 /// Return the specified named attribute if present, None otherwise. 199 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 200 ArrayRef<NamedAttribute> values = getValue(); 201 const auto *it = llvm::lower_bound(values, name); 202 return it != values.end() && it->first == name ? *it 203 : Optional<NamedAttribute>(); 204 } 205 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { 206 for (auto elt : getValue()) 207 if (elt.first == name) 208 return elt; 209 return llvm::None; 210 } 211 212 DictionaryAttr::iterator DictionaryAttr::begin() const { 213 return getValue().begin(); 214 } 215 DictionaryAttr::iterator DictionaryAttr::end() const { 216 return getValue().end(); 217 } 218 size_t DictionaryAttr::size() const { return getValue().size(); } 219 220 //===----------------------------------------------------------------------===// 221 // FloatAttr 222 //===----------------------------------------------------------------------===// 223 224 FloatAttr FloatAttr::get(Type type, double value) { 225 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 226 } 227 228 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { 229 return Base::getChecked(loc, StandardAttributes::Float, type, value); 230 } 231 232 FloatAttr FloatAttr::get(Type type, const APFloat &value) { 233 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 234 } 235 236 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { 237 return Base::getChecked(loc, StandardAttributes::Float, type, value); 238 } 239 240 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } 241 242 double FloatAttr::getValueAsDouble() const { 243 return getValueAsDouble(getValue()); 244 } 245 double FloatAttr::getValueAsDouble(APFloat value) { 246 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 247 bool losesInfo = false; 248 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 249 &losesInfo); 250 } 251 return value.convertToDouble(); 252 } 253 254 /// Verify construction invariants. 255 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) { 256 if (!type.isa<FloatType>()) 257 return emitError(loc, "expected floating point type"); 258 return success(); 259 } 260 261 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 262 double value) { 263 return verifyFloatTypeInvariants(loc, type); 264 } 265 266 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 267 const APFloat &value) { 268 // Verify that the type is correct. 269 if (failed(verifyFloatTypeInvariants(loc, type))) 270 return failure(); 271 272 // Verify that the type semantics match that of the value. 273 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 274 return emitError( 275 loc, "FloatAttr type doesn't match the type implied by its value"); 276 } 277 return success(); 278 } 279 280 //===----------------------------------------------------------------------===// 281 // SymbolRefAttr 282 //===----------------------------------------------------------------------===// 283 284 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { 285 return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None) 286 .cast<FlatSymbolRefAttr>(); 287 } 288 289 SymbolRefAttr SymbolRefAttr::get(StringRef value, 290 ArrayRef<FlatSymbolRefAttr> nestedReferences, 291 MLIRContext *ctx) { 292 return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences); 293 } 294 295 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } 296 297 StringRef SymbolRefAttr::getLeafReference() const { 298 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 299 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 300 } 301 302 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const { 303 return getImpl()->getNestedRefs(); 304 } 305 306 //===----------------------------------------------------------------------===// 307 // IntegerAttr 308 //===----------------------------------------------------------------------===// 309 310 IntegerAttr IntegerAttr::get(Type type, const APInt &value) { 311 return Base::get(type.getContext(), StandardAttributes::Integer, type, value); 312 } 313 314 IntegerAttr IntegerAttr::get(Type type, int64_t value) { 315 // This uses 64 bit APInts by default for index type. 316 if (type.isIndex()) 317 return get(type, APInt(IndexType::kInternalStorageBitWidth, value)); 318 319 auto intType = type.cast<IntegerType>(); 320 return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger())); 321 } 322 323 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } 324 325 int64_t IntegerAttr::getInt() const { 326 assert((getImpl()->getType().isIndex() || 327 getImpl()->getType().isSignlessInteger()) && 328 "must be signless integer"); 329 return getValue().getSExtValue(); 330 } 331 332 int64_t IntegerAttr::getSInt() const { 333 assert(getImpl()->getType().isSignedInteger() && "must be signed integer"); 334 return getValue().getSExtValue(); 335 } 336 337 uint64_t IntegerAttr::getUInt() const { 338 assert(getImpl()->getType().isUnsignedInteger() && 339 "must be unsigned integer"); 340 return getValue().getZExtValue(); 341 } 342 343 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) { 344 if (type.isa<IntegerType>() || type.isa<IndexType>()) 345 return success(); 346 return emitError(loc, "expected integer or index type"); 347 } 348 349 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 350 int64_t value) { 351 return verifyIntegerTypeInvariants(loc, type); 352 } 353 354 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 355 const APInt &value) { 356 if (failed(verifyIntegerTypeInvariants(loc, type))) 357 return failure(); 358 if (auto integerType = type.dyn_cast<IntegerType>()) 359 if (integerType.getWidth() != value.getBitWidth()) 360 return emitError(loc, "integer type bit width (") 361 << integerType.getWidth() << ") doesn't match value bit width (" 362 << value.getBitWidth() << ")"; 363 return success(); 364 } 365 366 //===----------------------------------------------------------------------===// 367 // IntegerSetAttr 368 //===----------------------------------------------------------------------===// 369 370 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { 371 return Base::get(value.getConstraint(0).getContext(), 372 StandardAttributes::IntegerSet, value); 373 } 374 375 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } 376 377 //===----------------------------------------------------------------------===// 378 // OpaqueAttr 379 //===----------------------------------------------------------------------===// 380 381 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, 382 MLIRContext *context) { 383 return Base::get(context, StandardAttributes::Opaque, dialect, attrData, 384 type); 385 } 386 387 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, 388 Type type, Location location) { 389 return Base::getChecked(location, StandardAttributes::Opaque, dialect, 390 attrData, type); 391 } 392 393 /// Returns the dialect namespace of the opaque attribute. 394 Identifier OpaqueAttr::getDialectNamespace() const { 395 return getImpl()->dialectNamespace; 396 } 397 398 /// Returns the raw attribute data of the opaque attribute. 399 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } 400 401 /// Verify the construction of an opaque attribute. 402 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc, 403 Identifier dialect, 404 StringRef attrData, 405 Type type) { 406 if (!Dialect::isValidNamespace(dialect.strref())) 407 return emitError(loc, "invalid dialect namespace '") << dialect << "'"; 408 return success(); 409 } 410 411 //===----------------------------------------------------------------------===// 412 // StringAttr 413 //===----------------------------------------------------------------------===// 414 415 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { 416 return get(bytes, NoneType::get(context)); 417 } 418 419 /// Get an instance of a StringAttr with the given string and Type. 420 StringAttr StringAttr::get(StringRef bytes, Type type) { 421 return Base::get(type.getContext(), StandardAttributes::String, bytes, type); 422 } 423 424 StringRef StringAttr::getValue() const { return getImpl()->value; } 425 426 //===----------------------------------------------------------------------===// 427 // TypeAttr 428 //===----------------------------------------------------------------------===// 429 430 TypeAttr TypeAttr::get(Type value) { 431 return Base::get(value.getContext(), StandardAttributes::Type, value); 432 } 433 434 Type TypeAttr::getValue() const { return getImpl()->value; } 435 436 //===----------------------------------------------------------------------===// 437 // ElementsAttr 438 //===----------------------------------------------------------------------===// 439 440 ShapedType ElementsAttr::getType() const { 441 return Attribute::getType().cast<ShapedType>(); 442 } 443 444 /// Returns the number of elements held by this attribute. 445 int64_t ElementsAttr::getNumElements() const { 446 return getType().getNumElements(); 447 } 448 449 /// Return the value at the given index. If index does not refer to a valid 450 /// element, then a null attribute is returned. 451 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 452 switch (getKind()) { 453 case StandardAttributes::DenseIntOrFPElements: 454 return cast<DenseElementsAttr>().getValue(index); 455 case StandardAttributes::OpaqueElements: 456 return cast<OpaqueElementsAttr>().getValue(index); 457 case StandardAttributes::SparseElements: 458 return cast<SparseElementsAttr>().getValue(index); 459 default: 460 llvm_unreachable("unknown ElementsAttr kind"); 461 } 462 } 463 464 /// Return if the given 'index' refers to a valid element in this attribute. 465 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 466 auto type = getType(); 467 468 // Verify that the rank of the indices matches the held type. 469 auto rank = type.getRank(); 470 if (rank != static_cast<int64_t>(index.size())) 471 return false; 472 473 // Verify that all of the indices are within the shape dimensions. 474 auto shape = type.getShape(); 475 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 476 return static_cast<int64_t>(index[i]) < shape[i]; 477 }); 478 } 479 480 ElementsAttr 481 ElementsAttr::mapValues(Type newElementType, 482 function_ref<APInt(const APInt &)> mapping) const { 483 switch (getKind()) { 484 case StandardAttributes::DenseIntOrFPElements: 485 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 486 default: 487 llvm_unreachable("unsupported ElementsAttr subtype"); 488 } 489 } 490 491 ElementsAttr 492 ElementsAttr::mapValues(Type newElementType, 493 function_ref<APInt(const APFloat &)> mapping) const { 494 switch (getKind()) { 495 case StandardAttributes::DenseIntOrFPElements: 496 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 497 default: 498 llvm_unreachable("unsupported ElementsAttr subtype"); 499 } 500 } 501 502 /// Returns the 1 dimensional flattened row-major index from the given 503 /// multi-dimensional index. 504 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 505 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 506 auto type = getType(); 507 508 // Reduce the provided multidimensional index into a flattended 1D row-major 509 // index. 510 auto rank = type.getRank(); 511 auto shape = type.getShape(); 512 uint64_t valueIndex = 0; 513 uint64_t dimMultiplier = 1; 514 for (int i = rank - 1; i >= 0; --i) { 515 valueIndex += index[i] * dimMultiplier; 516 dimMultiplier *= shape[i]; 517 } 518 return valueIndex; 519 } 520 521 //===----------------------------------------------------------------------===// 522 // DenseElementAttr Utilities 523 //===----------------------------------------------------------------------===// 524 525 /// Get the bitwidth of a dense element type within the buffer. 526 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 527 static size_t getDenseElementStorageWidth(size_t origWidth) { 528 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 529 } 530 static size_t getDenseElementStorageWidth(Type elementType) { 531 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 532 } 533 534 /// Set a bit to a specific value. 535 static void setBit(char *rawData, size_t bitPos, bool value) { 536 if (value) 537 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 538 else 539 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 540 } 541 542 /// Return the value of the specified bit. 543 static bool getBit(const char *rawData, size_t bitPos) { 544 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 545 } 546 547 /// Get start position of actual data in `value`. Actual data is 548 /// stored in last `bitWidth`/CHAR_BIT bytes in big endian. 549 static char *getAPIntDataPos(APInt &value, size_t bitWidth) { 550 char *dataPos = 551 const_cast<char *>(reinterpret_cast<const char *>(value.getRawData())); 552 if (llvm::support::endian::system_endianness() == 553 llvm::support::endianness::big) 554 dataPos = dataPos + 8 - llvm::divideCeil(bitWidth, CHAR_BIT); 555 return dataPos; 556 } 557 558 /// Read APInt `value` from appropriate position. 559 static void readAPInt(APInt &value, size_t bitWidth, char *outData) { 560 char *dataPos = getAPIntDataPos(value, bitWidth); 561 std::copy_n(dataPos, llvm::divideCeil(bitWidth, CHAR_BIT), outData); 562 } 563 564 /// Write `inData` to appropriate position of APInt `value`. 565 static void writeAPInt(const char *inData, size_t bitWidth, APInt &value) { 566 char *dataPos = getAPIntDataPos(value, bitWidth); 567 std::copy_n(inData, llvm::divideCeil(bitWidth, CHAR_BIT), dataPos); 568 } 569 570 /// Writes value to the bit position `bitPos` in array `rawData`. 571 static void writeBits(char *rawData, size_t bitPos, APInt value) { 572 size_t bitWidth = value.getBitWidth(); 573 574 // If the bitwidth is 1 we just toggle the specific bit. 575 if (bitWidth == 1) 576 return setBit(rawData, bitPos, value.isOneValue()); 577 578 // Otherwise, the bit position is guaranteed to be byte aligned. 579 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 580 readAPInt(value, bitWidth, rawData + (bitPos / CHAR_BIT)); 581 } 582 583 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 584 /// `rawData`. 585 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 586 // Handle a boolean bit position. 587 if (bitWidth == 1) 588 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 589 590 // Otherwise, the bit position must be 8-bit aligned. 591 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 592 APInt result(bitWidth, 0); 593 writeAPInt(rawData + (bitPos / CHAR_BIT), bitWidth, result); 594 return result; 595 } 596 597 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the 598 /// same element count as 'type'. 599 template <typename Values> 600 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 601 return (values.size() == 1) || 602 (type.getNumElements() == static_cast<int64_t>(values.size())); 603 } 604 605 //===----------------------------------------------------------------------===// 606 // DenseElementAttr Iterators 607 //===----------------------------------------------------------------------===// 608 609 //===----------------------------------------------------------------------===// 610 // AttributeElementIterator 611 612 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 613 DenseElementsAttr attr, size_t index) 614 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 615 Attribute, Attribute, Attribute>( 616 attr.getAsOpaquePointer(), index) {} 617 618 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 619 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 620 Type eltTy = owner.getType().getElementType(); 621 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) { 622 if (intEltTy.getWidth() == 1) 623 return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(), 624 owner.getContext()); 625 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 626 } 627 if (eltTy.isa<IndexType>()) 628 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 629 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 630 IntElementIterator intIt(owner, index); 631 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 632 return FloatAttr::get(eltTy, *floatIt); 633 } 634 if (owner.isa<DenseStringElementsAttr>()) { 635 ArrayRef<StringRef> vals = owner.getRawStringData(); 636 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 637 } 638 llvm_unreachable("unexpected element type"); 639 } 640 641 //===----------------------------------------------------------------------===// 642 // BoolElementIterator 643 644 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 645 DenseElementsAttr attr, size_t dataIndex) 646 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 647 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 648 649 bool DenseElementsAttr::BoolElementIterator::operator*() const { 650 return getBit(getData(), getDataIndex()); 651 } 652 653 //===----------------------------------------------------------------------===// 654 // IntElementIterator 655 656 DenseElementsAttr::IntElementIterator::IntElementIterator( 657 DenseElementsAttr attr, size_t dataIndex) 658 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 659 attr.getRawData().data(), attr.isSplat(), dataIndex), 660 bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} 661 662 APInt DenseElementsAttr::IntElementIterator::operator*() const { 663 return readBits(getData(), 664 getDataIndex() * getDenseElementStorageWidth(bitWidth), 665 bitWidth); 666 } 667 668 //===----------------------------------------------------------------------===// 669 // ComplexIntElementIterator 670 671 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 672 DenseElementsAttr attr, size_t dataIndex) 673 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 674 std::complex<APInt>, std::complex<APInt>, 675 std::complex<APInt>>( 676 attr.getRawData().data(), attr.isSplat(), dataIndex) { 677 auto complexType = attr.getType().getElementType().cast<ComplexType>(); 678 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 679 } 680 681 std::complex<APInt> 682 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 683 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 684 size_t offset = getDataIndex() * storageWidth * 2; 685 return {readBits(getData(), offset, bitWidth), 686 readBits(getData(), offset + storageWidth, bitWidth)}; 687 } 688 689 //===----------------------------------------------------------------------===// 690 // FloatElementIterator 691 692 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 693 const llvm::fltSemantics &smt, IntElementIterator it) 694 : llvm::mapped_iterator<IntElementIterator, 695 std::function<APFloat(const APInt &)>>( 696 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 697 698 //===----------------------------------------------------------------------===// 699 // ComplexFloatElementIterator 700 701 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( 702 const llvm::fltSemantics &smt, ComplexIntElementIterator it) 703 : llvm::mapped_iterator< 704 ComplexIntElementIterator, 705 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( 706 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { 707 return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; 708 }) {} 709 710 //===----------------------------------------------------------------------===// 711 // DenseElementsAttr 712 //===----------------------------------------------------------------------===// 713 714 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 715 ArrayRef<Attribute> values) { 716 assert(hasSameElementsOrSplat(type, values)); 717 718 // If the element type is not based on int/float/index, assume it is a string 719 // type. 720 auto eltType = type.getElementType(); 721 if (!type.getElementType().isIntOrIndexOrFloat()) { 722 SmallVector<StringRef, 8> stringValues; 723 stringValues.reserve(values.size()); 724 for (Attribute attr : values) { 725 assert(attr.isa<StringAttr>() && 726 "expected string value for non integer/index/float element"); 727 stringValues.push_back(attr.cast<StringAttr>().getValue()); 728 } 729 return get(type, stringValues); 730 } 731 732 // Otherwise, get the raw storage width to use for the allocation. 733 size_t bitWidth = getDenseElementBitWidth(eltType); 734 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 735 736 // Compress the attribute values into a character buffer. 737 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 738 values.size()); 739 APInt intVal; 740 for (unsigned i = 0, e = values.size(); i < e; ++i) { 741 assert(eltType == values[i].getType() && 742 "expected attribute value to have element type"); 743 744 switch (eltType.getKind()) { 745 case StandardTypes::BF16: 746 case StandardTypes::F16: 747 case StandardTypes::F32: 748 case StandardTypes::F64: 749 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 750 break; 751 case StandardTypes::Integer: 752 case StandardTypes::Index: 753 intVal = values[i].isa<BoolAttr>() 754 ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0) 755 : values[i].cast<IntegerAttr>().getValue(); 756 break; 757 default: 758 llvm_unreachable("unexpected element type"); 759 } 760 assert(intVal.getBitWidth() == bitWidth && 761 "expected value to have same bitwidth as element type"); 762 writeBits(data.data(), i * storageBitWidth, intVal); 763 } 764 return DenseIntOrFPElementsAttr::getRaw(type, data, 765 /*isSplat=*/(values.size() == 1)); 766 } 767 768 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 769 ArrayRef<bool> values) { 770 assert(hasSameElementsOrSplat(type, values)); 771 assert(type.getElementType().isInteger(1)); 772 773 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 774 for (int i = 0, e = values.size(); i != e; ++i) 775 setBit(buff.data(), i, values[i]); 776 return DenseIntOrFPElementsAttr::getRaw(type, buff, 777 /*isSplat=*/(values.size() == 1)); 778 } 779 780 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 781 ArrayRef<StringRef> values) { 782 assert(!type.getElementType().isIntOrFloat()); 783 return DenseStringElementsAttr::get(type, values); 784 } 785 786 /// Constructs a dense integer elements attribute from an array of APInt 787 /// values. Each APInt value is expected to have the same bitwidth as the 788 /// element type of 'type'. 789 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 790 ArrayRef<APInt> values) { 791 assert(type.getElementType().isIntOrIndex()); 792 assert(hasSameElementsOrSplat(type, values)); 793 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 794 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 795 /*isSplat=*/(values.size() == 1)); 796 } 797 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 798 ArrayRef<std::complex<APInt>> values) { 799 ComplexType complex = type.getElementType().cast<ComplexType>(); 800 assert(complex.getElementType().isa<IntegerType>()); 801 assert(hasSameElementsOrSplat(type, values)); 802 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 803 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 804 values.size() * 2); 805 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, 806 /*isSplat=*/(values.size() == 1)); 807 } 808 809 // Constructs a dense float elements attribute from an array of APFloat 810 // values. Each APFloat value is expected to have the same bitwidth as the 811 // element type of 'type'. 812 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 813 ArrayRef<APFloat> values) { 814 assert(type.getElementType().isa<FloatType>()); 815 assert(hasSameElementsOrSplat(type, values)); 816 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 817 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 818 /*isSplat=*/(values.size() == 1)); 819 } 820 DenseElementsAttr 821 DenseElementsAttr::get(ShapedType type, 822 ArrayRef<std::complex<APFloat>> values) { 823 ComplexType complex = type.getElementType().cast<ComplexType>(); 824 assert(complex.getElementType().isa<FloatType>()); 825 assert(hasSameElementsOrSplat(type, values)); 826 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 827 values.size() * 2); 828 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 829 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, 830 /*isSplat=*/(values.size() == 1)); 831 } 832 833 /// Construct a dense elements attribute from a raw buffer representing the 834 /// data for this attribute. Users should generally not use this methods as 835 /// the expected buffer format may not be a form the user expects. 836 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 837 ArrayRef<char> rawBuffer, 838 bool isSplatBuffer) { 839 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 840 } 841 842 /// Returns true if the given buffer is a valid raw buffer for the given type. 843 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 844 ArrayRef<char> rawBuffer, 845 bool &detectedSplat) { 846 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 847 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 848 849 // Storage width of 1 is special as it is packed by the bit. 850 if (storageWidth == 1) { 851 // Check for a splat, or a buffer equal to the number of elements. 852 if ((detectedSplat = rawBuffer.size() == 1)) 853 return true; 854 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); 855 } 856 // All other types are 8-bit aligned. 857 if ((detectedSplat = rawBufferWidth == storageWidth)) 858 return true; 859 return rawBufferWidth == (storageWidth * type.getNumElements()); 860 } 861 862 /// Check the information for a C++ data type, check if this type is valid for 863 /// the current attribute. This method is used to verify specific type 864 /// invariants that the templatized 'getValues' method cannot. 865 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 866 bool isSigned) { 867 // Make sure that the data element size is the same as the type element width. 868 if (getDenseElementBitWidth(type) != 869 static_cast<size_t>(dataEltSize * CHAR_BIT)) 870 return false; 871 872 // Check that the element type is either float or integer or index. 873 if (!isInt) 874 return type.isa<FloatType>(); 875 if (type.isIndex()) 876 return true; 877 878 auto intType = type.dyn_cast<IntegerType>(); 879 if (!intType) 880 return false; 881 882 // Make sure signedness semantics is consistent. 883 if (intType.isSignless()) 884 return true; 885 return intType.isSigned() ? isSigned : !isSigned; 886 } 887 888 /// Defaults down the subclass implementation. 889 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 890 ArrayRef<char> data, 891 int64_t dataEltSize, 892 bool isInt, bool isSigned) { 893 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 894 isSigned); 895 } 896 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 897 ArrayRef<char> data, 898 int64_t dataEltSize, 899 bool isInt, 900 bool isSigned) { 901 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 902 isInt, isSigned); 903 } 904 905 /// A method used to verify specific type invariants that the templatized 'get' 906 /// method cannot. 907 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 908 bool isSigned) const { 909 return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt, 910 isSigned); 911 } 912 913 /// Check the information for a C++ data type, check if this type is valid for 914 /// the current attribute. 915 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 916 bool isSigned) const { 917 return ::isValidIntOrFloat( 918 getType().getElementType().cast<ComplexType>().getElementType(), 919 dataEltSize / 2, isInt, isSigned); 920 } 921 922 /// Returns if this attribute corresponds to a splat, i.e. if all element 923 /// values are the same. 924 bool DenseElementsAttr::isSplat() const { 925 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 926 } 927 928 /// Return the held element values as a range of Attributes. 929 auto DenseElementsAttr::getAttributeValues() const 930 -> llvm::iterator_range<AttributeElementIterator> { 931 return {attr_value_begin(), attr_value_end()}; 932 } 933 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 934 return AttributeElementIterator(*this, 0); 935 } 936 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 937 return AttributeElementIterator(*this, getNumElements()); 938 } 939 940 /// Return the held element values as a range of bool. The element type of 941 /// this attribute must be of integer type of bitwidth 1. 942 auto DenseElementsAttr::getBoolValues() const 943 -> llvm::iterator_range<BoolElementIterator> { 944 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 945 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 946 (void)eltType; 947 return {BoolElementIterator(*this, 0), 948 BoolElementIterator(*this, getNumElements())}; 949 } 950 951 /// Return the held element values as a range of APInts. The element type of 952 /// this attribute must be of integer type. 953 auto DenseElementsAttr::getIntValues() const 954 -> llvm::iterator_range<IntElementIterator> { 955 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 956 return {raw_int_begin(), raw_int_end()}; 957 } 958 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 959 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 960 return raw_int_begin(); 961 } 962 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 963 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 964 return raw_int_end(); 965 } 966 auto DenseElementsAttr::getComplexIntValues() const 967 -> llvm::iterator_range<ComplexIntElementIterator> { 968 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 969 (void)eltTy; 970 assert(eltTy.isa<IntegerType>() && "expected complex integral type"); 971 return {ComplexIntElementIterator(*this, 0), 972 ComplexIntElementIterator(*this, getNumElements())}; 973 } 974 975 /// Return the held element values as a range of APFloat. The element type of 976 /// this attribute must be of float type. 977 auto DenseElementsAttr::getFloatValues() const 978 -> llvm::iterator_range<FloatElementIterator> { 979 auto elementType = getType().getElementType().cast<FloatType>(); 980 const auto &elementSemantics = elementType.getFloatSemantics(); 981 return {FloatElementIterator(elementSemantics, raw_int_begin()), 982 FloatElementIterator(elementSemantics, raw_int_end())}; 983 } 984 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 985 return getFloatValues().begin(); 986 } 987 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 988 return getFloatValues().end(); 989 } 990 auto DenseElementsAttr::getComplexFloatValues() const 991 -> llvm::iterator_range<ComplexFloatElementIterator> { 992 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 993 assert(eltTy.isa<FloatType>() && "expected complex float type"); 994 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 995 return {{semantics, {*this, 0}}, 996 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 997 } 998 999 /// Return the raw storage data held by this attribute. 1000 ArrayRef<char> DenseElementsAttr::getRawData() const { 1001 return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data; 1002 } 1003 1004 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 1005 return static_cast<DenseStringElementsAttributeStorage *>(impl)->data; 1006 } 1007 1008 /// Return a new DenseElementsAttr that has the same data as the current 1009 /// attribute, but has been reshaped to 'newType'. The new type must have the 1010 /// same total number of elements as well as element type. 1011 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 1012 ShapedType curType = getType(); 1013 if (curType == newType) 1014 return *this; 1015 1016 (void)curType; 1017 assert(newType.getElementType() == curType.getElementType() && 1018 "expected the same element type"); 1019 assert(newType.getNumElements() == curType.getNumElements() && 1020 "expected the same number of elements"); 1021 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 1022 } 1023 1024 DenseElementsAttr 1025 DenseElementsAttr::mapValues(Type newElementType, 1026 function_ref<APInt(const APInt &)> mapping) const { 1027 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 1028 } 1029 1030 DenseElementsAttr DenseElementsAttr::mapValues( 1031 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1032 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 1033 } 1034 1035 //===----------------------------------------------------------------------===// 1036 // DenseStringElementsAttr 1037 //===----------------------------------------------------------------------===// 1038 1039 DenseStringElementsAttr 1040 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) { 1041 return Base::get(type.getContext(), StandardAttributes::DenseStringElements, 1042 type, values, (values.size() == 1)); 1043 } 1044 1045 //===----------------------------------------------------------------------===// 1046 // DenseIntOrFPElementsAttr 1047 //===----------------------------------------------------------------------===// 1048 1049 /// Utility method to write a range of APInt values to a buffer. 1050 template <typename APRangeT> 1051 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1052 APRangeT &&values) { 1053 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); 1054 size_t offset = 0; 1055 for (auto it = values.begin(), e = values.end(); it != e; 1056 ++it, offset += storageWidth) { 1057 assert((*it).getBitWidth() <= storageWidth); 1058 writeBits(data.data(), offset, *it); 1059 } 1060 } 1061 1062 /// Constructs a dense elements attribute from an array of raw APFloat values. 1063 /// Each APFloat value is expected to have the same bitwidth as the element 1064 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1065 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1066 size_t storageWidth, 1067 ArrayRef<APFloat> values, 1068 bool isSplat) { 1069 std::vector<char> data; 1070 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1071 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1072 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1073 } 1074 1075 /// Constructs a dense elements attribute from an array of raw APInt values. 1076 /// Each APInt value is expected to have the same bitwidth as the element type 1077 /// of 'type'. 1078 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1079 size_t storageWidth, 1080 ArrayRef<APInt> values, 1081 bool isSplat) { 1082 std::vector<char> data; 1083 writeAPIntsToBuffer(storageWidth, data, values); 1084 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1085 } 1086 1087 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1088 ArrayRef<char> data, 1089 bool isSplat) { 1090 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 1091 "type must be ranked tensor or vector"); 1092 assert(type.hasStaticShape() && "type must have static shape"); 1093 return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements, 1094 type, data, isSplat); 1095 } 1096 1097 /// Overload of the raw 'get' method that asserts that the given type is of 1098 /// complex type. This method is used to verify type invariants that the 1099 /// templatized 'get' method cannot. 1100 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1101 ArrayRef<char> data, 1102 int64_t dataEltSize, 1103 bool isInt, 1104 bool isSigned) { 1105 assert(::isValidIntOrFloat( 1106 type.getElementType().cast<ComplexType>().getElementType(), 1107 dataEltSize / 2, isInt, isSigned)); 1108 1109 int64_t numElements = data.size() / dataEltSize; 1110 assert(numElements == 1 || numElements == type.getNumElements()); 1111 return getRaw(type, data, /*isSplat=*/numElements == 1); 1112 } 1113 1114 /// Overload of the 'getRaw' method that asserts that the given type is of 1115 /// integer type. This method is used to verify type invariants that the 1116 /// templatized 'get' method cannot. 1117 DenseElementsAttr 1118 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1119 int64_t dataEltSize, bool isInt, 1120 bool isSigned) { 1121 assert( 1122 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1123 1124 int64_t numElements = data.size() / dataEltSize; 1125 assert(numElements == 1 || numElements == type.getNumElements()); 1126 return getRaw(type, data, /*isSplat=*/numElements == 1); 1127 } 1128 1129 //===----------------------------------------------------------------------===// 1130 // DenseFPElementsAttr 1131 //===----------------------------------------------------------------------===// 1132 1133 template <typename Fn, typename Attr> 1134 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1135 Type newElementType, 1136 llvm::SmallVectorImpl<char> &data) { 1137 size_t bitWidth = getDenseElementBitWidth(newElementType); 1138 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1139 1140 ShapedType newArrayType; 1141 if (inType.isa<RankedTensorType>()) 1142 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1143 else if (inType.isa<UnrankedTensorType>()) 1144 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1145 else if (inType.isa<VectorType>()) 1146 newArrayType = VectorType::get(inType.getShape(), newElementType); 1147 else 1148 assert(newArrayType && "Unhandled tensor type"); 1149 1150 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1151 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 1152 1153 // Functor used to process a single element value of the attribute. 1154 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1155 auto newInt = mapping(value); 1156 assert(newInt.getBitWidth() == bitWidth); 1157 writeBits(data.data(), index * storageBitWidth, newInt); 1158 }; 1159 1160 // Check for the splat case. 1161 if (attr.isSplat()) { 1162 processElt(*attr.begin(), /*index=*/0); 1163 return newArrayType; 1164 } 1165 1166 // Otherwise, process all of the element values. 1167 uint64_t elementIdx = 0; 1168 for (auto value : attr) 1169 processElt(value, elementIdx++); 1170 return newArrayType; 1171 } 1172 1173 DenseElementsAttr DenseFPElementsAttr::mapValues( 1174 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1175 llvm::SmallVector<char, 8> elementData; 1176 auto newArrayType = 1177 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1178 1179 return getRaw(newArrayType, elementData, isSplat()); 1180 } 1181 1182 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1183 bool DenseFPElementsAttr::classof(Attribute attr) { 1184 return attr.isa<DenseElementsAttr>() && 1185 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1186 } 1187 1188 //===----------------------------------------------------------------------===// 1189 // DenseIntElementsAttr 1190 //===----------------------------------------------------------------------===// 1191 1192 DenseElementsAttr DenseIntElementsAttr::mapValues( 1193 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1194 llvm::SmallVector<char, 8> elementData; 1195 auto newArrayType = 1196 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1197 1198 return getRaw(newArrayType, elementData, isSplat()); 1199 } 1200 1201 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1202 bool DenseIntElementsAttr::classof(Attribute attr) { 1203 return attr.isa<DenseElementsAttr>() && 1204 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1205 } 1206 1207 //===----------------------------------------------------------------------===// 1208 // OpaqueElementsAttr 1209 //===----------------------------------------------------------------------===// 1210 1211 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, 1212 StringRef bytes) { 1213 assert(TensorType::isValidElementType(type.getElementType()) && 1214 "Input element type should be a valid tensor element type"); 1215 return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type, 1216 dialect, bytes); 1217 } 1218 1219 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } 1220 1221 /// Return the value at the given index. If index does not refer to a valid 1222 /// element, then a null attribute is returned. 1223 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1224 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1225 if (Dialect *dialect = getDialect()) 1226 return dialect->extractElementHook(*this, index); 1227 return Attribute(); 1228 } 1229 1230 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } 1231 1232 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1233 if (auto *d = getDialect()) 1234 return d->decodeHook(*this, result); 1235 return true; 1236 } 1237 1238 //===----------------------------------------------------------------------===// 1239 // SparseElementsAttr 1240 //===----------------------------------------------------------------------===// 1241 1242 SparseElementsAttr SparseElementsAttr::get(ShapedType type, 1243 DenseElementsAttr indices, 1244 DenseElementsAttr values) { 1245 assert(indices.getType().getElementType().isInteger(64) && 1246 "expected sparse indices to be 64-bit integer values"); 1247 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 1248 "type must be ranked tensor or vector"); 1249 assert(type.hasStaticShape() && "type must have static shape"); 1250 return Base::get(type.getContext(), StandardAttributes::SparseElements, type, 1251 indices.cast<DenseIntElementsAttr>(), values); 1252 } 1253 1254 DenseIntElementsAttr SparseElementsAttr::getIndices() const { 1255 return getImpl()->indices; 1256 } 1257 1258 DenseElementsAttr SparseElementsAttr::getValues() const { 1259 return getImpl()->values; 1260 } 1261 1262 /// Return the value of the element at the given index. 1263 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1264 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1265 auto type = getType(); 1266 1267 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1268 // as a 1-D index array. 1269 auto sparseIndices = getIndices(); 1270 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1271 1272 // Check to see if the indices are a splat. 1273 if (sparseIndices.isSplat()) { 1274 // If the index is also not a splat of the index value, we know that the 1275 // value is zero. 1276 auto splatIndex = *sparseIndexValues.begin(); 1277 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1278 return getZeroAttr(); 1279 1280 // If the indices are a splat, we also expect the values to be a splat. 1281 assert(getValues().isSplat() && "expected splat values"); 1282 return getValues().getSplatValue(); 1283 } 1284 1285 // Build a mapping between known indices and the offset of the stored element. 1286 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1287 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1288 size_t rank = type.getRank(); 1289 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1290 mappedIndices.try_emplace( 1291 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1292 1293 // Look for the provided index key within the mapped indices. If the provided 1294 // index is not found, then return a zero attribute. 1295 auto it = mappedIndices.find(index); 1296 if (it == mappedIndices.end()) 1297 return getZeroAttr(); 1298 1299 // Otherwise, return the held sparse value element. 1300 return getValues().getValue(it->second); 1301 } 1302 1303 /// Get a zero APFloat for the given sparse attribute. 1304 APFloat SparseElementsAttr::getZeroAPFloat() const { 1305 auto eltType = getType().getElementType().cast<FloatType>(); 1306 return APFloat(eltType.getFloatSemantics()); 1307 } 1308 1309 /// Get a zero APInt for the given sparse attribute. 1310 APInt SparseElementsAttr::getZeroAPInt() const { 1311 auto eltType = getType().getElementType().cast<IntegerType>(); 1312 return APInt::getNullValue(eltType.getWidth()); 1313 } 1314 1315 /// Get a zero attribute for the given attribute type. 1316 Attribute SparseElementsAttr::getZeroAttr() const { 1317 auto eltType = getType().getElementType(); 1318 1319 // Handle floating point elements. 1320 if (eltType.isa<FloatType>()) 1321 return FloatAttr::get(eltType, 0); 1322 1323 // Otherwise, this is an integer. 1324 auto intEltTy = eltType.cast<IntegerType>(); 1325 if (intEltTy.getWidth() == 1) 1326 return BoolAttr::get(false, eltType.getContext()); 1327 return IntegerAttr::get(eltType, 0); 1328 } 1329 1330 /// Flatten, and return, all of the sparse indices in this attribute in 1331 /// row-major order. 1332 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1333 std::vector<ptrdiff_t> flatSparseIndices; 1334 1335 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1336 // as a 1-D index array. 1337 auto sparseIndices = getIndices(); 1338 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1339 if (sparseIndices.isSplat()) { 1340 SmallVector<uint64_t, 8> indices(getType().getRank(), 1341 *sparseIndexValues.begin()); 1342 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1343 return flatSparseIndices; 1344 } 1345 1346 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1347 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1348 size_t rank = getType().getRank(); 1349 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1350 flatSparseIndices.push_back(getFlattenedIndex( 1351 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1352 return flatSparseIndices; 1353 } 1354 1355 //===----------------------------------------------------------------------===// 1356 // MutableDictionaryAttr 1357 //===----------------------------------------------------------------------===// 1358 1359 MutableDictionaryAttr::MutableDictionaryAttr( 1360 ArrayRef<NamedAttribute> attributes) { 1361 setAttrs(attributes); 1362 } 1363 1364 /// Return the underlying dictionary attribute. 1365 DictionaryAttr 1366 MutableDictionaryAttr::getDictionary(MLIRContext *context) const { 1367 // Construct empty DictionaryAttr if needed. 1368 if (!attrs) 1369 return DictionaryAttr::get({}, context); 1370 return attrs; 1371 } 1372 1373 ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const { 1374 return attrs ? attrs.getValue() : llvm::None; 1375 } 1376 1377 /// Replace the held attributes with ones provided in 'newAttrs'. 1378 void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) { 1379 // Don't create an attribute list if there are no attributes. 1380 if (attributes.empty()) 1381 attrs = nullptr; 1382 else 1383 attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext()); 1384 } 1385 1386 /// Return the specified attribute if present, null otherwise. 1387 Attribute MutableDictionaryAttr::get(StringRef name) const { 1388 return attrs ? attrs.get(name) : nullptr; 1389 } 1390 1391 /// Return the specified attribute if present, null otherwise. 1392 Attribute MutableDictionaryAttr::get(Identifier name) const { 1393 return attrs ? attrs.get(name) : nullptr; 1394 } 1395 1396 /// Return the specified named attribute if present, None otherwise. 1397 Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const { 1398 return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>(); 1399 } 1400 Optional<NamedAttribute> 1401 MutableDictionaryAttr::getNamed(Identifier name) const { 1402 return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>(); 1403 } 1404 1405 /// If the an attribute exists with the specified name, change it to the new 1406 /// value. Otherwise, add a new attribute with the specified name/value. 1407 void MutableDictionaryAttr::set(Identifier name, Attribute value) { 1408 assert(value && "attributes may never be null"); 1409 1410 // Look for an existing value for the given name, and set it in-place. 1411 ArrayRef<NamedAttribute> values = getAttrs(); 1412 const auto *it = llvm::find_if( 1413 values, [name](NamedAttribute attr) { return attr.first == name; }); 1414 if (it != values.end()) { 1415 // Bail out early if the value is the same as what we already have. 1416 if (it->second == value) 1417 return; 1418 1419 SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end()); 1420 newAttrs[it - values.begin()].second = value; 1421 attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); 1422 return; 1423 } 1424 1425 // Otherwise, insert the new attribute into its sorted position. 1426 it = llvm::lower_bound(values, name); 1427 SmallVector<NamedAttribute, 8> newAttrs; 1428 newAttrs.reserve(values.size() + 1); 1429 newAttrs.append(values.begin(), it); 1430 newAttrs.push_back({name, value}); 1431 newAttrs.append(it, values.end()); 1432 attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); 1433 } 1434 1435 /// Remove the attribute with the specified name if it exists. The return 1436 /// value indicates whether the attribute was present or not. 1437 auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult { 1438 auto origAttrs = getAttrs(); 1439 for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { 1440 if (origAttrs[i].first == name) { 1441 // Handle the simple case of removing the only attribute in the list. 1442 if (e == 1) { 1443 attrs = nullptr; 1444 return RemoveResult::Removed; 1445 } 1446 1447 SmallVector<NamedAttribute, 8> newAttrs; 1448 newAttrs.reserve(origAttrs.size() - 1); 1449 newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); 1450 newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); 1451 attrs = DictionaryAttr::getWithSorted(newAttrs, 1452 newAttrs[0].second.getContext()); 1453 return RemoveResult::Removed; 1454 } 1455 } 1456 return RemoveResult::NotFound; 1457 } 1458 1459 bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) { 1460 return strcmp(lhs.first.data(), rhs.first.data()) < 0; 1461 } 1462 bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) { 1463 // This is correct even when attr.first.data()[name.size()] is not a zero 1464 // string terminator, because we only care about a less than comparison. 1465 // This can't use memcmp, because it doesn't guarantee that it will stop 1466 // reading both buffers if one is shorter than the other, even if there is 1467 // a difference. 1468 return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0; 1469 } 1470