1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Attributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Diagnostics.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/Function.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/IR/Types.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 #include "llvm/Support/Endian.h" 20 21 using namespace mlir; 22 using namespace mlir::detail; 23 24 //===----------------------------------------------------------------------===// 25 // AttributeStorage 26 //===----------------------------------------------------------------------===// 27 28 AttributeStorage::AttributeStorage(Type type) 29 : type(type.getAsOpaquePointer()) {} 30 AttributeStorage::AttributeStorage() : type(nullptr) {} 31 32 Type AttributeStorage::getType() const { 33 return Type::getFromOpaquePointer(type); 34 } 35 void AttributeStorage::setType(Type newType) { 36 type = newType.getAsOpaquePointer(); 37 } 38 39 //===----------------------------------------------------------------------===// 40 // Attribute 41 //===----------------------------------------------------------------------===// 42 43 /// Return the type of this attribute. 44 Type Attribute::getType() const { return impl->getType(); } 45 46 /// Return the context this attribute belongs to. 47 MLIRContext *Attribute::getContext() const { return getType().getContext(); } 48 49 /// Get the dialect this attribute is registered to. 50 Dialect &Attribute::getDialect() const { return impl->getDialect(); } 51 52 //===----------------------------------------------------------------------===// 53 // AffineMapAttr 54 //===----------------------------------------------------------------------===// 55 56 AffineMapAttr AffineMapAttr::get(AffineMap value) { 57 return Base::get(value.getContext(), StandardAttributes::AffineMap, value); 58 } 59 60 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } 61 62 //===----------------------------------------------------------------------===// 63 // ArrayAttr 64 //===----------------------------------------------------------------------===// 65 66 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) { 67 return Base::get(context, StandardAttributes::Array, value); 68 } 69 70 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } 71 72 Attribute ArrayAttr::operator[](unsigned idx) const { 73 assert(idx < size() && "index out of bounds"); 74 return getValue()[idx]; 75 } 76 77 //===----------------------------------------------------------------------===// 78 // DictionaryAttr 79 //===----------------------------------------------------------------------===// 80 81 /// Helper function that does either an in place sort or sorts from source array 82 /// into destination. If inPlace then storage is both the source and the 83 /// destination, else value is the source and storage destination. Returns 84 /// whether source was sorted. 85 template <bool inPlace> 86 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 87 SmallVectorImpl<NamedAttribute> &storage) { 88 // Specialize for the common case. 89 switch (value.size()) { 90 case 0: 91 // Zero already sorted. 92 break; 93 case 1: 94 // One already sorted but may need to be copied. 95 if (!inPlace) 96 storage.assign({value[0]}); 97 break; 98 case 2: { 99 assert(value[0].first != value[1].first && 100 "DictionaryAttr element names must be unique"); 101 bool isSorted = value[0] < value[1]; 102 if (inPlace) { 103 if (!isSorted) 104 std::swap(storage[0], storage[1]); 105 } else if (isSorted) { 106 storage.assign({value[0], value[1]}); 107 } else { 108 storage.assign({value[1], value[0]}); 109 } 110 return !isSorted; 111 } 112 default: 113 if (!inPlace) 114 storage.assign(value.begin(), value.end()); 115 // Check to see they are sorted already. 116 bool isSorted = llvm::is_sorted(value); 117 if (!isSorted) { 118 // If not, do a general sort. 119 llvm::array_pod_sort(storage.begin(), storage.end()); 120 value = storage; 121 } 122 123 // Ensure that the attribute elements are unique. 124 assert(std::adjacent_find(value.begin(), value.end(), 125 [](NamedAttribute l, NamedAttribute r) { 126 return l.first == r.first; 127 }) == value.end() && 128 "DictionaryAttr element names must be unique"); 129 return !isSorted; 130 } 131 return false; 132 } 133 134 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 135 SmallVectorImpl<NamedAttribute> &storage) { 136 return dictionaryAttrSort</*inPlace=*/false>(value, storage); 137 } 138 139 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 140 return dictionaryAttrSort</*inPlace=*/true>(array, array); 141 } 142 143 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value, 144 MLIRContext *context) { 145 if (value.empty()) 146 return DictionaryAttr::getEmpty(context); 147 assert(llvm::all_of(value, 148 [](const NamedAttribute &attr) { return attr.second; }) && 149 "value cannot have null entries"); 150 151 // We need to sort the element list to canonicalize it. 152 SmallVector<NamedAttribute, 8> storage; 153 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 154 value = storage; 155 156 return Base::get(context, StandardAttributes::Dictionary, value); 157 } 158 /// Construct a dictionary with an array of values that is known to already be 159 /// sorted by name and uniqued. 160 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value, 161 MLIRContext *context) { 162 if (value.empty()) 163 return DictionaryAttr::getEmpty(context); 164 // Ensure that the attribute elements are unique and sorted. 165 assert(llvm::is_sorted(value, 166 [](NamedAttribute l, NamedAttribute r) { 167 return l.first.strref() < r.first.strref(); 168 }) && 169 "expected attribute values to be sorted"); 170 assert(std::adjacent_find(value.begin(), value.end(), 171 [](NamedAttribute l, NamedAttribute r) { 172 return l.first == r.first; 173 }) == value.end() && 174 "DictionaryAttr element names must be unique"); 175 return Base::get(context, StandardAttributes::Dictionary, value); 176 } 177 178 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const { 179 return getImpl()->getElements(); 180 } 181 182 /// Return the specified attribute if present, null otherwise. 183 Attribute DictionaryAttr::get(StringRef name) const { 184 Optional<NamedAttribute> attr = getNamed(name); 185 return attr ? attr->second : nullptr; 186 } 187 Attribute DictionaryAttr::get(Identifier name) const { 188 Optional<NamedAttribute> attr = getNamed(name); 189 return attr ? attr->second : nullptr; 190 } 191 192 /// Return the specified named attribute if present, None otherwise. 193 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 194 ArrayRef<NamedAttribute> values = getValue(); 195 const auto *it = llvm::lower_bound(values, name); 196 return it != values.end() && it->first == name ? *it 197 : Optional<NamedAttribute>(); 198 } 199 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { 200 for (auto elt : getValue()) 201 if (elt.first == name) 202 return elt; 203 return llvm::None; 204 } 205 206 DictionaryAttr::iterator DictionaryAttr::begin() const { 207 return getValue().begin(); 208 } 209 DictionaryAttr::iterator DictionaryAttr::end() const { 210 return getValue().end(); 211 } 212 size_t DictionaryAttr::size() const { return getValue().size(); } 213 214 //===----------------------------------------------------------------------===// 215 // FloatAttr 216 //===----------------------------------------------------------------------===// 217 218 FloatAttr FloatAttr::get(Type type, double value) { 219 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 220 } 221 222 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { 223 return Base::getChecked(loc, StandardAttributes::Float, type, value); 224 } 225 226 FloatAttr FloatAttr::get(Type type, const APFloat &value) { 227 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 228 } 229 230 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { 231 return Base::getChecked(loc, StandardAttributes::Float, type, value); 232 } 233 234 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } 235 236 double FloatAttr::getValueAsDouble() const { 237 return getValueAsDouble(getValue()); 238 } 239 double FloatAttr::getValueAsDouble(APFloat value) { 240 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 241 bool losesInfo = false; 242 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 243 &losesInfo); 244 } 245 return value.convertToDouble(); 246 } 247 248 /// Verify construction invariants. 249 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) { 250 if (!type.isa<FloatType>()) 251 return emitError(loc, "expected floating point type"); 252 return success(); 253 } 254 255 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 256 double value) { 257 return verifyFloatTypeInvariants(loc, type); 258 } 259 260 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 261 const APFloat &value) { 262 // Verify that the type is correct. 263 if (failed(verifyFloatTypeInvariants(loc, type))) 264 return failure(); 265 266 // Verify that the type semantics match that of the value. 267 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 268 return emitError( 269 loc, "FloatAttr type doesn't match the type implied by its value"); 270 } 271 return success(); 272 } 273 274 //===----------------------------------------------------------------------===// 275 // SymbolRefAttr 276 //===----------------------------------------------------------------------===// 277 278 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { 279 return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None) 280 .cast<FlatSymbolRefAttr>(); 281 } 282 283 SymbolRefAttr SymbolRefAttr::get(StringRef value, 284 ArrayRef<FlatSymbolRefAttr> nestedReferences, 285 MLIRContext *ctx) { 286 return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences); 287 } 288 289 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } 290 291 StringRef SymbolRefAttr::getLeafReference() const { 292 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 293 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 294 } 295 296 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const { 297 return getImpl()->getNestedRefs(); 298 } 299 300 //===----------------------------------------------------------------------===// 301 // IntegerAttr 302 //===----------------------------------------------------------------------===// 303 304 IntegerAttr IntegerAttr::get(Type type, const APInt &value) { 305 if (type.isSignlessInteger(1)) 306 return BoolAttr::get(value.getBoolValue(), type.getContext()); 307 return Base::get(type.getContext(), StandardAttributes::Integer, type, value); 308 } 309 310 IntegerAttr IntegerAttr::get(Type type, int64_t value) { 311 // This uses 64 bit APInts by default for index type. 312 if (type.isIndex()) 313 return get(type, APInt(IndexType::kInternalStorageBitWidth, value)); 314 315 auto intType = type.cast<IntegerType>(); 316 return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger())); 317 } 318 319 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } 320 321 int64_t IntegerAttr::getInt() const { 322 assert((getImpl()->getType().isIndex() || 323 getImpl()->getType().isSignlessInteger()) && 324 "must be signless integer"); 325 return getValue().getSExtValue(); 326 } 327 328 int64_t IntegerAttr::getSInt() const { 329 assert(getImpl()->getType().isSignedInteger() && "must be signed integer"); 330 return getValue().getSExtValue(); 331 } 332 333 uint64_t IntegerAttr::getUInt() const { 334 assert(getImpl()->getType().isUnsignedInteger() && 335 "must be unsigned integer"); 336 return getValue().getZExtValue(); 337 } 338 339 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) { 340 if (type.isa<IntegerType>() || type.isa<IndexType>()) 341 return success(); 342 return emitError(loc, "expected integer or index type"); 343 } 344 345 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 346 int64_t value) { 347 return verifyIntegerTypeInvariants(loc, type); 348 } 349 350 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 351 const APInt &value) { 352 if (failed(verifyIntegerTypeInvariants(loc, type))) 353 return failure(); 354 if (auto integerType = type.dyn_cast<IntegerType>()) 355 if (integerType.getWidth() != value.getBitWidth()) 356 return emitError(loc, "integer type bit width (") 357 << integerType.getWidth() << ") doesn't match value bit width (" 358 << value.getBitWidth() << ")"; 359 return success(); 360 } 361 362 //===----------------------------------------------------------------------===// 363 // BoolAttr 364 365 bool BoolAttr::getValue() const { 366 auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl); 367 return storage->getValue().getBoolValue(); 368 } 369 370 bool BoolAttr::classof(Attribute attr) { 371 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 372 return intAttr && intAttr.getType().isSignlessInteger(1); 373 } 374 375 //===----------------------------------------------------------------------===// 376 // IntegerSetAttr 377 //===----------------------------------------------------------------------===// 378 379 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { 380 return Base::get(value.getConstraint(0).getContext(), 381 StandardAttributes::IntegerSet, value); 382 } 383 384 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } 385 386 //===----------------------------------------------------------------------===// 387 // OpaqueAttr 388 //===----------------------------------------------------------------------===// 389 390 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, 391 MLIRContext *context) { 392 return Base::get(context, StandardAttributes::Opaque, dialect, attrData, 393 type); 394 } 395 396 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, 397 Type type, Location location) { 398 return Base::getChecked(location, StandardAttributes::Opaque, dialect, 399 attrData, type); 400 } 401 402 /// Returns the dialect namespace of the opaque attribute. 403 Identifier OpaqueAttr::getDialectNamespace() const { 404 return getImpl()->dialectNamespace; 405 } 406 407 /// Returns the raw attribute data of the opaque attribute. 408 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } 409 410 /// Verify the construction of an opaque attribute. 411 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc, 412 Identifier dialect, 413 StringRef attrData, 414 Type type) { 415 if (!Dialect::isValidNamespace(dialect.strref())) 416 return emitError(loc, "invalid dialect namespace '") << dialect << "'"; 417 return success(); 418 } 419 420 //===----------------------------------------------------------------------===// 421 // StringAttr 422 //===----------------------------------------------------------------------===// 423 424 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { 425 return get(bytes, NoneType::get(context)); 426 } 427 428 /// Get an instance of a StringAttr with the given string and Type. 429 StringAttr StringAttr::get(StringRef bytes, Type type) { 430 return Base::get(type.getContext(), StandardAttributes::String, bytes, type); 431 } 432 433 StringRef StringAttr::getValue() const { return getImpl()->value; } 434 435 //===----------------------------------------------------------------------===// 436 // TypeAttr 437 //===----------------------------------------------------------------------===// 438 439 TypeAttr TypeAttr::get(Type value) { 440 return Base::get(value.getContext(), StandardAttributes::Type, value); 441 } 442 443 Type TypeAttr::getValue() const { return getImpl()->value; } 444 445 //===----------------------------------------------------------------------===// 446 // ElementsAttr 447 //===----------------------------------------------------------------------===// 448 449 ShapedType ElementsAttr::getType() const { 450 return Attribute::getType().cast<ShapedType>(); 451 } 452 453 /// Returns the number of elements held by this attribute. 454 int64_t ElementsAttr::getNumElements() const { 455 return getType().getNumElements(); 456 } 457 458 /// Return the value at the given index. If index does not refer to a valid 459 /// element, then a null attribute is returned. 460 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 461 switch (getKind()) { 462 case StandardAttributes::DenseIntOrFPElements: 463 return cast<DenseElementsAttr>().getValue(index); 464 case StandardAttributes::OpaqueElements: 465 return cast<OpaqueElementsAttr>().getValue(index); 466 case StandardAttributes::SparseElements: 467 return cast<SparseElementsAttr>().getValue(index); 468 default: 469 llvm_unreachable("unknown ElementsAttr kind"); 470 } 471 } 472 473 /// Return if the given 'index' refers to a valid element in this attribute. 474 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 475 auto type = getType(); 476 477 // Verify that the rank of the indices matches the held type. 478 auto rank = type.getRank(); 479 if (rank != static_cast<int64_t>(index.size())) 480 return false; 481 482 // Verify that all of the indices are within the shape dimensions. 483 auto shape = type.getShape(); 484 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 485 return static_cast<int64_t>(index[i]) < shape[i]; 486 }); 487 } 488 489 ElementsAttr 490 ElementsAttr::mapValues(Type newElementType, 491 function_ref<APInt(const APInt &)> mapping) const { 492 switch (getKind()) { 493 case StandardAttributes::DenseIntOrFPElements: 494 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 495 default: 496 llvm_unreachable("unsupported ElementsAttr subtype"); 497 } 498 } 499 500 ElementsAttr 501 ElementsAttr::mapValues(Type newElementType, 502 function_ref<APInt(const APFloat &)> mapping) const { 503 switch (getKind()) { 504 case StandardAttributes::DenseIntOrFPElements: 505 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 506 default: 507 llvm_unreachable("unsupported ElementsAttr subtype"); 508 } 509 } 510 511 /// Returns the 1 dimensional flattened row-major index from the given 512 /// multi-dimensional index. 513 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 514 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 515 auto type = getType(); 516 517 // Reduce the provided multidimensional index into a flattended 1D row-major 518 // index. 519 auto rank = type.getRank(); 520 auto shape = type.getShape(); 521 uint64_t valueIndex = 0; 522 uint64_t dimMultiplier = 1; 523 for (int i = rank - 1; i >= 0; --i) { 524 valueIndex += index[i] * dimMultiplier; 525 dimMultiplier *= shape[i]; 526 } 527 return valueIndex; 528 } 529 530 //===----------------------------------------------------------------------===// 531 // DenseElementAttr Utilities 532 //===----------------------------------------------------------------------===// 533 534 /// Get the bitwidth of a dense element type within the buffer. 535 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 536 static size_t getDenseElementStorageWidth(size_t origWidth) { 537 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 538 } 539 static size_t getDenseElementStorageWidth(Type elementType) { 540 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 541 } 542 543 /// Set a bit to a specific value. 544 static void setBit(char *rawData, size_t bitPos, bool value) { 545 if (value) 546 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 547 else 548 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 549 } 550 551 /// Return the value of the specified bit. 552 static bool getBit(const char *rawData, size_t bitPos) { 553 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 554 } 555 556 /// Get start position of actual data in `value`. Actual data is 557 /// stored in last `bitWidth`/CHAR_BIT bytes in big endian. 558 static char *getAPIntDataPos(APInt &value, size_t bitWidth) { 559 char *dataPos = 560 const_cast<char *>(reinterpret_cast<const char *>(value.getRawData())); 561 if (llvm::support::endian::system_endianness() == 562 llvm::support::endianness::big) 563 dataPos = dataPos + 8 - llvm::divideCeil(bitWidth, CHAR_BIT); 564 return dataPos; 565 } 566 567 /// Read APInt `value` from appropriate position. 568 static void readAPInt(APInt &value, size_t bitWidth, char *outData) { 569 char *dataPos = getAPIntDataPos(value, bitWidth); 570 std::copy_n(dataPos, llvm::divideCeil(bitWidth, CHAR_BIT), outData); 571 } 572 573 /// Write `inData` to appropriate position of APInt `value`. 574 static void writeAPInt(const char *inData, size_t bitWidth, APInt &value) { 575 char *dataPos = getAPIntDataPos(value, bitWidth); 576 std::copy_n(inData, llvm::divideCeil(bitWidth, CHAR_BIT), dataPos); 577 } 578 579 /// Writes value to the bit position `bitPos` in array `rawData`. 580 static void writeBits(char *rawData, size_t bitPos, APInt value) { 581 size_t bitWidth = value.getBitWidth(); 582 583 // If the bitwidth is 1 we just toggle the specific bit. 584 if (bitWidth == 1) 585 return setBit(rawData, bitPos, value.isOneValue()); 586 587 // Otherwise, the bit position is guaranteed to be byte aligned. 588 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 589 readAPInt(value, bitWidth, rawData + (bitPos / CHAR_BIT)); 590 } 591 592 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 593 /// `rawData`. 594 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 595 // Handle a boolean bit position. 596 if (bitWidth == 1) 597 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 598 599 // Otherwise, the bit position must be 8-bit aligned. 600 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 601 APInt result(bitWidth, 0); 602 writeAPInt(rawData + (bitPos / CHAR_BIT), bitWidth, result); 603 return result; 604 } 605 606 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the 607 /// same element count as 'type'. 608 template <typename Values> 609 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 610 return (values.size() == 1) || 611 (type.getNumElements() == static_cast<int64_t>(values.size())); 612 } 613 614 //===----------------------------------------------------------------------===// 615 // DenseElementAttr Iterators 616 //===----------------------------------------------------------------------===// 617 618 //===----------------------------------------------------------------------===// 619 // AttributeElementIterator 620 621 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 622 DenseElementsAttr attr, size_t index) 623 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 624 Attribute, Attribute, Attribute>( 625 attr.getAsOpaquePointer(), index) {} 626 627 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 628 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 629 Type eltTy = owner.getType().getElementType(); 630 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 631 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 632 if (eltTy.isa<IndexType>()) 633 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 634 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 635 IntElementIterator intIt(owner, index); 636 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 637 return FloatAttr::get(eltTy, *floatIt); 638 } 639 if (owner.isa<DenseStringElementsAttr>()) { 640 ArrayRef<StringRef> vals = owner.getRawStringData(); 641 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 642 } 643 llvm_unreachable("unexpected element type"); 644 } 645 646 //===----------------------------------------------------------------------===// 647 // BoolElementIterator 648 649 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 650 DenseElementsAttr attr, size_t dataIndex) 651 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 652 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 653 654 bool DenseElementsAttr::BoolElementIterator::operator*() const { 655 return getBit(getData(), getDataIndex()); 656 } 657 658 //===----------------------------------------------------------------------===// 659 // IntElementIterator 660 661 DenseElementsAttr::IntElementIterator::IntElementIterator( 662 DenseElementsAttr attr, size_t dataIndex) 663 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 664 attr.getRawData().data(), attr.isSplat(), dataIndex), 665 bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} 666 667 APInt DenseElementsAttr::IntElementIterator::operator*() const { 668 return readBits(getData(), 669 getDataIndex() * getDenseElementStorageWidth(bitWidth), 670 bitWidth); 671 } 672 673 //===----------------------------------------------------------------------===// 674 // ComplexIntElementIterator 675 676 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 677 DenseElementsAttr attr, size_t dataIndex) 678 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 679 std::complex<APInt>, std::complex<APInt>, 680 std::complex<APInt>>( 681 attr.getRawData().data(), attr.isSplat(), dataIndex) { 682 auto complexType = attr.getType().getElementType().cast<ComplexType>(); 683 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 684 } 685 686 std::complex<APInt> 687 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 688 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 689 size_t offset = getDataIndex() * storageWidth * 2; 690 return {readBits(getData(), offset, bitWidth), 691 readBits(getData(), offset + storageWidth, bitWidth)}; 692 } 693 694 //===----------------------------------------------------------------------===// 695 // FloatElementIterator 696 697 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 698 const llvm::fltSemantics &smt, IntElementIterator it) 699 : llvm::mapped_iterator<IntElementIterator, 700 std::function<APFloat(const APInt &)>>( 701 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 702 703 //===----------------------------------------------------------------------===// 704 // ComplexFloatElementIterator 705 706 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( 707 const llvm::fltSemantics &smt, ComplexIntElementIterator it) 708 : llvm::mapped_iterator< 709 ComplexIntElementIterator, 710 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( 711 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { 712 return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; 713 }) {} 714 715 //===----------------------------------------------------------------------===// 716 // DenseElementsAttr 717 //===----------------------------------------------------------------------===// 718 719 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 720 ArrayRef<Attribute> values) { 721 assert(hasSameElementsOrSplat(type, values)); 722 723 // If the element type is not based on int/float/index, assume it is a string 724 // type. 725 auto eltType = type.getElementType(); 726 if (!type.getElementType().isIntOrIndexOrFloat()) { 727 SmallVector<StringRef, 8> stringValues; 728 stringValues.reserve(values.size()); 729 for (Attribute attr : values) { 730 assert(attr.isa<StringAttr>() && 731 "expected string value for non integer/index/float element"); 732 stringValues.push_back(attr.cast<StringAttr>().getValue()); 733 } 734 return get(type, stringValues); 735 } 736 737 // Otherwise, get the raw storage width to use for the allocation. 738 size_t bitWidth = getDenseElementBitWidth(eltType); 739 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 740 741 // Compress the attribute values into a character buffer. 742 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 743 values.size()); 744 APInt intVal; 745 for (unsigned i = 0, e = values.size(); i < e; ++i) { 746 assert(eltType == values[i].getType() && 747 "expected attribute value to have element type"); 748 749 switch (eltType.getKind()) { 750 case StandardTypes::BF16: 751 case StandardTypes::F16: 752 case StandardTypes::F32: 753 case StandardTypes::F64: 754 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 755 break; 756 case StandardTypes::Integer: 757 case StandardTypes::Index: 758 intVal = values[i].cast<IntegerAttr>().getValue(); 759 break; 760 default: 761 llvm_unreachable("unexpected element type"); 762 } 763 assert(intVal.getBitWidth() == bitWidth && 764 "expected value to have same bitwidth as element type"); 765 writeBits(data.data(), i * storageBitWidth, intVal); 766 } 767 return DenseIntOrFPElementsAttr::getRaw(type, data, 768 /*isSplat=*/(values.size() == 1)); 769 } 770 771 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 772 ArrayRef<bool> values) { 773 assert(hasSameElementsOrSplat(type, values)); 774 assert(type.getElementType().isInteger(1)); 775 776 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 777 for (int i = 0, e = values.size(); i != e; ++i) 778 setBit(buff.data(), i, values[i]); 779 return DenseIntOrFPElementsAttr::getRaw(type, buff, 780 /*isSplat=*/(values.size() == 1)); 781 } 782 783 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 784 ArrayRef<StringRef> values) { 785 assert(!type.getElementType().isIntOrFloat()); 786 return DenseStringElementsAttr::get(type, values); 787 } 788 789 /// Constructs a dense integer elements attribute from an array of APInt 790 /// values. Each APInt value is expected to have the same bitwidth as the 791 /// element type of 'type'. 792 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 793 ArrayRef<APInt> values) { 794 assert(type.getElementType().isIntOrIndex()); 795 assert(hasSameElementsOrSplat(type, values)); 796 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 797 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 798 /*isSplat=*/(values.size() == 1)); 799 } 800 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 801 ArrayRef<std::complex<APInt>> values) { 802 ComplexType complex = type.getElementType().cast<ComplexType>(); 803 assert(complex.getElementType().isa<IntegerType>()); 804 assert(hasSameElementsOrSplat(type, values)); 805 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 806 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 807 values.size() * 2); 808 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, 809 /*isSplat=*/(values.size() == 1)); 810 } 811 812 // Constructs a dense float elements attribute from an array of APFloat 813 // values. Each APFloat value is expected to have the same bitwidth as the 814 // element type of 'type'. 815 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 816 ArrayRef<APFloat> values) { 817 assert(type.getElementType().isa<FloatType>()); 818 assert(hasSameElementsOrSplat(type, values)); 819 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 820 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 821 /*isSplat=*/(values.size() == 1)); 822 } 823 DenseElementsAttr 824 DenseElementsAttr::get(ShapedType type, 825 ArrayRef<std::complex<APFloat>> values) { 826 ComplexType complex = type.getElementType().cast<ComplexType>(); 827 assert(complex.getElementType().isa<FloatType>()); 828 assert(hasSameElementsOrSplat(type, values)); 829 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 830 values.size() * 2); 831 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 832 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, 833 /*isSplat=*/(values.size() == 1)); 834 } 835 836 /// Construct a dense elements attribute from a raw buffer representing the 837 /// data for this attribute. Users should generally not use this methods as 838 /// the expected buffer format may not be a form the user expects. 839 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 840 ArrayRef<char> rawBuffer, 841 bool isSplatBuffer) { 842 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 843 } 844 845 /// Returns true if the given buffer is a valid raw buffer for the given type. 846 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 847 ArrayRef<char> rawBuffer, 848 bool &detectedSplat) { 849 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 850 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 851 852 // Storage width of 1 is special as it is packed by the bit. 853 if (storageWidth == 1) { 854 // Check for a splat, or a buffer equal to the number of elements. 855 if ((detectedSplat = rawBuffer.size() == 1)) 856 return true; 857 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); 858 } 859 // All other types are 8-bit aligned. 860 if ((detectedSplat = rawBufferWidth == storageWidth)) 861 return true; 862 return rawBufferWidth == (storageWidth * type.getNumElements()); 863 } 864 865 /// Check the information for a C++ data type, check if this type is valid for 866 /// the current attribute. This method is used to verify specific type 867 /// invariants that the templatized 'getValues' method cannot. 868 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 869 bool isSigned) { 870 // Make sure that the data element size is the same as the type element width. 871 if (getDenseElementBitWidth(type) != 872 static_cast<size_t>(dataEltSize * CHAR_BIT)) 873 return false; 874 875 // Check that the element type is either float or integer or index. 876 if (!isInt) 877 return type.isa<FloatType>(); 878 if (type.isIndex()) 879 return true; 880 881 auto intType = type.dyn_cast<IntegerType>(); 882 if (!intType) 883 return false; 884 885 // Make sure signedness semantics is consistent. 886 if (intType.isSignless()) 887 return true; 888 return intType.isSigned() ? isSigned : !isSigned; 889 } 890 891 /// Defaults down the subclass implementation. 892 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 893 ArrayRef<char> data, 894 int64_t dataEltSize, 895 bool isInt, bool isSigned) { 896 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 897 isSigned); 898 } 899 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 900 ArrayRef<char> data, 901 int64_t dataEltSize, 902 bool isInt, 903 bool isSigned) { 904 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 905 isInt, isSigned); 906 } 907 908 /// A method used to verify specific type invariants that the templatized 'get' 909 /// method cannot. 910 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 911 bool isSigned) const { 912 return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt, 913 isSigned); 914 } 915 916 /// Check the information for a C++ data type, check if this type is valid for 917 /// the current attribute. 918 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 919 bool isSigned) const { 920 return ::isValidIntOrFloat( 921 getType().getElementType().cast<ComplexType>().getElementType(), 922 dataEltSize / 2, isInt, isSigned); 923 } 924 925 /// Returns if this attribute corresponds to a splat, i.e. if all element 926 /// values are the same. 927 bool DenseElementsAttr::isSplat() const { 928 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 929 } 930 931 /// Return the held element values as a range of Attributes. 932 auto DenseElementsAttr::getAttributeValues() const 933 -> llvm::iterator_range<AttributeElementIterator> { 934 return {attr_value_begin(), attr_value_end()}; 935 } 936 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 937 return AttributeElementIterator(*this, 0); 938 } 939 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 940 return AttributeElementIterator(*this, getNumElements()); 941 } 942 943 /// Return the held element values as a range of bool. The element type of 944 /// this attribute must be of integer type of bitwidth 1. 945 auto DenseElementsAttr::getBoolValues() const 946 -> llvm::iterator_range<BoolElementIterator> { 947 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 948 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 949 (void)eltType; 950 return {BoolElementIterator(*this, 0), 951 BoolElementIterator(*this, getNumElements())}; 952 } 953 954 /// Return the held element values as a range of APInts. The element type of 955 /// this attribute must be of integer type. 956 auto DenseElementsAttr::getIntValues() const 957 -> llvm::iterator_range<IntElementIterator> { 958 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 959 return {raw_int_begin(), raw_int_end()}; 960 } 961 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 962 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 963 return raw_int_begin(); 964 } 965 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 966 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 967 return raw_int_end(); 968 } 969 auto DenseElementsAttr::getComplexIntValues() const 970 -> llvm::iterator_range<ComplexIntElementIterator> { 971 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 972 (void)eltTy; 973 assert(eltTy.isa<IntegerType>() && "expected complex integral type"); 974 return {ComplexIntElementIterator(*this, 0), 975 ComplexIntElementIterator(*this, getNumElements())}; 976 } 977 978 /// Return the held element values as a range of APFloat. The element type of 979 /// this attribute must be of float type. 980 auto DenseElementsAttr::getFloatValues() const 981 -> llvm::iterator_range<FloatElementIterator> { 982 auto elementType = getType().getElementType().cast<FloatType>(); 983 const auto &elementSemantics = elementType.getFloatSemantics(); 984 return {FloatElementIterator(elementSemantics, raw_int_begin()), 985 FloatElementIterator(elementSemantics, raw_int_end())}; 986 } 987 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 988 return getFloatValues().begin(); 989 } 990 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 991 return getFloatValues().end(); 992 } 993 auto DenseElementsAttr::getComplexFloatValues() const 994 -> llvm::iterator_range<ComplexFloatElementIterator> { 995 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 996 assert(eltTy.isa<FloatType>() && "expected complex float type"); 997 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 998 return {{semantics, {*this, 0}}, 999 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 1000 } 1001 1002 /// Return the raw storage data held by this attribute. 1003 ArrayRef<char> DenseElementsAttr::getRawData() const { 1004 return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data; 1005 } 1006 1007 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 1008 return static_cast<DenseStringElementsAttributeStorage *>(impl)->data; 1009 } 1010 1011 /// Return a new DenseElementsAttr that has the same data as the current 1012 /// attribute, but has been reshaped to 'newType'. The new type must have the 1013 /// same total number of elements as well as element type. 1014 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 1015 ShapedType curType = getType(); 1016 if (curType == newType) 1017 return *this; 1018 1019 (void)curType; 1020 assert(newType.getElementType() == curType.getElementType() && 1021 "expected the same element type"); 1022 assert(newType.getNumElements() == curType.getNumElements() && 1023 "expected the same number of elements"); 1024 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 1025 } 1026 1027 DenseElementsAttr 1028 DenseElementsAttr::mapValues(Type newElementType, 1029 function_ref<APInt(const APInt &)> mapping) const { 1030 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 1031 } 1032 1033 DenseElementsAttr DenseElementsAttr::mapValues( 1034 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1035 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 1036 } 1037 1038 //===----------------------------------------------------------------------===// 1039 // DenseStringElementsAttr 1040 //===----------------------------------------------------------------------===// 1041 1042 DenseStringElementsAttr 1043 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) { 1044 return Base::get(type.getContext(), StandardAttributes::DenseStringElements, 1045 type, values, (values.size() == 1)); 1046 } 1047 1048 //===----------------------------------------------------------------------===// 1049 // DenseIntOrFPElementsAttr 1050 //===----------------------------------------------------------------------===// 1051 1052 /// Utility method to write a range of APInt values to a buffer. 1053 template <typename APRangeT> 1054 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1055 APRangeT &&values) { 1056 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); 1057 size_t offset = 0; 1058 for (auto it = values.begin(), e = values.end(); it != e; 1059 ++it, offset += storageWidth) { 1060 assert((*it).getBitWidth() <= storageWidth); 1061 writeBits(data.data(), offset, *it); 1062 } 1063 } 1064 1065 /// Constructs a dense elements attribute from an array of raw APFloat values. 1066 /// Each APFloat value is expected to have the same bitwidth as the element 1067 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1068 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1069 size_t storageWidth, 1070 ArrayRef<APFloat> values, 1071 bool isSplat) { 1072 std::vector<char> data; 1073 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1074 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1075 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1076 } 1077 1078 /// Constructs a dense elements attribute from an array of raw APInt values. 1079 /// Each APInt value is expected to have the same bitwidth as the element type 1080 /// of 'type'. 1081 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1082 size_t storageWidth, 1083 ArrayRef<APInt> values, 1084 bool isSplat) { 1085 std::vector<char> data; 1086 writeAPIntsToBuffer(storageWidth, data, values); 1087 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1088 } 1089 1090 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1091 ArrayRef<char> data, 1092 bool isSplat) { 1093 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 1094 "type must be ranked tensor or vector"); 1095 assert(type.hasStaticShape() && "type must have static shape"); 1096 return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements, 1097 type, data, isSplat); 1098 } 1099 1100 /// Overload of the raw 'get' method that asserts that the given type is of 1101 /// complex type. This method is used to verify type invariants that the 1102 /// templatized 'get' method cannot. 1103 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1104 ArrayRef<char> data, 1105 int64_t dataEltSize, 1106 bool isInt, 1107 bool isSigned) { 1108 assert(::isValidIntOrFloat( 1109 type.getElementType().cast<ComplexType>().getElementType(), 1110 dataEltSize / 2, isInt, isSigned)); 1111 1112 int64_t numElements = data.size() / dataEltSize; 1113 assert(numElements == 1 || numElements == type.getNumElements()); 1114 return getRaw(type, data, /*isSplat=*/numElements == 1); 1115 } 1116 1117 /// Overload of the 'getRaw' method that asserts that the given type is of 1118 /// integer type. This method is used to verify type invariants that the 1119 /// templatized 'get' method cannot. 1120 DenseElementsAttr 1121 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1122 int64_t dataEltSize, bool isInt, 1123 bool isSigned) { 1124 assert( 1125 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1126 1127 int64_t numElements = data.size() / dataEltSize; 1128 assert(numElements == 1 || numElements == type.getNumElements()); 1129 return getRaw(type, data, /*isSplat=*/numElements == 1); 1130 } 1131 1132 //===----------------------------------------------------------------------===// 1133 // DenseFPElementsAttr 1134 //===----------------------------------------------------------------------===// 1135 1136 template <typename Fn, typename Attr> 1137 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1138 Type newElementType, 1139 llvm::SmallVectorImpl<char> &data) { 1140 size_t bitWidth = getDenseElementBitWidth(newElementType); 1141 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1142 1143 ShapedType newArrayType; 1144 if (inType.isa<RankedTensorType>()) 1145 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1146 else if (inType.isa<UnrankedTensorType>()) 1147 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1148 else if (inType.isa<VectorType>()) 1149 newArrayType = VectorType::get(inType.getShape(), newElementType); 1150 else 1151 assert(newArrayType && "Unhandled tensor type"); 1152 1153 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1154 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 1155 1156 // Functor used to process a single element value of the attribute. 1157 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1158 auto newInt = mapping(value); 1159 assert(newInt.getBitWidth() == bitWidth); 1160 writeBits(data.data(), index * storageBitWidth, newInt); 1161 }; 1162 1163 // Check for the splat case. 1164 if (attr.isSplat()) { 1165 processElt(*attr.begin(), /*index=*/0); 1166 return newArrayType; 1167 } 1168 1169 // Otherwise, process all of the element values. 1170 uint64_t elementIdx = 0; 1171 for (auto value : attr) 1172 processElt(value, elementIdx++); 1173 return newArrayType; 1174 } 1175 1176 DenseElementsAttr DenseFPElementsAttr::mapValues( 1177 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1178 llvm::SmallVector<char, 8> elementData; 1179 auto newArrayType = 1180 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1181 1182 return getRaw(newArrayType, elementData, isSplat()); 1183 } 1184 1185 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1186 bool DenseFPElementsAttr::classof(Attribute attr) { 1187 return attr.isa<DenseElementsAttr>() && 1188 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1189 } 1190 1191 //===----------------------------------------------------------------------===// 1192 // DenseIntElementsAttr 1193 //===----------------------------------------------------------------------===// 1194 1195 DenseElementsAttr DenseIntElementsAttr::mapValues( 1196 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1197 llvm::SmallVector<char, 8> elementData; 1198 auto newArrayType = 1199 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1200 1201 return getRaw(newArrayType, elementData, isSplat()); 1202 } 1203 1204 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1205 bool DenseIntElementsAttr::classof(Attribute attr) { 1206 return attr.isa<DenseElementsAttr>() && 1207 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1208 } 1209 1210 //===----------------------------------------------------------------------===// 1211 // OpaqueElementsAttr 1212 //===----------------------------------------------------------------------===// 1213 1214 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, 1215 StringRef bytes) { 1216 assert(TensorType::isValidElementType(type.getElementType()) && 1217 "Input element type should be a valid tensor element type"); 1218 return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type, 1219 dialect, bytes); 1220 } 1221 1222 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } 1223 1224 /// Return the value at the given index. If index does not refer to a valid 1225 /// element, then a null attribute is returned. 1226 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1227 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1228 if (Dialect *dialect = getDialect()) 1229 return dialect->extractElementHook(*this, index); 1230 return Attribute(); 1231 } 1232 1233 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } 1234 1235 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1236 if (auto *d = getDialect()) 1237 return d->decodeHook(*this, result); 1238 return true; 1239 } 1240 1241 //===----------------------------------------------------------------------===// 1242 // SparseElementsAttr 1243 //===----------------------------------------------------------------------===// 1244 1245 SparseElementsAttr SparseElementsAttr::get(ShapedType type, 1246 DenseElementsAttr indices, 1247 DenseElementsAttr values) { 1248 assert(indices.getType().getElementType().isInteger(64) && 1249 "expected sparse indices to be 64-bit integer values"); 1250 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 1251 "type must be ranked tensor or vector"); 1252 assert(type.hasStaticShape() && "type must have static shape"); 1253 return Base::get(type.getContext(), StandardAttributes::SparseElements, type, 1254 indices.cast<DenseIntElementsAttr>(), values); 1255 } 1256 1257 DenseIntElementsAttr SparseElementsAttr::getIndices() const { 1258 return getImpl()->indices; 1259 } 1260 1261 DenseElementsAttr SparseElementsAttr::getValues() const { 1262 return getImpl()->values; 1263 } 1264 1265 /// Return the value of the element at the given index. 1266 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1267 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1268 auto type = getType(); 1269 1270 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1271 // as a 1-D index array. 1272 auto sparseIndices = getIndices(); 1273 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1274 1275 // Check to see if the indices are a splat. 1276 if (sparseIndices.isSplat()) { 1277 // If the index is also not a splat of the index value, we know that the 1278 // value is zero. 1279 auto splatIndex = *sparseIndexValues.begin(); 1280 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1281 return getZeroAttr(); 1282 1283 // If the indices are a splat, we also expect the values to be a splat. 1284 assert(getValues().isSplat() && "expected splat values"); 1285 return getValues().getSplatValue(); 1286 } 1287 1288 // Build a mapping between known indices and the offset of the stored element. 1289 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1290 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1291 size_t rank = type.getRank(); 1292 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1293 mappedIndices.try_emplace( 1294 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1295 1296 // Look for the provided index key within the mapped indices. If the provided 1297 // index is not found, then return a zero attribute. 1298 auto it = mappedIndices.find(index); 1299 if (it == mappedIndices.end()) 1300 return getZeroAttr(); 1301 1302 // Otherwise, return the held sparse value element. 1303 return getValues().getValue(it->second); 1304 } 1305 1306 /// Get a zero APFloat for the given sparse attribute. 1307 APFloat SparseElementsAttr::getZeroAPFloat() const { 1308 auto eltType = getType().getElementType().cast<FloatType>(); 1309 return APFloat(eltType.getFloatSemantics()); 1310 } 1311 1312 /// Get a zero APInt for the given sparse attribute. 1313 APInt SparseElementsAttr::getZeroAPInt() const { 1314 auto eltType = getType().getElementType().cast<IntegerType>(); 1315 return APInt::getNullValue(eltType.getWidth()); 1316 } 1317 1318 /// Get a zero attribute for the given attribute type. 1319 Attribute SparseElementsAttr::getZeroAttr() const { 1320 auto eltType = getType().getElementType(); 1321 1322 // Handle floating point elements. 1323 if (eltType.isa<FloatType>()) 1324 return FloatAttr::get(eltType, 0); 1325 1326 // Otherwise, this is an integer. 1327 // TODO: Handle StringAttr here. 1328 return IntegerAttr::get(eltType, 0); 1329 } 1330 1331 /// Flatten, and return, all of the sparse indices in this attribute in 1332 /// row-major order. 1333 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1334 std::vector<ptrdiff_t> flatSparseIndices; 1335 1336 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1337 // as a 1-D index array. 1338 auto sparseIndices = getIndices(); 1339 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1340 if (sparseIndices.isSplat()) { 1341 SmallVector<uint64_t, 8> indices(getType().getRank(), 1342 *sparseIndexValues.begin()); 1343 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1344 return flatSparseIndices; 1345 } 1346 1347 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1348 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1349 size_t rank = getType().getRank(); 1350 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1351 flatSparseIndices.push_back(getFlattenedIndex( 1352 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1353 return flatSparseIndices; 1354 } 1355 1356 //===----------------------------------------------------------------------===// 1357 // MutableDictionaryAttr 1358 //===----------------------------------------------------------------------===// 1359 1360 MutableDictionaryAttr::MutableDictionaryAttr( 1361 ArrayRef<NamedAttribute> attributes) { 1362 setAttrs(attributes); 1363 } 1364 1365 /// Return the underlying dictionary attribute. 1366 DictionaryAttr 1367 MutableDictionaryAttr::getDictionary(MLIRContext *context) const { 1368 // Construct empty DictionaryAttr if needed. 1369 if (!attrs) 1370 return DictionaryAttr::get({}, context); 1371 return attrs; 1372 } 1373 1374 ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const { 1375 return attrs ? attrs.getValue() : llvm::None; 1376 } 1377 1378 /// Replace the held attributes with ones provided in 'newAttrs'. 1379 void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) { 1380 // Don't create an attribute list if there are no attributes. 1381 if (attributes.empty()) 1382 attrs = nullptr; 1383 else 1384 attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext()); 1385 } 1386 1387 /// Return the specified attribute if present, null otherwise. 1388 Attribute MutableDictionaryAttr::get(StringRef name) const { 1389 return attrs ? attrs.get(name) : nullptr; 1390 } 1391 1392 /// Return the specified attribute if present, null otherwise. 1393 Attribute MutableDictionaryAttr::get(Identifier name) const { 1394 return attrs ? attrs.get(name) : nullptr; 1395 } 1396 1397 /// Return the specified named attribute if present, None otherwise. 1398 Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const { 1399 return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>(); 1400 } 1401 Optional<NamedAttribute> 1402 MutableDictionaryAttr::getNamed(Identifier name) const { 1403 return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>(); 1404 } 1405 1406 /// If the an attribute exists with the specified name, change it to the new 1407 /// value. Otherwise, add a new attribute with the specified name/value. 1408 void MutableDictionaryAttr::set(Identifier name, Attribute value) { 1409 assert(value && "attributes may never be null"); 1410 1411 // Look for an existing value for the given name, and set it in-place. 1412 ArrayRef<NamedAttribute> values = getAttrs(); 1413 const auto *it = llvm::find_if( 1414 values, [name](NamedAttribute attr) { return attr.first == name; }); 1415 if (it != values.end()) { 1416 // Bail out early if the value is the same as what we already have. 1417 if (it->second == value) 1418 return; 1419 1420 SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end()); 1421 newAttrs[it - values.begin()].second = value; 1422 attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); 1423 return; 1424 } 1425 1426 // Otherwise, insert the new attribute into its sorted position. 1427 it = llvm::lower_bound(values, name); 1428 SmallVector<NamedAttribute, 8> newAttrs; 1429 newAttrs.reserve(values.size() + 1); 1430 newAttrs.append(values.begin(), it); 1431 newAttrs.push_back({name, value}); 1432 newAttrs.append(it, values.end()); 1433 attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); 1434 } 1435 1436 /// Remove the attribute with the specified name if it exists. The return 1437 /// value indicates whether the attribute was present or not. 1438 auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult { 1439 auto origAttrs = getAttrs(); 1440 for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { 1441 if (origAttrs[i].first == name) { 1442 // Handle the simple case of removing the only attribute in the list. 1443 if (e == 1) { 1444 attrs = nullptr; 1445 return RemoveResult::Removed; 1446 } 1447 1448 SmallVector<NamedAttribute, 8> newAttrs; 1449 newAttrs.reserve(origAttrs.size() - 1); 1450 newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); 1451 newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); 1452 attrs = DictionaryAttr::getWithSorted(newAttrs, 1453 newAttrs[0].second.getContext()); 1454 return RemoveResult::Removed; 1455 } 1456 } 1457 return RemoveResult::NotFound; 1458 } 1459 1460 bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) { 1461 return strcmp(lhs.first.data(), rhs.first.data()) < 0; 1462 } 1463 bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) { 1464 // This is correct even when attr.first.data()[name.size()] is not a zero 1465 // string terminator, because we only care about a less than comparison. 1466 // This can't use memcmp, because it doesn't guarantee that it will stop 1467 // reading both buffers if one is shorter than the other, even if there is 1468 // a difference. 1469 return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0; 1470 } 1471