1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Attributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Diagnostics.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/Function.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/IR/Types.h" 17 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/Twine.h" 20 #include "llvm/Support/Endian.h" 21 22 using namespace mlir; 23 using namespace mlir::detail; 24 25 //===----------------------------------------------------------------------===// 26 // AttributeStorage 27 //===----------------------------------------------------------------------===// 28 29 AttributeStorage::AttributeStorage(Type type) 30 : type(type.getAsOpaquePointer()) {} 31 AttributeStorage::AttributeStorage() : type(nullptr) {} 32 33 Type AttributeStorage::getType() const { 34 return Type::getFromOpaquePointer(type); 35 } 36 void AttributeStorage::setType(Type newType) { 37 type = newType.getAsOpaquePointer(); 38 } 39 40 //===----------------------------------------------------------------------===// 41 // Attribute 42 //===----------------------------------------------------------------------===// 43 44 /// Return the type of this attribute. 45 Type Attribute::getType() const { return impl->getType(); } 46 47 /// Return the context this attribute belongs to. 48 MLIRContext *Attribute::getContext() const { return getType().getContext(); } 49 50 /// Get the dialect this attribute is registered to. 51 Dialect &Attribute::getDialect() const { 52 return impl->getAbstractAttribute().getDialect(); 53 } 54 55 //===----------------------------------------------------------------------===// 56 // AffineMapAttr 57 //===----------------------------------------------------------------------===// 58 59 AffineMapAttr AffineMapAttr::get(AffineMap value) { 60 return Base::get(value.getContext(), value); 61 } 62 63 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } 64 65 //===----------------------------------------------------------------------===// 66 // ArrayAttr 67 //===----------------------------------------------------------------------===// 68 69 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) { 70 return Base::get(context, value); 71 } 72 73 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } 74 75 Attribute ArrayAttr::operator[](unsigned idx) const { 76 assert(idx < size() && "index out of bounds"); 77 return getValue()[idx]; 78 } 79 80 //===----------------------------------------------------------------------===// 81 // DictionaryAttr 82 //===----------------------------------------------------------------------===// 83 84 /// Helper function that does either an in place sort or sorts from source array 85 /// into destination. If inPlace then storage is both the source and the 86 /// destination, else value is the source and storage destination. Returns 87 /// whether source was sorted. 88 template <bool inPlace> 89 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 90 SmallVectorImpl<NamedAttribute> &storage) { 91 // Specialize for the common case. 92 switch (value.size()) { 93 case 0: 94 // Zero already sorted. 95 break; 96 case 1: 97 // One already sorted but may need to be copied. 98 if (!inPlace) 99 storage.assign({value[0]}); 100 break; 101 case 2: { 102 assert(value[0].first != value[1].first && 103 "DictionaryAttr element names must be unique"); 104 bool isSorted = value[0] < value[1]; 105 if (inPlace) { 106 if (!isSorted) 107 std::swap(storage[0], storage[1]); 108 } else if (isSorted) { 109 storage.assign({value[0], value[1]}); 110 } else { 111 storage.assign({value[1], value[0]}); 112 } 113 return !isSorted; 114 } 115 default: 116 if (!inPlace) 117 storage.assign(value.begin(), value.end()); 118 // Check to see they are sorted already. 119 bool isSorted = llvm::is_sorted(value); 120 if (!isSorted) { 121 // If not, do a general sort. 122 llvm::array_pod_sort(storage.begin(), storage.end()); 123 value = storage; 124 } 125 126 // Ensure that the attribute elements are unique. 127 assert(std::adjacent_find(value.begin(), value.end(), 128 [](NamedAttribute l, NamedAttribute r) { 129 return l.first == r.first; 130 }) == value.end() && 131 "DictionaryAttr element names must be unique"); 132 return !isSorted; 133 } 134 return false; 135 } 136 137 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 138 SmallVectorImpl<NamedAttribute> &storage) { 139 return dictionaryAttrSort</*inPlace=*/false>(value, storage); 140 } 141 142 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 143 return dictionaryAttrSort</*inPlace=*/true>(array, array); 144 } 145 146 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value, 147 MLIRContext *context) { 148 if (value.empty()) 149 return DictionaryAttr::getEmpty(context); 150 assert(llvm::all_of(value, 151 [](const NamedAttribute &attr) { return attr.second; }) && 152 "value cannot have null entries"); 153 154 // We need to sort the element list to canonicalize it. 155 SmallVector<NamedAttribute, 8> storage; 156 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 157 value = storage; 158 159 return Base::get(context, value); 160 } 161 /// Construct a dictionary with an array of values that is known to already be 162 /// sorted by name and uniqued. 163 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value, 164 MLIRContext *context) { 165 if (value.empty()) 166 return DictionaryAttr::getEmpty(context); 167 // Ensure that the attribute elements are unique and sorted. 168 assert(llvm::is_sorted(value, 169 [](NamedAttribute l, NamedAttribute r) { 170 return l.first.strref() < r.first.strref(); 171 }) && 172 "expected attribute values to be sorted"); 173 assert(std::adjacent_find(value.begin(), value.end(), 174 [](NamedAttribute l, NamedAttribute r) { 175 return l.first == r.first; 176 }) == value.end() && 177 "DictionaryAttr element names must be unique"); 178 return Base::get(context, value); 179 } 180 181 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const { 182 return getImpl()->getElements(); 183 } 184 185 /// Return the specified attribute if present, null otherwise. 186 Attribute DictionaryAttr::get(StringRef name) const { 187 Optional<NamedAttribute> attr = getNamed(name); 188 return attr ? attr->second : nullptr; 189 } 190 Attribute DictionaryAttr::get(Identifier name) const { 191 Optional<NamedAttribute> attr = getNamed(name); 192 return attr ? attr->second : nullptr; 193 } 194 195 /// Return the specified named attribute if present, None otherwise. 196 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 197 ArrayRef<NamedAttribute> values = getValue(); 198 const auto *it = llvm::lower_bound(values, name); 199 return it != values.end() && it->first == name ? *it 200 : Optional<NamedAttribute>(); 201 } 202 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { 203 for (auto elt : getValue()) 204 if (elt.first == name) 205 return elt; 206 return llvm::None; 207 } 208 209 DictionaryAttr::iterator DictionaryAttr::begin() const { 210 return getValue().begin(); 211 } 212 DictionaryAttr::iterator DictionaryAttr::end() const { 213 return getValue().end(); 214 } 215 size_t DictionaryAttr::size() const { return getValue().size(); } 216 217 //===----------------------------------------------------------------------===// 218 // FloatAttr 219 //===----------------------------------------------------------------------===// 220 221 FloatAttr FloatAttr::get(Type type, double value) { 222 return Base::get(type.getContext(), type, value); 223 } 224 225 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { 226 return Base::getChecked(loc, type, value); 227 } 228 229 FloatAttr FloatAttr::get(Type type, const APFloat &value) { 230 return Base::get(type.getContext(), type, value); 231 } 232 233 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { 234 return Base::getChecked(loc, type, value); 235 } 236 237 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } 238 239 double FloatAttr::getValueAsDouble() const { 240 return getValueAsDouble(getValue()); 241 } 242 double FloatAttr::getValueAsDouble(APFloat value) { 243 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 244 bool losesInfo = false; 245 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 246 &losesInfo); 247 } 248 return value.convertToDouble(); 249 } 250 251 /// Verify construction invariants. 252 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) { 253 if (!type.isa<FloatType>()) 254 return emitError(loc, "expected floating point type"); 255 return success(); 256 } 257 258 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 259 double value) { 260 return verifyFloatTypeInvariants(loc, type); 261 } 262 263 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 264 const APFloat &value) { 265 // Verify that the type is correct. 266 if (failed(verifyFloatTypeInvariants(loc, type))) 267 return failure(); 268 269 // Verify that the type semantics match that of the value. 270 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 271 return emitError( 272 loc, "FloatAttr type doesn't match the type implied by its value"); 273 } 274 return success(); 275 } 276 277 //===----------------------------------------------------------------------===// 278 // SymbolRefAttr 279 //===----------------------------------------------------------------------===// 280 281 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { 282 return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>(); 283 } 284 285 SymbolRefAttr SymbolRefAttr::get(StringRef value, 286 ArrayRef<FlatSymbolRefAttr> nestedReferences, 287 MLIRContext *ctx) { 288 return Base::get(ctx, value, nestedReferences); 289 } 290 291 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } 292 293 StringRef SymbolRefAttr::getLeafReference() const { 294 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 295 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 296 } 297 298 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const { 299 return getImpl()->getNestedRefs(); 300 } 301 302 //===----------------------------------------------------------------------===// 303 // IntegerAttr 304 //===----------------------------------------------------------------------===// 305 306 IntegerAttr IntegerAttr::get(Type type, const APInt &value) { 307 if (type.isSignlessInteger(1)) 308 return BoolAttr::get(value.getBoolValue(), type.getContext()); 309 return Base::get(type.getContext(), type, value); 310 } 311 312 IntegerAttr IntegerAttr::get(Type type, int64_t value) { 313 // This uses 64 bit APInts by default for index type. 314 if (type.isIndex()) 315 return get(type, APInt(IndexType::kInternalStorageBitWidth, value)); 316 317 auto intType = type.cast<IntegerType>(); 318 return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger())); 319 } 320 321 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } 322 323 int64_t IntegerAttr::getInt() const { 324 assert((getImpl()->getType().isIndex() || 325 getImpl()->getType().isSignlessInteger()) && 326 "must be signless integer"); 327 return getValue().getSExtValue(); 328 } 329 330 int64_t IntegerAttr::getSInt() const { 331 assert(getImpl()->getType().isSignedInteger() && "must be signed integer"); 332 return getValue().getSExtValue(); 333 } 334 335 uint64_t IntegerAttr::getUInt() const { 336 assert(getImpl()->getType().isUnsignedInteger() && 337 "must be unsigned integer"); 338 return getValue().getZExtValue(); 339 } 340 341 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) { 342 if (type.isa<IntegerType, IndexType>()) 343 return success(); 344 return emitError(loc, "expected integer or index type"); 345 } 346 347 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 348 int64_t value) { 349 return verifyIntegerTypeInvariants(loc, type); 350 } 351 352 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 353 const APInt &value) { 354 if (failed(verifyIntegerTypeInvariants(loc, type))) 355 return failure(); 356 if (auto integerType = type.dyn_cast<IntegerType>()) 357 if (integerType.getWidth() != value.getBitWidth()) 358 return emitError(loc, "integer type bit width (") 359 << integerType.getWidth() << ") doesn't match value bit width (" 360 << value.getBitWidth() << ")"; 361 return success(); 362 } 363 364 //===----------------------------------------------------------------------===// 365 // BoolAttr 366 367 bool BoolAttr::getValue() const { 368 auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl); 369 return storage->getValue().getBoolValue(); 370 } 371 372 bool BoolAttr::classof(Attribute attr) { 373 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 374 return intAttr && intAttr.getType().isSignlessInteger(1); 375 } 376 377 //===----------------------------------------------------------------------===// 378 // IntegerSetAttr 379 //===----------------------------------------------------------------------===// 380 381 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { 382 return Base::get(value.getConstraint(0).getContext(), value); 383 } 384 385 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } 386 387 //===----------------------------------------------------------------------===// 388 // OpaqueAttr 389 //===----------------------------------------------------------------------===// 390 391 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, 392 MLIRContext *context) { 393 return Base::get(context, dialect, attrData, type); 394 } 395 396 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, 397 Type type, Location location) { 398 return Base::getChecked(location, dialect, attrData, type); 399 } 400 401 /// Returns the dialect namespace of the opaque attribute. 402 Identifier OpaqueAttr::getDialectNamespace() const { 403 return getImpl()->dialectNamespace; 404 } 405 406 /// Returns the raw attribute data of the opaque attribute. 407 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } 408 409 /// Verify the construction of an opaque attribute. 410 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc, 411 Identifier dialect, 412 StringRef attrData, 413 Type type) { 414 if (!Dialect::isValidNamespace(dialect.strref())) 415 return emitError(loc, "invalid dialect namespace '") << dialect << "'"; 416 return success(); 417 } 418 419 //===----------------------------------------------------------------------===// 420 // StringAttr 421 //===----------------------------------------------------------------------===// 422 423 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { 424 return get(bytes, NoneType::get(context)); 425 } 426 427 /// Get an instance of a StringAttr with the given string and Type. 428 StringAttr StringAttr::get(StringRef bytes, Type type) { 429 return Base::get(type.getContext(), bytes, type); 430 } 431 432 StringRef StringAttr::getValue() const { return getImpl()->value; } 433 434 //===----------------------------------------------------------------------===// 435 // TypeAttr 436 //===----------------------------------------------------------------------===// 437 438 TypeAttr TypeAttr::get(Type value) { 439 return Base::get(value.getContext(), value); 440 } 441 442 Type TypeAttr::getValue() const { return getImpl()->value; } 443 444 //===----------------------------------------------------------------------===// 445 // ElementsAttr 446 //===----------------------------------------------------------------------===// 447 448 ShapedType ElementsAttr::getType() const { 449 return Attribute::getType().cast<ShapedType>(); 450 } 451 452 /// Returns the number of elements held by this attribute. 453 int64_t ElementsAttr::getNumElements() const { 454 return getType().getNumElements(); 455 } 456 457 /// Return the value at the given index. If index does not refer to a valid 458 /// element, then a null attribute is returned. 459 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 460 if (auto denseAttr = dyn_cast<DenseElementsAttr>()) 461 return denseAttr.getValue(index); 462 if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>()) 463 return opaqueAttr.getValue(index); 464 return cast<SparseElementsAttr>().getValue(index); 465 } 466 467 /// Return if the given 'index' refers to a valid element in this attribute. 468 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 469 auto type = getType(); 470 471 // Verify that the rank of the indices matches the held type. 472 auto rank = type.getRank(); 473 if (rank != static_cast<int64_t>(index.size())) 474 return false; 475 476 // Verify that all of the indices are within the shape dimensions. 477 auto shape = type.getShape(); 478 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 479 return static_cast<int64_t>(index[i]) < shape[i]; 480 }); 481 } 482 483 ElementsAttr 484 ElementsAttr::mapValues(Type newElementType, 485 function_ref<APInt(const APInt &)> mapping) const { 486 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 487 return intOrFpAttr.mapValues(newElementType, mapping); 488 llvm_unreachable("unsupported ElementsAttr subtype"); 489 } 490 491 ElementsAttr 492 ElementsAttr::mapValues(Type newElementType, 493 function_ref<APInt(const APFloat &)> mapping) const { 494 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>()) 495 return intOrFpAttr.mapValues(newElementType, mapping); 496 llvm_unreachable("unsupported ElementsAttr subtype"); 497 } 498 499 /// Method for support type inquiry through isa, cast and dyn_cast. 500 bool ElementsAttr::classof(Attribute attr) { 501 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr, 502 OpaqueElementsAttr, SparseElementsAttr>(); 503 } 504 505 /// Returns the 1 dimensional flattened row-major index from the given 506 /// multi-dimensional index. 507 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 508 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 509 auto type = getType(); 510 511 // Reduce the provided multidimensional index into a flattended 1D row-major 512 // index. 513 auto rank = type.getRank(); 514 auto shape = type.getShape(); 515 uint64_t valueIndex = 0; 516 uint64_t dimMultiplier = 1; 517 for (int i = rank - 1; i >= 0; --i) { 518 valueIndex += index[i] * dimMultiplier; 519 dimMultiplier *= shape[i]; 520 } 521 return valueIndex; 522 } 523 524 //===----------------------------------------------------------------------===// 525 // DenseElementsAttr Utilities 526 //===----------------------------------------------------------------------===// 527 528 /// Get the bitwidth of a dense element type within the buffer. 529 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 530 static size_t getDenseElementStorageWidth(size_t origWidth) { 531 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 532 } 533 static size_t getDenseElementStorageWidth(Type elementType) { 534 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 535 } 536 537 /// Set a bit to a specific value. 538 static void setBit(char *rawData, size_t bitPos, bool value) { 539 if (value) 540 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 541 else 542 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 543 } 544 545 /// Return the value of the specified bit. 546 static bool getBit(const char *rawData, size_t bitPos) { 547 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 548 } 549 550 /// Get start position of actual data in `value`. Actual data is 551 /// stored in last `bitWidth`/CHAR_BIT bytes in big endian. 552 static char *getAPIntDataPos(APInt &value, size_t bitWidth) { 553 char *dataPos = 554 const_cast<char *>(reinterpret_cast<const char *>(value.getRawData())); 555 if (llvm::support::endian::system_endianness() == 556 llvm::support::endianness::big) 557 dataPos = dataPos + 8 - llvm::divideCeil(bitWidth, CHAR_BIT); 558 return dataPos; 559 } 560 561 /// Read APInt `value` from appropriate position. 562 static void readAPInt(APInt &value, size_t bitWidth, char *outData) { 563 char *dataPos = getAPIntDataPos(value, bitWidth); 564 std::copy_n(dataPos, llvm::divideCeil(bitWidth, CHAR_BIT), outData); 565 } 566 567 /// Write `inData` to appropriate position of APInt `value`. 568 static void writeAPInt(const char *inData, size_t bitWidth, APInt &value) { 569 char *dataPos = getAPIntDataPos(value, bitWidth); 570 std::copy_n(inData, llvm::divideCeil(bitWidth, CHAR_BIT), dataPos); 571 } 572 573 /// Writes value to the bit position `bitPos` in array `rawData`. 574 static void writeBits(char *rawData, size_t bitPos, APInt value) { 575 size_t bitWidth = value.getBitWidth(); 576 577 // If the bitwidth is 1 we just toggle the specific bit. 578 if (bitWidth == 1) 579 return setBit(rawData, bitPos, value.isOneValue()); 580 581 // Otherwise, the bit position is guaranteed to be byte aligned. 582 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 583 readAPInt(value, bitWidth, rawData + (bitPos / CHAR_BIT)); 584 } 585 586 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 587 /// `rawData`. 588 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 589 // Handle a boolean bit position. 590 if (bitWidth == 1) 591 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 592 593 // Otherwise, the bit position must be 8-bit aligned. 594 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 595 APInt result(bitWidth, 0); 596 writeAPInt(rawData + (bitPos / CHAR_BIT), bitWidth, result); 597 return result; 598 } 599 600 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 601 /// the same element count as 'type'. 602 template <typename Values> 603 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 604 return (values.size() == 1) || 605 (type.getNumElements() == static_cast<int64_t>(values.size())); 606 } 607 608 //===----------------------------------------------------------------------===// 609 // DenseElementsAttr Iterators 610 //===----------------------------------------------------------------------===// 611 612 //===----------------------------------------------------------------------===// 613 // AttributeElementIterator 614 615 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 616 DenseElementsAttr attr, size_t index) 617 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 618 Attribute, Attribute, Attribute>( 619 attr.getAsOpaquePointer(), index) {} 620 621 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 622 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 623 Type eltTy = owner.getType().getElementType(); 624 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) 625 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 626 if (eltTy.isa<IndexType>()) 627 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 628 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 629 IntElementIterator intIt(owner, index); 630 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 631 return FloatAttr::get(eltTy, *floatIt); 632 } 633 if (owner.isa<DenseStringElementsAttr>()) { 634 ArrayRef<StringRef> vals = owner.getRawStringData(); 635 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 636 } 637 llvm_unreachable("unexpected element type"); 638 } 639 640 //===----------------------------------------------------------------------===// 641 // BoolElementIterator 642 643 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 644 DenseElementsAttr attr, size_t dataIndex) 645 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 646 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 647 648 bool DenseElementsAttr::BoolElementIterator::operator*() const { 649 return getBit(getData(), getDataIndex()); 650 } 651 652 //===----------------------------------------------------------------------===// 653 // IntElementIterator 654 655 DenseElementsAttr::IntElementIterator::IntElementIterator( 656 DenseElementsAttr attr, size_t dataIndex) 657 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 658 attr.getRawData().data(), attr.isSplat(), dataIndex), 659 bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} 660 661 APInt DenseElementsAttr::IntElementIterator::operator*() const { 662 return readBits(getData(), 663 getDataIndex() * getDenseElementStorageWidth(bitWidth), 664 bitWidth); 665 } 666 667 //===----------------------------------------------------------------------===// 668 // ComplexIntElementIterator 669 670 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 671 DenseElementsAttr attr, size_t dataIndex) 672 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 673 std::complex<APInt>, std::complex<APInt>, 674 std::complex<APInt>>( 675 attr.getRawData().data(), attr.isSplat(), dataIndex) { 676 auto complexType = attr.getType().getElementType().cast<ComplexType>(); 677 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 678 } 679 680 std::complex<APInt> 681 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 682 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 683 size_t offset = getDataIndex() * storageWidth * 2; 684 return {readBits(getData(), offset, bitWidth), 685 readBits(getData(), offset + storageWidth, bitWidth)}; 686 } 687 688 //===----------------------------------------------------------------------===// 689 // FloatElementIterator 690 691 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 692 const llvm::fltSemantics &smt, IntElementIterator it) 693 : llvm::mapped_iterator<IntElementIterator, 694 std::function<APFloat(const APInt &)>>( 695 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 696 697 //===----------------------------------------------------------------------===// 698 // ComplexFloatElementIterator 699 700 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( 701 const llvm::fltSemantics &smt, ComplexIntElementIterator it) 702 : llvm::mapped_iterator< 703 ComplexIntElementIterator, 704 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( 705 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { 706 return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; 707 }) {} 708 709 //===----------------------------------------------------------------------===// 710 // DenseElementsAttr 711 //===----------------------------------------------------------------------===// 712 713 /// Method for support type inquiry through isa, cast and dyn_cast. 714 bool DenseElementsAttr::classof(Attribute attr) { 715 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); 716 } 717 718 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 719 ArrayRef<Attribute> values) { 720 assert(hasSameElementsOrSplat(type, values)); 721 722 // If the element type is not based on int/float/index, assume it is a string 723 // type. 724 auto eltType = type.getElementType(); 725 if (!type.getElementType().isIntOrIndexOrFloat()) { 726 SmallVector<StringRef, 8> stringValues; 727 stringValues.reserve(values.size()); 728 for (Attribute attr : values) { 729 assert(attr.isa<StringAttr>() && 730 "expected string value for non integer/index/float element"); 731 stringValues.push_back(attr.cast<StringAttr>().getValue()); 732 } 733 return get(type, stringValues); 734 } 735 736 // Otherwise, get the raw storage width to use for the allocation. 737 size_t bitWidth = getDenseElementBitWidth(eltType); 738 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 739 740 // Compress the attribute values into a character buffer. 741 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 742 values.size()); 743 APInt intVal; 744 for (unsigned i = 0, e = values.size(); i < e; ++i) { 745 assert(eltType == values[i].getType() && 746 "expected attribute value to have element type"); 747 if (eltType.isa<FloatType>()) 748 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 749 else if (eltType.isa<IntegerType>()) 750 intVal = values[i].cast<IntegerAttr>().getValue(); 751 else 752 llvm_unreachable("unexpected element type"); 753 754 assert(intVal.getBitWidth() == bitWidth && 755 "expected value to have same bitwidth as element type"); 756 writeBits(data.data(), i * storageBitWidth, intVal); 757 } 758 return DenseIntOrFPElementsAttr::getRaw(type, data, 759 /*isSplat=*/(values.size() == 1)); 760 } 761 762 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 763 ArrayRef<bool> values) { 764 assert(hasSameElementsOrSplat(type, values)); 765 assert(type.getElementType().isInteger(1)); 766 767 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 768 for (int i = 0, e = values.size(); i != e; ++i) 769 setBit(buff.data(), i, values[i]); 770 return DenseIntOrFPElementsAttr::getRaw(type, buff, 771 /*isSplat=*/(values.size() == 1)); 772 } 773 774 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 775 ArrayRef<StringRef> values) { 776 assert(!type.getElementType().isIntOrFloat()); 777 return DenseStringElementsAttr::get(type, values); 778 } 779 780 /// Constructs a dense integer elements attribute from an array of APInt 781 /// values. Each APInt value is expected to have the same bitwidth as the 782 /// element type of 'type'. 783 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 784 ArrayRef<APInt> values) { 785 assert(type.getElementType().isIntOrIndex()); 786 assert(hasSameElementsOrSplat(type, values)); 787 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 788 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 789 /*isSplat=*/(values.size() == 1)); 790 } 791 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 792 ArrayRef<std::complex<APInt>> values) { 793 ComplexType complex = type.getElementType().cast<ComplexType>(); 794 assert(complex.getElementType().isa<IntegerType>()); 795 assert(hasSameElementsOrSplat(type, values)); 796 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 797 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 798 values.size() * 2); 799 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, 800 /*isSplat=*/(values.size() == 1)); 801 } 802 803 // Constructs a dense float elements attribute from an array of APFloat 804 // values. Each APFloat value is expected to have the same bitwidth as the 805 // element type of 'type'. 806 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 807 ArrayRef<APFloat> values) { 808 assert(type.getElementType().isa<FloatType>()); 809 assert(hasSameElementsOrSplat(type, values)); 810 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 811 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, 812 /*isSplat=*/(values.size() == 1)); 813 } 814 DenseElementsAttr 815 DenseElementsAttr::get(ShapedType type, 816 ArrayRef<std::complex<APFloat>> values) { 817 ComplexType complex = type.getElementType().cast<ComplexType>(); 818 assert(complex.getElementType().isa<FloatType>()); 819 assert(hasSameElementsOrSplat(type, values)); 820 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 821 values.size() * 2); 822 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 823 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, 824 /*isSplat=*/(values.size() == 1)); 825 } 826 827 /// Construct a dense elements attribute from a raw buffer representing the 828 /// data for this attribute. Users should generally not use this methods as 829 /// the expected buffer format may not be a form the user expects. 830 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 831 ArrayRef<char> rawBuffer, 832 bool isSplatBuffer) { 833 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 834 } 835 836 /// Returns true if the given buffer is a valid raw buffer for the given type. 837 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 838 ArrayRef<char> rawBuffer, 839 bool &detectedSplat) { 840 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 841 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 842 843 // Storage width of 1 is special as it is packed by the bit. 844 if (storageWidth == 1) { 845 // Check for a splat, or a buffer equal to the number of elements. 846 if ((detectedSplat = rawBuffer.size() == 1)) 847 return true; 848 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); 849 } 850 // All other types are 8-bit aligned. 851 if ((detectedSplat = rawBufferWidth == storageWidth)) 852 return true; 853 return rawBufferWidth == (storageWidth * type.getNumElements()); 854 } 855 856 /// Check the information for a C++ data type, check if this type is valid for 857 /// the current attribute. This method is used to verify specific type 858 /// invariants that the templatized 'getValues' method cannot. 859 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 860 bool isSigned) { 861 // Make sure that the data element size is the same as the type element width. 862 if (getDenseElementBitWidth(type) != 863 static_cast<size_t>(dataEltSize * CHAR_BIT)) 864 return false; 865 866 // Check that the element type is either float or integer or index. 867 if (!isInt) 868 return type.isa<FloatType>(); 869 if (type.isIndex()) 870 return true; 871 872 auto intType = type.dyn_cast<IntegerType>(); 873 if (!intType) 874 return false; 875 876 // Make sure signedness semantics is consistent. 877 if (intType.isSignless()) 878 return true; 879 return intType.isSigned() ? isSigned : !isSigned; 880 } 881 882 /// Defaults down the subclass implementation. 883 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 884 ArrayRef<char> data, 885 int64_t dataEltSize, 886 bool isInt, bool isSigned) { 887 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 888 isSigned); 889 } 890 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 891 ArrayRef<char> data, 892 int64_t dataEltSize, 893 bool isInt, 894 bool isSigned) { 895 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 896 isInt, isSigned); 897 } 898 899 /// A method used to verify specific type invariants that the templatized 'get' 900 /// method cannot. 901 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 902 bool isSigned) const { 903 return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt, 904 isSigned); 905 } 906 907 /// Check the information for a C++ data type, check if this type is valid for 908 /// the current attribute. 909 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 910 bool isSigned) const { 911 return ::isValidIntOrFloat( 912 getType().getElementType().cast<ComplexType>().getElementType(), 913 dataEltSize / 2, isInt, isSigned); 914 } 915 916 /// Returns true if this attribute corresponds to a splat, i.e. if all element 917 /// values are the same. 918 bool DenseElementsAttr::isSplat() const { 919 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 920 } 921 922 /// Return the held element values as a range of Attributes. 923 auto DenseElementsAttr::getAttributeValues() const 924 -> llvm::iterator_range<AttributeElementIterator> { 925 return {attr_value_begin(), attr_value_end()}; 926 } 927 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 928 return AttributeElementIterator(*this, 0); 929 } 930 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 931 return AttributeElementIterator(*this, getNumElements()); 932 } 933 934 /// Return the held element values as a range of bool. The element type of 935 /// this attribute must be of integer type of bitwidth 1. 936 auto DenseElementsAttr::getBoolValues() const 937 -> llvm::iterator_range<BoolElementIterator> { 938 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 939 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 940 (void)eltType; 941 return {BoolElementIterator(*this, 0), 942 BoolElementIterator(*this, getNumElements())}; 943 } 944 945 /// Return the held element values as a range of APInts. The element type of 946 /// this attribute must be of integer type. 947 auto DenseElementsAttr::getIntValues() const 948 -> llvm::iterator_range<IntElementIterator> { 949 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 950 return {raw_int_begin(), raw_int_end()}; 951 } 952 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 953 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 954 return raw_int_begin(); 955 } 956 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 957 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 958 return raw_int_end(); 959 } 960 auto DenseElementsAttr::getComplexIntValues() const 961 -> llvm::iterator_range<ComplexIntElementIterator> { 962 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 963 (void)eltTy; 964 assert(eltTy.isa<IntegerType>() && "expected complex integral type"); 965 return {ComplexIntElementIterator(*this, 0), 966 ComplexIntElementIterator(*this, getNumElements())}; 967 } 968 969 /// Return the held element values as a range of APFloat. The element type of 970 /// this attribute must be of float type. 971 auto DenseElementsAttr::getFloatValues() const 972 -> llvm::iterator_range<FloatElementIterator> { 973 auto elementType = getType().getElementType().cast<FloatType>(); 974 const auto &elementSemantics = elementType.getFloatSemantics(); 975 return {FloatElementIterator(elementSemantics, raw_int_begin()), 976 FloatElementIterator(elementSemantics, raw_int_end())}; 977 } 978 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 979 return getFloatValues().begin(); 980 } 981 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 982 return getFloatValues().end(); 983 } 984 auto DenseElementsAttr::getComplexFloatValues() const 985 -> llvm::iterator_range<ComplexFloatElementIterator> { 986 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); 987 assert(eltTy.isa<FloatType>() && "expected complex float type"); 988 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); 989 return {{semantics, {*this, 0}}, 990 {semantics, {*this, static_cast<size_t>(getNumElements())}}}; 991 } 992 993 /// Return the raw storage data held by this attribute. 994 ArrayRef<char> DenseElementsAttr::getRawData() const { 995 return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data; 996 } 997 998 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 999 return static_cast<DenseStringElementsAttributeStorage *>(impl)->data; 1000 } 1001 1002 /// Return a new DenseElementsAttr that has the same data as the current 1003 /// attribute, but has been reshaped to 'newType'. The new type must have the 1004 /// same total number of elements as well as element type. 1005 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 1006 ShapedType curType = getType(); 1007 if (curType == newType) 1008 return *this; 1009 1010 (void)curType; 1011 assert(newType.getElementType() == curType.getElementType() && 1012 "expected the same element type"); 1013 assert(newType.getNumElements() == curType.getNumElements() && 1014 "expected the same number of elements"); 1015 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 1016 } 1017 1018 DenseElementsAttr 1019 DenseElementsAttr::mapValues(Type newElementType, 1020 function_ref<APInt(const APInt &)> mapping) const { 1021 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 1022 } 1023 1024 DenseElementsAttr DenseElementsAttr::mapValues( 1025 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1026 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 1027 } 1028 1029 //===----------------------------------------------------------------------===// 1030 // DenseStringElementsAttr 1031 //===----------------------------------------------------------------------===// 1032 1033 DenseStringElementsAttr 1034 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) { 1035 return Base::get(type.getContext(), type, values, (values.size() == 1)); 1036 } 1037 1038 //===----------------------------------------------------------------------===// 1039 // DenseIntOrFPElementsAttr 1040 //===----------------------------------------------------------------------===// 1041 1042 /// Utility method to write a range of APInt values to a buffer. 1043 template <typename APRangeT> 1044 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1045 APRangeT &&values) { 1046 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); 1047 size_t offset = 0; 1048 for (auto it = values.begin(), e = values.end(); it != e; 1049 ++it, offset += storageWidth) { 1050 assert((*it).getBitWidth() <= storageWidth); 1051 writeBits(data.data(), offset, *it); 1052 } 1053 } 1054 1055 /// Constructs a dense elements attribute from an array of raw APFloat values. 1056 /// Each APFloat value is expected to have the same bitwidth as the element 1057 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1058 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1059 size_t storageWidth, 1060 ArrayRef<APFloat> values, 1061 bool isSplat) { 1062 std::vector<char> data; 1063 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1064 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1065 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1066 } 1067 1068 /// Constructs a dense elements attribute from an array of raw APInt values. 1069 /// Each APInt value is expected to have the same bitwidth as the element type 1070 /// of 'type'. 1071 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1072 size_t storageWidth, 1073 ArrayRef<APInt> values, 1074 bool isSplat) { 1075 std::vector<char> data; 1076 writeAPIntsToBuffer(storageWidth, data, values); 1077 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); 1078 } 1079 1080 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1081 ArrayRef<char> data, 1082 bool isSplat) { 1083 assert((type.isa<RankedTensorType, VectorType>()) && 1084 "type must be ranked tensor or vector"); 1085 assert(type.hasStaticShape() && "type must have static shape"); 1086 return Base::get(type.getContext(), type, data, isSplat); 1087 } 1088 1089 /// Overload of the raw 'get' method that asserts that the given type is of 1090 /// complex type. This method is used to verify type invariants that the 1091 /// templatized 'get' method cannot. 1092 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1093 ArrayRef<char> data, 1094 int64_t dataEltSize, 1095 bool isInt, 1096 bool isSigned) { 1097 assert(::isValidIntOrFloat( 1098 type.getElementType().cast<ComplexType>().getElementType(), 1099 dataEltSize / 2, isInt, isSigned)); 1100 1101 int64_t numElements = data.size() / dataEltSize; 1102 assert(numElements == 1 || numElements == type.getNumElements()); 1103 return getRaw(type, data, /*isSplat=*/numElements == 1); 1104 } 1105 1106 /// Overload of the 'getRaw' method that asserts that the given type is of 1107 /// integer type. This method is used to verify type invariants that the 1108 /// templatized 'get' method cannot. 1109 DenseElementsAttr 1110 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1111 int64_t dataEltSize, bool isInt, 1112 bool isSigned) { 1113 assert( 1114 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); 1115 1116 int64_t numElements = data.size() / dataEltSize; 1117 assert(numElements == 1 || numElements == type.getNumElements()); 1118 return getRaw(type, data, /*isSplat=*/numElements == 1); 1119 } 1120 1121 //===----------------------------------------------------------------------===// 1122 // DenseFPElementsAttr 1123 //===----------------------------------------------------------------------===// 1124 1125 template <typename Fn, typename Attr> 1126 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1127 Type newElementType, 1128 llvm::SmallVectorImpl<char> &data) { 1129 size_t bitWidth = getDenseElementBitWidth(newElementType); 1130 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1131 1132 ShapedType newArrayType; 1133 if (inType.isa<RankedTensorType>()) 1134 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1135 else if (inType.isa<UnrankedTensorType>()) 1136 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 1137 else if (inType.isa<VectorType>()) 1138 newArrayType = VectorType::get(inType.getShape(), newElementType); 1139 else 1140 assert(newArrayType && "Unhandled tensor type"); 1141 1142 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1143 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 1144 1145 // Functor used to process a single element value of the attribute. 1146 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1147 auto newInt = mapping(value); 1148 assert(newInt.getBitWidth() == bitWidth); 1149 writeBits(data.data(), index * storageBitWidth, newInt); 1150 }; 1151 1152 // Check for the splat case. 1153 if (attr.isSplat()) { 1154 processElt(*attr.begin(), /*index=*/0); 1155 return newArrayType; 1156 } 1157 1158 // Otherwise, process all of the element values. 1159 uint64_t elementIdx = 0; 1160 for (auto value : attr) 1161 processElt(value, elementIdx++); 1162 return newArrayType; 1163 } 1164 1165 DenseElementsAttr DenseFPElementsAttr::mapValues( 1166 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1167 llvm::SmallVector<char, 8> elementData; 1168 auto newArrayType = 1169 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1170 1171 return getRaw(newArrayType, elementData, isSplat()); 1172 } 1173 1174 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1175 bool DenseFPElementsAttr::classof(Attribute attr) { 1176 return attr.isa<DenseElementsAttr>() && 1177 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 1178 } 1179 1180 //===----------------------------------------------------------------------===// 1181 // DenseIntElementsAttr 1182 //===----------------------------------------------------------------------===// 1183 1184 DenseElementsAttr DenseIntElementsAttr::mapValues( 1185 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1186 llvm::SmallVector<char, 8> elementData; 1187 auto newArrayType = 1188 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1189 1190 return getRaw(newArrayType, elementData, isSplat()); 1191 } 1192 1193 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1194 bool DenseIntElementsAttr::classof(Attribute attr) { 1195 return attr.isa<DenseElementsAttr>() && 1196 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 1197 } 1198 1199 //===----------------------------------------------------------------------===// 1200 // OpaqueElementsAttr 1201 //===----------------------------------------------------------------------===// 1202 1203 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, 1204 StringRef bytes) { 1205 assert(TensorType::isValidElementType(type.getElementType()) && 1206 "Input element type should be a valid tensor element type"); 1207 return Base::get(type.getContext(), type, dialect, bytes); 1208 } 1209 1210 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } 1211 1212 /// Return the value at the given index. If index does not refer to a valid 1213 /// element, then a null attribute is returned. 1214 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1215 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1216 return Attribute(); 1217 } 1218 1219 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } 1220 1221 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1222 auto *d = getDialect(); 1223 if (!d) 1224 return true; 1225 auto *interface = 1226 d->getRegisteredInterface<DialectDecodeAttributesInterface>(); 1227 if (!interface) 1228 return true; 1229 return failed(interface->decode(*this, result)); 1230 } 1231 1232 //===----------------------------------------------------------------------===// 1233 // SparseElementsAttr 1234 //===----------------------------------------------------------------------===// 1235 1236 SparseElementsAttr SparseElementsAttr::get(ShapedType type, 1237 DenseElementsAttr indices, 1238 DenseElementsAttr values) { 1239 assert(indices.getType().getElementType().isInteger(64) && 1240 "expected sparse indices to be 64-bit integer values"); 1241 assert((type.isa<RankedTensorType, VectorType>()) && 1242 "type must be ranked tensor or vector"); 1243 assert(type.hasStaticShape() && "type must have static shape"); 1244 return Base::get(type.getContext(), type, 1245 indices.cast<DenseIntElementsAttr>(), values); 1246 } 1247 1248 DenseIntElementsAttr SparseElementsAttr::getIndices() const { 1249 return getImpl()->indices; 1250 } 1251 1252 DenseElementsAttr SparseElementsAttr::getValues() const { 1253 return getImpl()->values; 1254 } 1255 1256 /// Return the value of the element at the given index. 1257 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1258 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1259 auto type = getType(); 1260 1261 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1262 // as a 1-D index array. 1263 auto sparseIndices = getIndices(); 1264 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1265 1266 // Check to see if the indices are a splat. 1267 if (sparseIndices.isSplat()) { 1268 // If the index is also not a splat of the index value, we know that the 1269 // value is zero. 1270 auto splatIndex = *sparseIndexValues.begin(); 1271 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1272 return getZeroAttr(); 1273 1274 // If the indices are a splat, we also expect the values to be a splat. 1275 assert(getValues().isSplat() && "expected splat values"); 1276 return getValues().getSplatValue(); 1277 } 1278 1279 // Build a mapping between known indices and the offset of the stored element. 1280 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1281 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1282 size_t rank = type.getRank(); 1283 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1284 mappedIndices.try_emplace( 1285 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1286 1287 // Look for the provided index key within the mapped indices. If the provided 1288 // index is not found, then return a zero attribute. 1289 auto it = mappedIndices.find(index); 1290 if (it == mappedIndices.end()) 1291 return getZeroAttr(); 1292 1293 // Otherwise, return the held sparse value element. 1294 return getValues().getValue(it->second); 1295 } 1296 1297 /// Get a zero APFloat for the given sparse attribute. 1298 APFloat SparseElementsAttr::getZeroAPFloat() const { 1299 auto eltType = getType().getElementType().cast<FloatType>(); 1300 return APFloat(eltType.getFloatSemantics()); 1301 } 1302 1303 /// Get a zero APInt for the given sparse attribute. 1304 APInt SparseElementsAttr::getZeroAPInt() const { 1305 auto eltType = getType().getElementType().cast<IntegerType>(); 1306 return APInt::getNullValue(eltType.getWidth()); 1307 } 1308 1309 /// Get a zero attribute for the given attribute type. 1310 Attribute SparseElementsAttr::getZeroAttr() const { 1311 auto eltType = getType().getElementType(); 1312 1313 // Handle floating point elements. 1314 if (eltType.isa<FloatType>()) 1315 return FloatAttr::get(eltType, 0); 1316 1317 // Otherwise, this is an integer. 1318 // TODO: Handle StringAttr here. 1319 return IntegerAttr::get(eltType, 0); 1320 } 1321 1322 /// Flatten, and return, all of the sparse indices in this attribute in 1323 /// row-major order. 1324 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1325 std::vector<ptrdiff_t> flatSparseIndices; 1326 1327 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1328 // as a 1-D index array. 1329 auto sparseIndices = getIndices(); 1330 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1331 if (sparseIndices.isSplat()) { 1332 SmallVector<uint64_t, 8> indices(getType().getRank(), 1333 *sparseIndexValues.begin()); 1334 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1335 return flatSparseIndices; 1336 } 1337 1338 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1339 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1340 size_t rank = getType().getRank(); 1341 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1342 flatSparseIndices.push_back(getFlattenedIndex( 1343 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1344 return flatSparseIndices; 1345 } 1346 1347 //===----------------------------------------------------------------------===// 1348 // MutableDictionaryAttr 1349 //===----------------------------------------------------------------------===// 1350 1351 MutableDictionaryAttr::MutableDictionaryAttr( 1352 ArrayRef<NamedAttribute> attributes) { 1353 setAttrs(attributes); 1354 } 1355 1356 /// Return the underlying dictionary attribute. 1357 DictionaryAttr 1358 MutableDictionaryAttr::getDictionary(MLIRContext *context) const { 1359 // Construct empty DictionaryAttr if needed. 1360 if (!attrs) 1361 return DictionaryAttr::get({}, context); 1362 return attrs; 1363 } 1364 1365 ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const { 1366 return attrs ? attrs.getValue() : llvm::None; 1367 } 1368 1369 /// Replace the held attributes with ones provided in 'newAttrs'. 1370 void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) { 1371 // Don't create an attribute list if there are no attributes. 1372 if (attributes.empty()) 1373 attrs = nullptr; 1374 else 1375 attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext()); 1376 } 1377 1378 /// Return the specified attribute if present, null otherwise. 1379 Attribute MutableDictionaryAttr::get(StringRef name) const { 1380 return attrs ? attrs.get(name) : nullptr; 1381 } 1382 1383 /// Return the specified attribute if present, null otherwise. 1384 Attribute MutableDictionaryAttr::get(Identifier name) const { 1385 return attrs ? attrs.get(name) : nullptr; 1386 } 1387 1388 /// Return the specified named attribute if present, None otherwise. 1389 Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const { 1390 return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>(); 1391 } 1392 Optional<NamedAttribute> 1393 MutableDictionaryAttr::getNamed(Identifier name) const { 1394 return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>(); 1395 } 1396 1397 /// If the an attribute exists with the specified name, change it to the new 1398 /// value. Otherwise, add a new attribute with the specified name/value. 1399 void MutableDictionaryAttr::set(Identifier name, Attribute value) { 1400 assert(value && "attributes may never be null"); 1401 1402 // Look for an existing value for the given name, and set it in-place. 1403 ArrayRef<NamedAttribute> values = getAttrs(); 1404 const auto *it = llvm::find_if( 1405 values, [name](NamedAttribute attr) { return attr.first == name; }); 1406 if (it != values.end()) { 1407 // Bail out early if the value is the same as what we already have. 1408 if (it->second == value) 1409 return; 1410 1411 SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end()); 1412 newAttrs[it - values.begin()].second = value; 1413 attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); 1414 return; 1415 } 1416 1417 // Otherwise, insert the new attribute into its sorted position. 1418 it = llvm::lower_bound(values, name); 1419 SmallVector<NamedAttribute, 8> newAttrs; 1420 newAttrs.reserve(values.size() + 1); 1421 newAttrs.append(values.begin(), it); 1422 newAttrs.push_back({name, value}); 1423 newAttrs.append(it, values.end()); 1424 attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); 1425 } 1426 1427 /// Remove the attribute with the specified name if it exists. The return 1428 /// value indicates whether the attribute was present or not. 1429 auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult { 1430 auto origAttrs = getAttrs(); 1431 for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { 1432 if (origAttrs[i].first == name) { 1433 // Handle the simple case of removing the only attribute in the list. 1434 if (e == 1) { 1435 attrs = nullptr; 1436 return RemoveResult::Removed; 1437 } 1438 1439 SmallVector<NamedAttribute, 8> newAttrs; 1440 newAttrs.reserve(origAttrs.size() - 1); 1441 newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); 1442 newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); 1443 attrs = DictionaryAttr::getWithSorted(newAttrs, 1444 newAttrs[0].second.getContext()); 1445 return RemoveResult::Removed; 1446 } 1447 } 1448 return RemoveResult::NotFound; 1449 } 1450 1451 bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) { 1452 return strcmp(lhs.first.data(), rhs.first.data()) < 0; 1453 } 1454 bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) { 1455 // This is correct even when attr.first.data()[name.size()] is not a zero 1456 // string terminator, because we only care about a less than comparison. 1457 // This can't use memcmp, because it doesn't guarantee that it will stop 1458 // reading both buffers if one is shorter than the other, even if there is 1459 // a difference. 1460 return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0; 1461 } 1462