1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Attributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Diagnostics.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/Function.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/IR/Types.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 #include "llvm/Support/Endian.h" 20 21 using namespace mlir; 22 using namespace mlir::detail; 23 24 //===----------------------------------------------------------------------===// 25 // AttributeStorage 26 //===----------------------------------------------------------------------===// 27 28 AttributeStorage::AttributeStorage(Type type) 29 : type(type.getAsOpaquePointer()) {} 30 AttributeStorage::AttributeStorage() : type(nullptr) {} 31 32 Type AttributeStorage::getType() const { 33 return Type::getFromOpaquePointer(type); 34 } 35 void AttributeStorage::setType(Type newType) { 36 type = newType.getAsOpaquePointer(); 37 } 38 39 //===----------------------------------------------------------------------===// 40 // Attribute 41 //===----------------------------------------------------------------------===// 42 43 /// Return the type of this attribute. 44 Type Attribute::getType() const { return impl->getType(); } 45 46 /// Return the context this attribute belongs to. 47 MLIRContext *Attribute::getContext() const { return getType().getContext(); } 48 49 /// Get the dialect this attribute is registered to. 50 Dialect &Attribute::getDialect() const { 51 return impl->getAbstractAttribute().getDialect(); 52 } 53 54 //===----------------------------------------------------------------------===// 55 // AffineMapAttr 56 //===----------------------------------------------------------------------===// 57 58 AffineMapAttr AffineMapAttr::get(AffineMap value) { 59 return Base::get(value.getContext(), StandardAttributes::AffineMap, value); 60 } 61 62 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } 63 64 //===----------------------------------------------------------------------===// 65 // ArrayAttr 66 //===----------------------------------------------------------------------===// 67 68 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) { 69 return Base::get(context, StandardAttributes::Array, value); 70 } 71 72 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } 73 74 Attribute ArrayAttr::operator[](unsigned idx) const { 75 assert(idx < size() && "index out of bounds"); 76 return getValue()[idx]; 77 } 78 79 //===----------------------------------------------------------------------===// 80 // DictionaryAttr 81 //===----------------------------------------------------------------------===// 82 83 /// Helper function that does either an in place sort or sorts from source array 84 /// into destination. If inPlace then storage is both the source and the 85 /// destination, else value is the source and storage destination. Returns 86 /// whether source was sorted. 87 template <bool inPlace> 88 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 89 SmallVectorImpl<NamedAttribute> &storage) { 90 // Specialize for the common case. 91 switch (value.size()) { 92 case 0: 93 // Zero already sorted. 94 break; 95 case 1: 96 // One already sorted but may need to be copied. 97 if (!inPlace) 98 storage.assign({value[0]}); 99 break; 100 case 2: { 101 assert(value[0].first != value[1].first && 102 "DictionaryAttr element names must be unique"); 103 bool isSorted = value[0] < value[1]; 104 if (inPlace) { 105 if (!isSorted) 106 std::swap(storage[0], storage[1]); 107 } else if (isSorted) { 108 storage.assign({value[0], value[1]}); 109 } else { 110 storage.assign({value[1], value[0]}); 111 } 112 return !isSorted; 113 } 114 default: 115 if (!inPlace) 116 storage.assign(value.begin(), value.end()); 117 // Check to see they are sorted already. 118 bool isSorted = llvm::is_sorted(value); 119 if (!isSorted) { 120 // If not, do a general sort. 121 llvm::array_pod_sort(storage.begin(), storage.end()); 122 value = storage; 123 } 124 125 // Ensure that the attribute elements are unique. 126 assert(std::adjacent_find(value.begin(), value.end(), 127 [](NamedAttribute l, NamedAttribute r) { 128 return l.first == r.first; 129 }) == value.end() && 130 "DictionaryAttr element names must be unique"); 131 return !isSorted; 132 } 133 return false; 134 } 135 136 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 137 SmallVectorImpl<NamedAttribute> &storage) { 138 return dictionaryAttrSort</*inPlace=*/false>(value, storage); 139 } 140 141 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 142 return dictionaryAttrSort</*inPlace=*/true>(array, array); 143 } 144 145 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value, 146 MLIRContext *context) { 147 if (value.empty()) 148 return DictionaryAttr::getEmpty(context); 149 assert(llvm::all_of(value, 150 [](const NamedAttribute &attr) { return attr.second; }) && 151 "value cannot have null entries"); 152 153 // We need to sort the element list to canonicalize it. 154 SmallVector<NamedAttribute, 8> storage; 155 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 156 value = storage; 157 158 return Base::get(context, StandardAttributes::Dictionary, value); 159 } 160 /// Construct a dictionary with an array of values that is known to already be 161 /// sorted by name and uniqued. 162 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value, 163 MLIRContext *context) { 164 if (value.empty()) 165 return DictionaryAttr::getEmpty(context); 166 // Ensure that the attribute elements are unique and sorted. 167 assert(llvm::is_sorted(value, 168 [](NamedAttribute l, NamedAttribute r) { 169 return l.first.strref() < r.first.strref(); 170 }) && 171 "expected attribute values to be sorted"); 172 assert(std::adjacent_find(value.begin(), value.end(), 173 [](NamedAttribute l, NamedAttribute r) { 174 return l.first == r.first; 175 }) == value.end() && 176 "DictionaryAttr element names must be unique"); 177 return Base::get(context, StandardAttributes::Dictionary, value); 178 } 179 180 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const { 181 return getImpl()->getElements(); 182 } 183 184 /// Return the specified attribute if present, null otherwise. 185 Attribute DictionaryAttr::get(StringRef name) const { 186 Optional<NamedAttribute> attr = getNamed(name); 187 return attr ? attr->second : nullptr; 188 } 189 Attribute DictionaryAttr::get(Identifier name) const { 190 Optional<NamedAttribute> attr = getNamed(name); 191 return attr ? attr->second : nullptr; 192 } 193 194 /// Return the specified named attribute if present, None otherwise. 195 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 196 ArrayRef<NamedAttribute> values = getValue(); 197 const auto *it = llvm::lower_bound(values, name); 198 return it != values.end() && it->first == name ? *it 199 : Optional<NamedAttribute>(); 200 } 201 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { 202 for (auto elt : getValue()) 203 if (elt.first == name) 204 return elt; 205 return llvm::None; 206 } 207 208 DictionaryAttr::iterator DictionaryAttr::begin() const { 209 return getValue().begin(); 210 } 211 DictionaryAttr::iterator DictionaryAttr::end() const { 212 return getValue().end(); 213 } 214 size_t DictionaryAttr::size() const { return getValue().size(); } 215 216 //===----------------------------------------------------------------------===// 217 // FloatAttr 218 //===----------------------------------------------------------------------===// 219 220 FloatAttr FloatAttr::get(Type type, double value) { 221 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 222 } 223 224 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { 225 return Base::getChecked(loc, StandardAttributes::Float, type, value); 226 } 227 228 FloatAttr FloatAttr::get(Type type, const APFloat &value) { 229 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 230 } 231 232 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { 233 return Base::getChecked(loc, StandardAttributes::Float, type, value); 234 } 235 236 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } 237 238 double FloatAttr::getValueAsDouble() const { 239 return getValueAsDouble(getValue()); 240 } 241 double FloatAttr::getValueAsDouble(APFloat value) { 242 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 243 bool losesInfo = false; 244 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 245 &losesInfo); 246 } 247 return value.convertToDouble(); 248 } 249 250 /// Verify construction invariants. 251 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) { 252 if (!type.isa<FloatType>()) 253 return emitError(loc, "expected floating point type"); 254 return success(); 255 } 256 257 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 258 double value) { 259 return verifyFloatTypeInvariants(loc, type); 260 } 261 262 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 263 const APFloat &value) { 264 // Verify that the type is correct. 265 if (failed(verifyFloatTypeInvariants(loc, type))) 266 return failure(); 267 268 // Verify that the type semantics match that of the value. 269 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 270 return emitError( 271 loc, "FloatAttr type doesn't match the type implied by its value"); 272 } 273 return success(); 274 } 275 276 //===----------------------------------------------------------------------===// 277 // SymbolRefAttr 278 //===----------------------------------------------------------------------===// 279 280 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { 281 return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None) 282 .cast<FlatSymbolRefAttr>(); 283 } 284 285 SymbolRefAttr SymbolRefAttr::get(StringRef value, 286 ArrayRef<FlatSymbolRefAttr> nestedReferences, 287 MLIRContext *ctx) { 288 return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences); 289 } 290 291 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } 292 293 StringRef SymbolRefAttr::getLeafReference() const { 294 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 295 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 296 } 297 298 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const { 299 return getImpl()->getNestedRefs(); 300 } 301 302 //===----------------------------------------------------------------------===// 303 // IntegerAttr 304 //===----------------------------------------------------------------------===// 305 306 IntegerAttr IntegerAttr::get(Type type, const APInt &value) { 307 if (type.isSignlessInteger(1)) 308 return BoolAttr::get(value.getBoolValue(), type.getContext()); 309 return Base::get(type.getContext(), StandardAttributes::Integer, type, value); 310 } 311 312 IntegerAttr IntegerAttr::get(Type type, int64_t value) { 313 // This uses 64 bit APInts by default for index type. 314 if (type.isIndex()) 315 return get(type, APInt(IndexType::kInternalStorageBitWidth, value)); 316 317 auto intType = type.cast<IntegerType>(); 318 return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger())); 319 } 320 321 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } 322 323 int64_t IntegerAttr::getInt() const { 324 assert((getImpl()->getType().isIndex() || 325 getImpl()->getType().isSignlessInteger()) && 326 "must be signless integer"); 327 return getValue().getSExtValue(); 328 } 329 330 int64_t IntegerAttr::getSInt() const { 331 assert(getImpl()->getType().isSignedInteger() && "must be signed integer"); 332 return getValue().getSExtValue(); 333 } 334 335 uint64_t IntegerAttr::getUInt() const { 336 assert(getImpl()->getType().isUnsignedInteger() && 337 "must be unsigned integer"); 338 return getValue().getZExtValue(); 339 } 340 341 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) { 342 if (type.isa<IntegerType, IndexType>()) 343 return success(); 344 return emitError(loc, "expected integer or index type"); 345 } 346 347 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 348 int64_t value) { 349 return verifyIntegerTypeInvariants(loc, type); 350 } 351 352 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 353 const APInt &value) { 354 if (failed(verifyIntegerTypeInvariants(loc, type))) 355 return failure(); 356 if (auto integerType = type.dyn_cast<IntegerType>()) 357 if (integerType.getWidth() != value.getBitWidth()) 358 return emitError(loc, "integer type bit width (") 359 << integerType.getWidth() << ") doesn't match value bit width (" 360 << value.getBitWidth() << ")"; 361 return success(); 362 } 363 364 //===----------------------------------------------------------------------===// 365 // BoolAttr 366 367 bool BoolAttr::getValue() const { 368 auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl); 369 return storage->getValue().getBoolValue(); 370 } 371 372 bool BoolAttr::classof(Attribute attr) { 373 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 374 return intAttr && intAttr.getType().isSignlessInteger(1); 375 } 376 377 //===----------------------------------------------------------------------===// 378 // IntegerSetAttr 379 //===----------------------------------------------------------------------===// 380 381 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { 382 return Base::get(value.getConstraint(0).getContext(), 383 StandardAttributes::IntegerSet, value); 384 } 385 386 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } 387 388 //===----------------------------------------------------------------------===// 389 // OpaqueAttr 390 //===----------------------------------------------------------------------===// 391 392 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, 393 MLIRContext *context) { 394 return Base::get(context, StandardAttributes::Opaque, dialect, attrData, 395 type); 396 } 397 398 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, 399 Type type, Location location) { 400 return Base::getChecked(location, StandardAttributes::Opaque, dialect, 401 attrData, type); 402 } 403 404 /// Returns the dialect namespace of the opaque attribute. 405 Identifier OpaqueAttr::getDialectNamespace() const { 406 return getImpl()->dialectNamespace; 407 } 408 409 /// Returns the raw attribute data of the opaque attribute. 410 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } 411 412 /// Verify the construction of an opaque attribute. 413 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc, 414 Identifier dialect, 415 StringRef attrData, 416 Type type) { 417 if (!Dialect::isValidNamespace(dialect.strref())) 418 return emitError(loc, "invalid dialect namespace '") << dialect << "'"; 419 return success(); 420 } 421 422 //===----------------------------------------------------------------------===// 423 // StringAttr 424 //===----------------------------------------------------------------------===// 425 426 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { 427 return get(bytes, NoneType::get(context)); 428 } 429 430 /// Get an instance of a StringAttr with the given string and Type. 431 StringAttr StringAttr::get(StringRef bytes, Type type) { 432 return Base::get(type.getContext(), StandardAttributes::String, bytes, type); 433 } 434 435 StringRef StringAttr::getValue() const { return getImpl()->value; } 436 437 //===----------------------------------------------------------------------===// 438 // TypeAttr 439 //===----------------------------------------------------------------------===// 440 441 TypeAttr TypeAttr::get(Type value) { 442 return Base::get(value.getContext(), StandardAttributes::Type, value); 443 } 444 445 Type TypeAttr::getValue() const { return getImpl()->value; } 446 447 //===----------------------------------------------------------------------===// 448 // ElementsAttr 449 //===----------------------------------------------------------------------===// 450 451 ShapedType ElementsAttr::getType() const { 452 return Attribute::getType().cast<ShapedType>(); 453 } 454 455 /// Returns the number of elements held by this attribute. 456 int64_t ElementsAttr::getNumElements() const { 457 return getType().getNumElements(); 458 } 459 460 /// Return the value at the given index. If index does not refer to a valid 461 /// element, then a null attribute is returned. 462 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 463 if (auto denseAttr = dyn_cast<DenseElementsAttr>()) 464 return denseAttr.getValue(index); 465 if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>()) 466 return opaqueAttr.getValue(index); 467 return cast<SparseElementsAttr>().getValue(index); 468 } 469 470 /// Return if the given 'index' refers to a valid element in this attribute. 471 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 472 auto type = getType(); 473 474 // Verify that the rank of the indices matches the held type. 475 auto rank = type.getRank(); 476 if (rank != static_cast<int64_t>(index.size())) 477 return false; 478 479 // Verify that all of the indices are within the shape dimensions. 480 auto shape = type.getShape(); 481 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 482 return static_cast<int64_t>(index[i]) < shape[i]; 483 }); 484 } 485 486 ElementsAttr 487 ElementsAttr::mapValues(Type newElementType, 488 function_ref<APInt(const APInt &)> mapping) const { 489 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 490 return intOrFpAttr.mapValues(newElementType, mapping); 491 llvm_unreachable("unsupported ElementsAttr subtype"); 492 } 493 494 ElementsAttr 495 ElementsAttr::mapValues(Type newElementType, 496 function_ref<APInt(const APFloat &)> mapping) const { 497 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 498 return intOrFpAttr.mapValues(newElementType, mapping); 499 llvm_unreachable("unsupported ElementsAttr subtype"); 500 } 501 502 /// Method for support type inquiry through isa, cast and dyn_cast. 503 bool ElementsAttr::classof(Attribute attr) { 504 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr, 505 OpaqueElementsAttr, SparseElementsAttr>(); 506 } 507 508 /// Returns the 1 dimensional flattened row-major index from the given 509 /// multi-dimensional index. 510 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 511 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 512 auto type = getType(); 513 514 // Reduce the provided multidimensional index into a flattended 1D row-major 515 // index. 516 auto rank = type.getRank(); 517 auto shape = type.getShape(); 518 uint64_t valueIndex = 0; 519 uint64_t dimMultiplier = 1; 520 for (int i = rank - 1; i >= 0; --i) { 521 valueIndex += index[i] * dimMultiplier; 522 dimMultiplier *= shape[i]; 523 } 524 return valueIndex; 525 } 526 527 //===----------------------------------------------------------------------===// 528 // DenseElementAttr Utilities 529 //===----------------------------------------------------------------------===// 530 531 /// Get the bitwidth of a dense element type within the buffer. 532 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 533 static size_t getDenseElementStorageWidth(size_t origWidth) { 534 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 535 } 536 static size_t getDenseElementStorageWidth(Type elementType) { 537 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 538 } 539 540 /// Set a bit to a specific value. 541 static void setBit(char *rawData, size_t bitPos, bool value) { 542 if (value) 543 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 544 else 545 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 546 } 547 548 /// Return the value of the specified bit. 549 static bool getBit(const char *rawData, size_t bitPos) { 550 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 551 } 552 553 /// Get start position of actual data in `value`. Actual data is 554 /// stored in last `bitWidth`/CHAR_BIT bytes in big endian. 555 static char *getAPIntDataPos(APInt &value, size_t bitWidth) { 556 char *dataPos = 557 const_cast<char *>(reinterpret_cast<const char *>(value.getRawData())); 558 if (llvm::support::endian::system_endianness() == 559 llvm::support::endianness::big) 560 dataPos = dataPos + 8 - llvm::divideCeil(bitWidth, CHAR_BIT); 561 return dataPos; 562 } 563 564 /// Read APInt `value` from appropriate position. 565 static void readAPInt(APInt &value, size_t bitWidth, char *outData) { 566 char *dataPos = getAPIntDataPos(value, bitWidth); 567 std::copy_n(dataPos, llvm::divideCeil(bitWidth, CHAR_BIT), outData); 568 } 569 570 /// Write `inData` to appropriate position of APInt `value`. 571 static void writeAPInt(const char *inData, size_t bitWidth, APInt &value) { 572 char *dataPos = getAPIntDataPos(value, bitWidth); 573 std::copy_n(inData, llvm::divideCeil(bitWidth, CHAR_BIT), dataPos); 574 } 575 576 /// Writes value to the bit position `bitPos` in array `rawData`. 577 static void writeBits(char *rawData, size_t bitPos, APInt value) { 578 size_t bitWidth = value.getBitWidth(); 579 580 // If the bitwidth is 1 we just toggle the specific bit. 581 if (bitWidth == 1) 582 return setBit(rawData, bitPos, value.isOneValue()); 583 584 // Otherwise, the bit position is guaranteed to be byte aligned. 585 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 586 readAPInt(value, bitWidth, rawData + (bitPos / CHAR_BIT)); 587 } 588 589 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 590 /// `rawData`. 591 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 592 // Handle a boolean bit position. 593 if (bitWidth == 1) 594 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 595 596 // Otherwise, the bit position must be 8-bit aligned. 597 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 598 APInt result(bitWidth, 0); 599 writeAPInt(rawData + (bitPos / CHAR_BIT), bitWidth, result); 600 return result; 601 } 602 603 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the 604 /// same element count as 'type'. 605 template <typename Values> 606 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 607 return (values.size() == 1) || 608 (type.getNumElements() == static_cast<int64_t>(values.size())); 609 } 610 611 //===----------------------------------------------------------------------===// 612 // DenseElementAttr Iterators 613 //===----------------------------------------------------------------------===// 614 615 //===----------------------------------------------------------------------===// 616 // AttributeElementIterator 617 618 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 619 DenseElementsAttr attr, size_t index) 620 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 621 Attribute, Attribute, Attribute>( 622 attr.getAsOpaquePointer(), index) {} 623 624 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 625 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 626 Type eltTy = owner.getType().getElementType(); 627 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 628 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 629 if (eltTy.isa<IndexType>()) 630 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 631 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 632 IntElementIterator intIt(owner, index); 633 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 634 return FloatAttr::get(eltTy, *floatIt); 635 } 636 if (owner.isa<DenseStringElementsAttr>()) { 637 ArrayRef<StringRef> vals = owner.getRawStringData(); 638 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 639 } 640 llvm_unreachable("unexpected element type"); 641 } 642 643 //===----------------------------------------------------------------------===// 644 // BoolElementIterator 645 646 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 647 DenseElementsAttr attr, size_t dataIndex) 648 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 649 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 650 651 bool DenseElementsAttr::BoolElementIterator::operator*() const { 652 return getBit(getData(), getDataIndex()); 653 } 654 655 //===----------------------------------------------------------------------===// 656 // IntElementIterator 657 658 DenseElementsAttr::IntElementIterator::IntElementIterator( 659 DenseElementsAttr attr, size_t dataIndex) 660 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 661 attr.getRawData().data(), attr.isSplat(), dataIndex), 662 bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} 663 664 APInt DenseElementsAttr::IntElementIterator::operator*() const { 665 return readBits(getData(), 666 getDataIndex() * getDenseElementStorageWidth(bitWidth), 667 bitWidth); 668 } 669 670 //===----------------------------------------------------------------------===// 671 // ComplexIntElementIterator 672 673 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 674 DenseElementsAttr attr, size_t dataIndex) 675 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 676 std::complex<APInt>, std::complex<APInt>, 677 std::complex<APInt>>( 678 attr.getRawData().data(), attr.isSplat(), dataIndex) { 679 auto complexType = attr.getType().getElementType().cast<ComplexType>(); 680 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 681 } 682 683 std::complex<APInt> 684 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 685 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 686 size_t offset = getDataIndex() * storageWidth * 2; 687 return {readBits(getData(), offset, bitWidth), 688 readBits(getData(), offset + storageWidth, bitWidth)}; 689 } 690 691 //===----------------------------------------------------------------------===// 692 // FloatElementIterator 693 694 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 695 const llvm::fltSemantics &smt, IntElementIterator it) 696 : llvm::mapped_iterator<IntElementIterator, 697 std::function<APFloat(const APInt &)>>( 698 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 699 700 //===----------------------------------------------------------------------===// 701 // ComplexFloatElementIterator 702 703 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( 704 const llvm::fltSemantics &smt, ComplexIntElementIterator it) 705 : llvm::mapped_iterator< 706 ComplexIntElementIterator, 707 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( 708 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { 709 return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; 710 }) {} 711 712 //===----------------------------------------------------------------------===// 713 // DenseElementsAttr 714 //===----------------------------------------------------------------------===// 715 716 /// Method for support type inquiry through isa, cast and dyn_cast. 717 bool DenseElementsAttr::classof(Attribute attr) { 718 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 719 } 720 721 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 722 ArrayRef<Attribute> values) { 723 assert(hasSameElementsOrSplat(type, values)); 724 725 // If the element type is not based on int/float/index, assume it is a string 726 // type. 727 auto eltType = type.getElementType(); 728 if (!type.getElementType().isIntOrIndexOrFloat()) { 729 SmallVector<StringRef, 8> stringValues; 730 stringValues.reserve(values.size()); 731 for (Attribute attr : values) { 732 assert(attr.isa<StringAttr>() && 733 "expected string value for non integer/index/float element"); 734 stringValues.push_back(attr.cast<StringAttr>().getValue()); 735 } 736 return get(type, stringValues); 737 } 738 739 // Otherwise, get the raw storage width to use for the allocation. 740 size_t bitWidth = getDenseElementBitWidth(eltType); 741 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 742 743 // Compress the attribute values into a character buffer. 744 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 745 values.size()); 746 APInt intVal; 747 for (unsigned i = 0, e = values.size(); i < e; ++i) { 748 assert(eltType == values[i].getType() && 749 "expected attribute value to have element type"); 750 751 switch (eltType.getKind()) { 752 case StandardTypes::BF16: 753 case StandardTypes::F16: 754 case StandardTypes::F32: 755 case StandardTypes::F64: 756 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 757 break; 758 case StandardTypes::Integer: 759 case StandardTypes::Index: 760 intVal = values[i].cast<IntegerAttr>().getValue(); 761 break; 762 default: 763 llvm_unreachable("unexpected element type"); 764 } 765 assert(intVal.getBitWidth() == bitWidth && 766 "expected value to have same bitwidth as element type"); 767 writeBits(data.data(), i * storageBitWidth, intVal); 768 } 769 return DenseIntOrFPElementsAttr::getRaw(type, data, 770 /*isSplat=*/(values.size() == 1)); 771 } 772 773 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 774 ArrayRef<bool> values) { 775 assert(hasSameElementsOrSplat(type, values)); 776 assert(type.getElementType().isInteger(1)); 777 778 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 779 for (int i = 0, e = values.size(); i != e; ++i) 780 setBit(buff.data(), i, values[i]); 781 return DenseIntOrFPElementsAttr::getRaw(type, buff, 782 /*isSplat=*/(values.size() == 1)); 783 } 784 785 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 786 ArrayRef<StringRef> values) { 787 assert(!type.getElementType().isIntOrFloat()); 788 return DenseStringElementsAttr::get(type, values); 789 } 790 791 /// Constructs a dense integer elements attribute from an array of APInt 792 /// values. Each APInt value is expected to have the same bitwidth as the 793 /// element type of 'type'. 794 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 795 ArrayRef<APInt> values) { 796 assert(type.getElementType().isIntOrIndex()); 797 assert(hasSameElementsOrSplat(type, values)); 798 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 799 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 800 /*isSplat=*/(values.size() == 1)); 801 } 802 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 803 ArrayRef<std::complex<APInt>> values) { 804 ComplexType complex = type.getElementType().cast<ComplexType>(); 805 assert(complex.getElementType().isa<IntegerType>()); 806 assert(hasSameElementsOrSplat(type, values)); 807 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 808 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 809 values.size() * 2); 810 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, 811 /*isSplat=*/(values.size() == 1)); 812 } 813 814 // Constructs a dense float elements attribute from an array of APFloat 815 // values. Each APFloat value is expected to have the same bitwidth as the 816 // element type of 'type'. 817 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 818 ArrayRef<APFloat> values) { 819 assert(type.getElementType().isa<FloatType>()); 820 assert(hasSameElementsOrSplat(type, values)); 821 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 822 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 823 /*isSplat=*/(values.size() == 1)); 824 } 825 DenseElementsAttr 826 DenseElementsAttr::get(ShapedType type, 827 ArrayRef<std::complex<APFloat>> values) { 828 ComplexType complex = type.getElementType().cast<ComplexType>(); 829 assert(complex.getElementType().isa<FloatType>()); 830 assert(hasSameElementsOrSplat(type, values)); 831 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 832 values.size() * 2); 833 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 834 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, 835 /*isSplat=*/(values.size() == 1)); 836 } 837 838 /// Construct a dense elements attribute from a raw buffer representing the 839 /// data for this attribute. Users should generally not use this methods as 840 /// the expected buffer format may not be a form the user expects. 841 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 842 ArrayRef<char> rawBuffer, 843 bool isSplatBuffer) { 844 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 845 } 846 847 /// Returns true if the given buffer is a valid raw buffer for the given type. 848 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 849 ArrayRef<char> rawBuffer, 850 bool &detectedSplat) { 851 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 852 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 853 854 // Storage width of 1 is special as it is packed by the bit. 855 if (storageWidth == 1) { 856 // Check for a splat, or a buffer equal to the number of elements. 857 if ((detectedSplat = rawBuffer.size() == 1)) 858 return true; 859 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); 860 } 861 // All other types are 8-bit aligned. 862 if ((detectedSplat = rawBufferWidth == storageWidth)) 863 return true; 864 return rawBufferWidth == (storageWidth * type.getNumElements()); 865 } 866 867 /// Check the information for a C++ data type, check if this type is valid for 868 /// the current attribute. This method is used to verify specific type 869 /// invariants that the templatized 'getValues' method cannot. 870 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 871 bool isSigned) { 872 // Make sure that the data element size is the same as the type element width. 873 if (getDenseElementBitWidth(type) != 874 static_cast<size_t>(dataEltSize * CHAR_BIT)) 875 return false; 876 877 // Check that the element type is either float or integer or index. 878 if (!isInt) 879 return type.isa<FloatType>(); 880 if (type.isIndex()) 881 return true; 882 883 auto intType = type.dyn_cast<IntegerType>(); 884 if (!intType) 885 return false; 886 887 // Make sure signedness semantics is consistent. 888 if (intType.isSignless()) 889 return true; 890 return intType.isSigned() ? isSigned : !isSigned; 891 } 892 893 /// Defaults down the subclass implementation. 894 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 895 ArrayRef<char> data, 896 int64_t dataEltSize, 897 bool isInt, bool isSigned) { 898 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 899 isSigned); 900 } 901 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 902 ArrayRef<char> data, 903 int64_t dataEltSize, 904 bool isInt, 905 bool isSigned) { 906 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 907 isInt, isSigned); 908 } 909 910 /// A method used to verify specific type invariants that the templatized 'get' 911 /// method cannot. 912 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 913 bool isSigned) const { 914 return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt, 915 isSigned); 916 } 917 918 /// Check the information for a C++ data type, check if this type is valid for 919 /// the current attribute. 920 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 921 bool isSigned) const { 922 return ::isValidIntOrFloat( 923 getType().getElementType().cast<ComplexType>().getElementType(), 924 dataEltSize / 2, isInt, isSigned); 925 } 926 927 /// Returns if this attribute corresponds to a splat, i.e. if all element 928 /// values are the same. 929 bool DenseElementsAttr::isSplat() const { 930 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 931 } 932 933 /// Return the held element values as a range of Attributes. 934 auto DenseElementsAttr::getAttributeValues() const 935 -> llvm::iterator_range<AttributeElementIterator> { 936 return {attr_value_begin(), attr_value_end()}; 937 } 938 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 939 return AttributeElementIterator(*this, 0); 940 } 941 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 942 return AttributeElementIterator(*this, getNumElements()); 943 } 944 945 /// Return the held element values as a range of bool. The element type of 946 /// this attribute must be of integer type of bitwidth 1. 947 auto DenseElementsAttr::getBoolValues() const 948 -> llvm::iterator_range<BoolElementIterator> { 949 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 950 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 951 (void)eltType; 952 return {BoolElementIterator(*this, 0), 953 BoolElementIterator(*this, getNumElements())}; 954 } 955 956 /// Return the held element values as a range of APInts. The element type of 957 /// this attribute must be of integer type. 958 auto DenseElementsAttr::getIntValues() const 959 -> llvm::iterator_range<IntElementIterator> { 960 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 961 return {raw_int_begin(), raw_int_end()}; 962 } 963 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 964 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 965 return raw_int_begin(); 966 } 967 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 968 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 969 return raw_int_end(); 970 } 971 auto DenseElementsAttr::getComplexIntValues() const 972 -> llvm::iterator_range<ComplexIntElementIterator> { 973 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 974 (void)eltTy; 975 assert(eltTy.isa<IntegerType>() && "expected complex integral type"); 976 return {ComplexIntElementIterator(*this, 0), 977 ComplexIntElementIterator(*this, getNumElements())}; 978 } 979 980 /// Return the held element values as a range of APFloat. The element type of 981 /// this attribute must be of float type. 982 auto DenseElementsAttr::getFloatValues() const 983 -> llvm::iterator_range<FloatElementIterator> { 984 auto elementType = getType().getElementType().cast<FloatType>(); 985 const auto &elementSemantics = elementType.getFloatSemantics(); 986 return {FloatElementIterator(elementSemantics, raw_int_begin()), 987 FloatElementIterator(elementSemantics, raw_int_end())}; 988 } 989 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 990 return getFloatValues().begin(); 991 } 992 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 993 return getFloatValues().end(); 994 } 995 auto DenseElementsAttr::getComplexFloatValues() const 996 -> llvm::iterator_range<ComplexFloatElementIterator> { 997 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 998 assert(eltTy.isa<FloatType>() && "expected complex float type"); 999 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 1000 return {{semantics, {*this, 0}}, 1001 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 1002 } 1003 1004 /// Return the raw storage data held by this attribute. 1005 ArrayRef<char> DenseElementsAttr::getRawData() const { 1006 return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data; 1007 } 1008 1009 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 1010 return static_cast<DenseStringElementsAttributeStorage *>(impl)->data; 1011 } 1012 1013 /// Return a new DenseElementsAttr that has the same data as the current 1014 /// attribute, but has been reshaped to 'newType'. The new type must have the 1015 /// same total number of elements as well as element type. 1016 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 1017 ShapedType curType = getType(); 1018 if (curType == newType) 1019 return *this; 1020 1021 (void)curType; 1022 assert(newType.getElementType() == curType.getElementType() && 1023 "expected the same element type"); 1024 assert(newType.getNumElements() == curType.getNumElements() && 1025 "expected the same number of elements"); 1026 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 1027 } 1028 1029 DenseElementsAttr 1030 DenseElementsAttr::mapValues(Type newElementType, 1031 function_ref<APInt(const APInt &)> mapping) const { 1032 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 1033 } 1034 1035 DenseElementsAttr DenseElementsAttr::mapValues( 1036 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1037 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 1038 } 1039 1040 //===----------------------------------------------------------------------===// 1041 // DenseStringElementsAttr 1042 //===----------------------------------------------------------------------===// 1043 1044 DenseStringElementsAttr 1045 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) { 1046 return Base::get(type.getContext(), StandardAttributes::DenseStringElements, 1047 type, values, (values.size() == 1)); 1048 } 1049 1050 //===----------------------------------------------------------------------===// 1051 // DenseIntOrFPElementsAttr 1052 //===----------------------------------------------------------------------===// 1053 1054 /// Utility method to write a range of APInt values to a buffer. 1055 template <typename APRangeT> 1056 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1057 APRangeT &&values) { 1058 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); 1059 size_t offset = 0; 1060 for (auto it = values.begin(), e = values.end(); it != e; 1061 ++it, offset += storageWidth) { 1062 assert((*it).getBitWidth() <= storageWidth); 1063 writeBits(data.data(), offset, *it); 1064 } 1065 } 1066 1067 /// Constructs a dense elements attribute from an array of raw APFloat values. 1068 /// Each APFloat value is expected to have the same bitwidth as the element 1069 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1070 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1071 size_t storageWidth, 1072 ArrayRef<APFloat> values, 1073 bool isSplat) { 1074 std::vector<char> data; 1075 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1076 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1077 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1078 } 1079 1080 /// Constructs a dense elements attribute from an array of raw APInt values. 1081 /// Each APInt value is expected to have the same bitwidth as the element type 1082 /// of 'type'. 1083 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1084 size_t storageWidth, 1085 ArrayRef<APInt> values, 1086 bool isSplat) { 1087 std::vector<char> data; 1088 writeAPIntsToBuffer(storageWidth, data, values); 1089 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1090 } 1091 1092 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1093 ArrayRef<char> data, 1094 bool isSplat) { 1095 assert((type.isa<RankedTensorType, VectorType>()) && 1096 "type must be ranked tensor or vector"); 1097 assert(type.hasStaticShape() && "type must have static shape"); 1098 return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements, 1099 type, data, isSplat); 1100 } 1101 1102 /// Overload of the raw 'get' method that asserts that the given type is of 1103 /// complex type. This method is used to verify type invariants that the 1104 /// templatized 'get' method cannot. 1105 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1106 ArrayRef<char> data, 1107 int64_t dataEltSize, 1108 bool isInt, 1109 bool isSigned) { 1110 assert(::isValidIntOrFloat( 1111 type.getElementType().cast<ComplexType>().getElementType(), 1112 dataEltSize / 2, isInt, isSigned)); 1113 1114 int64_t numElements = data.size() / dataEltSize; 1115 assert(numElements == 1 || numElements == type.getNumElements()); 1116 return getRaw(type, data, /*isSplat=*/numElements == 1); 1117 } 1118 1119 /// Overload of the 'getRaw' method that asserts that the given type is of 1120 /// integer type. This method is used to verify type invariants that the 1121 /// templatized 'get' method cannot. 1122 DenseElementsAttr 1123 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1124 int64_t dataEltSize, bool isInt, 1125 bool isSigned) { 1126 assert( 1127 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1128 1129 int64_t numElements = data.size() / dataEltSize; 1130 assert(numElements == 1 || numElements == type.getNumElements()); 1131 return getRaw(type, data, /*isSplat=*/numElements == 1); 1132 } 1133 1134 //===----------------------------------------------------------------------===// 1135 // DenseFPElementsAttr 1136 //===----------------------------------------------------------------------===// 1137 1138 template <typename Fn, typename Attr> 1139 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1140 Type newElementType, 1141 llvm::SmallVectorImpl<char> &data) { 1142 size_t bitWidth = getDenseElementBitWidth(newElementType); 1143 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1144 1145 ShapedType newArrayType; 1146 if (inType.isa<RankedTensorType>()) 1147 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1148 else if (inType.isa<UnrankedTensorType>()) 1149 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1150 else if (inType.isa<VectorType>()) 1151 newArrayType = VectorType::get(inType.getShape(), newElementType); 1152 else 1153 assert(newArrayType && "Unhandled tensor type"); 1154 1155 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1156 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 1157 1158 // Functor used to process a single element value of the attribute. 1159 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1160 auto newInt = mapping(value); 1161 assert(newInt.getBitWidth() == bitWidth); 1162 writeBits(data.data(), index * storageBitWidth, newInt); 1163 }; 1164 1165 // Check for the splat case. 1166 if (attr.isSplat()) { 1167 processElt(*attr.begin(), /*index=*/0); 1168 return newArrayType; 1169 } 1170 1171 // Otherwise, process all of the element values. 1172 uint64_t elementIdx = 0; 1173 for (auto value : attr) 1174 processElt(value, elementIdx++); 1175 return newArrayType; 1176 } 1177 1178 DenseElementsAttr DenseFPElementsAttr::mapValues( 1179 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1180 llvm::SmallVector<char, 8> elementData; 1181 auto newArrayType = 1182 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1183 1184 return getRaw(newArrayType, elementData, isSplat()); 1185 } 1186 1187 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1188 bool DenseFPElementsAttr::classof(Attribute attr) { 1189 return attr.isa<DenseElementsAttr>() && 1190 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1191 } 1192 1193 //===----------------------------------------------------------------------===// 1194 // DenseIntElementsAttr 1195 //===----------------------------------------------------------------------===// 1196 1197 DenseElementsAttr DenseIntElementsAttr::mapValues( 1198 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1199 llvm::SmallVector<char, 8> elementData; 1200 auto newArrayType = 1201 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1202 1203 return getRaw(newArrayType, elementData, isSplat()); 1204 } 1205 1206 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1207 bool DenseIntElementsAttr::classof(Attribute attr) { 1208 return attr.isa<DenseElementsAttr>() && 1209 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1210 } 1211 1212 //===----------------------------------------------------------------------===// 1213 // OpaqueElementsAttr 1214 //===----------------------------------------------------------------------===// 1215 1216 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, 1217 StringRef bytes) { 1218 assert(TensorType::isValidElementType(type.getElementType()) && 1219 "Input element type should be a valid tensor element type"); 1220 return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type, 1221 dialect, bytes); 1222 } 1223 1224 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } 1225 1226 /// Return the value at the given index. If index does not refer to a valid 1227 /// element, then a null attribute is returned. 1228 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1229 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1230 if (Dialect *dialect = getDialect()) 1231 return dialect->extractElementHook(*this, index); 1232 return Attribute(); 1233 } 1234 1235 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } 1236 1237 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1238 if (auto *d = getDialect()) 1239 return d->decodeHook(*this, result); 1240 return true; 1241 } 1242 1243 //===----------------------------------------------------------------------===// 1244 // SparseElementsAttr 1245 //===----------------------------------------------------------------------===// 1246 1247 SparseElementsAttr SparseElementsAttr::get(ShapedType type, 1248 DenseElementsAttr indices, 1249 DenseElementsAttr values) { 1250 assert(indices.getType().getElementType().isInteger(64) && 1251 "expected sparse indices to be 64-bit integer values"); 1252 assert((type.isa<RankedTensorType, VectorType>()) && 1253 "type must be ranked tensor or vector"); 1254 assert(type.hasStaticShape() && "type must have static shape"); 1255 return Base::get(type.getContext(), StandardAttributes::SparseElements, type, 1256 indices.cast<DenseIntElementsAttr>(), values); 1257 } 1258 1259 DenseIntElementsAttr SparseElementsAttr::getIndices() const { 1260 return getImpl()->indices; 1261 } 1262 1263 DenseElementsAttr SparseElementsAttr::getValues() const { 1264 return getImpl()->values; 1265 } 1266 1267 /// Return the value of the element at the given index. 1268 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1269 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1270 auto type = getType(); 1271 1272 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1273 // as a 1-D index array. 1274 auto sparseIndices = getIndices(); 1275 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1276 1277 // Check to see if the indices are a splat. 1278 if (sparseIndices.isSplat()) { 1279 // If the index is also not a splat of the index value, we know that the 1280 // value is zero. 1281 auto splatIndex = *sparseIndexValues.begin(); 1282 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1283 return getZeroAttr(); 1284 1285 // If the indices are a splat, we also expect the values to be a splat. 1286 assert(getValues().isSplat() && "expected splat values"); 1287 return getValues().getSplatValue(); 1288 } 1289 1290 // Build a mapping between known indices and the offset of the stored element. 1291 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1292 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1293 size_t rank = type.getRank(); 1294 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1295 mappedIndices.try_emplace( 1296 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1297 1298 // Look for the provided index key within the mapped indices. If the provided 1299 // index is not found, then return a zero attribute. 1300 auto it = mappedIndices.find(index); 1301 if (it == mappedIndices.end()) 1302 return getZeroAttr(); 1303 1304 // Otherwise, return the held sparse value element. 1305 return getValues().getValue(it->second); 1306 } 1307 1308 /// Get a zero APFloat for the given sparse attribute. 1309 APFloat SparseElementsAttr::getZeroAPFloat() const { 1310 auto eltType = getType().getElementType().cast<FloatType>(); 1311 return APFloat(eltType.getFloatSemantics()); 1312 } 1313 1314 /// Get a zero APInt for the given sparse attribute. 1315 APInt SparseElementsAttr::getZeroAPInt() const { 1316 auto eltType = getType().getElementType().cast<IntegerType>(); 1317 return APInt::getNullValue(eltType.getWidth()); 1318 } 1319 1320 /// Get a zero attribute for the given attribute type. 1321 Attribute SparseElementsAttr::getZeroAttr() const { 1322 auto eltType = getType().getElementType(); 1323 1324 // Handle floating point elements. 1325 if (eltType.isa<FloatType>()) 1326 return FloatAttr::get(eltType, 0); 1327 1328 // Otherwise, this is an integer. 1329 // TODO: Handle StringAttr here. 1330 return IntegerAttr::get(eltType, 0); 1331 } 1332 1333 /// Flatten, and return, all of the sparse indices in this attribute in 1334 /// row-major order. 1335 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1336 std::vector<ptrdiff_t> flatSparseIndices; 1337 1338 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1339 // as a 1-D index array. 1340 auto sparseIndices = getIndices(); 1341 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1342 if (sparseIndices.isSplat()) { 1343 SmallVector<uint64_t, 8> indices(getType().getRank(), 1344 *sparseIndexValues.begin()); 1345 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1346 return flatSparseIndices; 1347 } 1348 1349 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1350 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1351 size_t rank = getType().getRank(); 1352 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1353 flatSparseIndices.push_back(getFlattenedIndex( 1354 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1355 return flatSparseIndices; 1356 } 1357 1358 //===----------------------------------------------------------------------===// 1359 // MutableDictionaryAttr 1360 //===----------------------------------------------------------------------===// 1361 1362 MutableDictionaryAttr::MutableDictionaryAttr( 1363 ArrayRef<NamedAttribute> attributes) { 1364 setAttrs(attributes); 1365 } 1366 1367 /// Return the underlying dictionary attribute. 1368 DictionaryAttr 1369 MutableDictionaryAttr::getDictionary(MLIRContext *context) const { 1370 // Construct empty DictionaryAttr if needed. 1371 if (!attrs) 1372 return DictionaryAttr::get({}, context); 1373 return attrs; 1374 } 1375 1376 ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const { 1377 return attrs ? attrs.getValue() : llvm::None; 1378 } 1379 1380 /// Replace the held attributes with ones provided in 'newAttrs'. 1381 void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) { 1382 // Don't create an attribute list if there are no attributes. 1383 if (attributes.empty()) 1384 attrs = nullptr; 1385 else 1386 attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext()); 1387 } 1388 1389 /// Return the specified attribute if present, null otherwise. 1390 Attribute MutableDictionaryAttr::get(StringRef name) const { 1391 return attrs ? attrs.get(name) : nullptr; 1392 } 1393 1394 /// Return the specified attribute if present, null otherwise. 1395 Attribute MutableDictionaryAttr::get(Identifier name) const { 1396 return attrs ? attrs.get(name) : nullptr; 1397 } 1398 1399 /// Return the specified named attribute if present, None otherwise. 1400 Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const { 1401 return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>(); 1402 } 1403 Optional<NamedAttribute> 1404 MutableDictionaryAttr::getNamed(Identifier name) const { 1405 return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>(); 1406 } 1407 1408 /// If the an attribute exists with the specified name, change it to the new 1409 /// value. Otherwise, add a new attribute with the specified name/value. 1410 void MutableDictionaryAttr::set(Identifier name, Attribute value) { 1411 assert(value && "attributes may never be null"); 1412 1413 // Look for an existing value for the given name, and set it in-place. 1414 ArrayRef<NamedAttribute> values = getAttrs(); 1415 const auto *it = llvm::find_if( 1416 values, [name](NamedAttribute attr) { return attr.first == name; }); 1417 if (it != values.end()) { 1418 // Bail out early if the value is the same as what we already have. 1419 if (it->second == value) 1420 return; 1421 1422 SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end()); 1423 newAttrs[it - values.begin()].second = value; 1424 attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); 1425 return; 1426 } 1427 1428 // Otherwise, insert the new attribute into its sorted position. 1429 it = llvm::lower_bound(values, name); 1430 SmallVector<NamedAttribute, 8> newAttrs; 1431 newAttrs.reserve(values.size() + 1); 1432 newAttrs.append(values.begin(), it); 1433 newAttrs.push_back({name, value}); 1434 newAttrs.append(it, values.end()); 1435 attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); 1436 } 1437 1438 /// Remove the attribute with the specified name if it exists. The return 1439 /// value indicates whether the attribute was present or not. 1440 auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult { 1441 auto origAttrs = getAttrs(); 1442 for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { 1443 if (origAttrs[i].first == name) { 1444 // Handle the simple case of removing the only attribute in the list. 1445 if (e == 1) { 1446 attrs = nullptr; 1447 return RemoveResult::Removed; 1448 } 1449 1450 SmallVector<NamedAttribute, 8> newAttrs; 1451 newAttrs.reserve(origAttrs.size() - 1); 1452 newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); 1453 newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); 1454 attrs = DictionaryAttr::getWithSorted(newAttrs, 1455 newAttrs[0].second.getContext()); 1456 return RemoveResult::Removed; 1457 } 1458 } 1459 return RemoveResult::NotFound; 1460 } 1461 1462 bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) { 1463 return strcmp(lhs.first.data(), rhs.first.data()) < 0; 1464 } 1465 bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) { 1466 // This is correct even when attr.first.data()[name.size()] is not a zero 1467 // string terminator, because we only care about a less than comparison. 1468 // This can't use memcmp, because it doesn't guarantee that it will stop 1469 // reading both buffers if one is shorter than the other, even if there is 1470 // a difference. 1471 return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0; 1472 } 1473