1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Attributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Diagnostics.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/Function.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/IR/Types.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 23 //===----------------------------------------------------------------------===// 24 // AttributeStorage 25 //===----------------------------------------------------------------------===// 26 27 AttributeStorage::AttributeStorage(Type type) 28 : type(type.getAsOpaquePointer()) {} 29 AttributeStorage::AttributeStorage() : type(nullptr) {} 30 31 Type AttributeStorage::getType() const { 32 return Type::getFromOpaquePointer(type); 33 } 34 void AttributeStorage::setType(Type newType) { 35 type = newType.getAsOpaquePointer(); 36 } 37 38 //===----------------------------------------------------------------------===// 39 // Attribute 40 //===----------------------------------------------------------------------===// 41 42 /// Return the type of this attribute. 43 Type Attribute::getType() const { return impl->getType(); } 44 45 /// Return the context this attribute belongs to. 46 MLIRContext *Attribute::getContext() const { return getType().getContext(); } 47 48 /// Get the dialect this attribute is registered to. 49 Dialect &Attribute::getDialect() const { return impl->getDialect(); } 50 51 //===----------------------------------------------------------------------===// 52 // AffineMapAttr 53 //===----------------------------------------------------------------------===// 54 55 AffineMapAttr AffineMapAttr::get(AffineMap value) { 56 return Base::get(value.getContext(), StandardAttributes::AffineMap, value); 57 } 58 59 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } 60 61 //===----------------------------------------------------------------------===// 62 // ArrayAttr 63 //===----------------------------------------------------------------------===// 64 65 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) { 66 return Base::get(context, StandardAttributes::Array, value); 67 } 68 69 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } 70 71 Attribute ArrayAttr::operator[](unsigned idx) const { 72 assert(idx < size() && "index out of bounds"); 73 return getValue()[idx]; 74 } 75 76 //===----------------------------------------------------------------------===// 77 // BoolAttr 78 //===----------------------------------------------------------------------===// 79 80 bool BoolAttr::getValue() const { return getImpl()->value; } 81 82 //===----------------------------------------------------------------------===// 83 // DictionaryAttr 84 //===----------------------------------------------------------------------===// 85 86 /// Perform a three-way comparison between the names of the specified 87 /// NamedAttributes. 88 static int compareNamedAttributes(const NamedAttribute *lhs, 89 const NamedAttribute *rhs) { 90 return strcmp(lhs->first.data(), rhs->first.data()); 91 } 92 93 /// Returns if the name of the given attribute precedes that of 'name'. 94 static bool compareNamedAttributeWithName(const NamedAttribute &attr, 95 StringRef name) { 96 // This is correct even when attr.first.data()[name.size()] is not a zero 97 // string terminator, because we only care about a less than comparison. 98 // This can't use memcmp, because it doesn't guarantee that it will stop 99 // reading both buffers if one is shorter than the other, even if there is 100 // a difference. 101 return strncmp(attr.first.data(), name.data(), name.size()) < 0; 102 } 103 104 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value, 105 MLIRContext *context) { 106 assert(llvm::all_of(value, 107 [](const NamedAttribute &attr) { return attr.second; }) && 108 "value cannot have null entries"); 109 110 // We need to sort the element list to canonicalize it, but we also don't want 111 // to do a ton of work in the super common case where the element list is 112 // already sorted. 113 SmallVector<NamedAttribute, 8> storage; 114 switch (value.size()) { 115 case 0: 116 break; 117 case 1: 118 // A single element is already sorted. 119 break; 120 case 2: 121 assert(value[0].first != value[1].first && 122 "DictionaryAttr element names must be unique"); 123 124 // Don't invoke a general sort for two element case. 125 if (compareNamedAttributes(&value[0], &value[1]) > 0) { 126 storage.push_back(value[1]); 127 storage.push_back(value[0]); 128 value = storage; 129 } 130 break; 131 default: 132 // Check to see they are sorted already. 133 bool isSorted = true; 134 for (unsigned i = 0, e = value.size() - 1; i != e; ++i) { 135 if (compareNamedAttributes(&value[i], &value[i + 1]) > 0) { 136 isSorted = false; 137 break; 138 } 139 } 140 // If not, do a general sort. 141 if (!isSorted) { 142 storage.append(value.begin(), value.end()); 143 llvm::array_pod_sort(storage.begin(), storage.end(), 144 compareNamedAttributes); 145 value = storage; 146 } 147 148 // Ensure that the attribute elements are unique. 149 assert(std::adjacent_find(value.begin(), value.end(), 150 [](NamedAttribute l, NamedAttribute r) { 151 return l.first == r.first; 152 }) == value.end() && 153 "DictionaryAttr element names must be unique"); 154 } 155 156 return Base::get(context, StandardAttributes::Dictionary, value); 157 } 158 159 /// Construct a dictionary with an array of values that is known to already be 160 /// sorted by name and uniqued. 161 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value, 162 MLIRContext *context) { 163 // Ensure that the attribute elements are unique and sorted. 164 assert(llvm::is_sorted(value, 165 [](NamedAttribute l, NamedAttribute r) { 166 return l.first.strref() < r.first.strref(); 167 }) && 168 "expected attribute values to be sorted"); 169 assert(std::adjacent_find(value.begin(), value.end(), 170 [](NamedAttribute l, NamedAttribute r) { 171 return l.first == r.first; 172 }) == value.end() && 173 "DictionaryAttr element names must be unique"); 174 return Base::get(context, StandardAttributes::Dictionary, value); 175 } 176 177 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const { 178 return getImpl()->getElements(); 179 } 180 181 /// Return the specified attribute if present, null otherwise. 182 Attribute DictionaryAttr::get(StringRef name) const { 183 ArrayRef<NamedAttribute> values = getValue(); 184 auto it = llvm::lower_bound(values, name, compareNamedAttributeWithName); 185 return it != values.end() && it->first == name ? it->second : Attribute(); 186 } 187 Attribute DictionaryAttr::get(Identifier name) const { 188 for (auto elt : getValue()) 189 if (elt.first == name) 190 return elt.second; 191 return nullptr; 192 } 193 194 DictionaryAttr::iterator DictionaryAttr::begin() const { 195 return getValue().begin(); 196 } 197 DictionaryAttr::iterator DictionaryAttr::end() const { 198 return getValue().end(); 199 } 200 size_t DictionaryAttr::size() const { return getValue().size(); } 201 202 //===----------------------------------------------------------------------===// 203 // FloatAttr 204 //===----------------------------------------------------------------------===// 205 206 FloatAttr FloatAttr::get(Type type, double value) { 207 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 208 } 209 210 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { 211 return Base::getChecked(loc, StandardAttributes::Float, type, value); 212 } 213 214 FloatAttr FloatAttr::get(Type type, const APFloat &value) { 215 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 216 } 217 218 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { 219 return Base::getChecked(loc, StandardAttributes::Float, type, value); 220 } 221 222 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } 223 224 double FloatAttr::getValueAsDouble() const { 225 return getValueAsDouble(getValue()); 226 } 227 double FloatAttr::getValueAsDouble(APFloat value) { 228 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 229 bool losesInfo = false; 230 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 231 &losesInfo); 232 } 233 return value.convertToDouble(); 234 } 235 236 /// Verify construction invariants. 237 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) { 238 if (!type.isa<FloatType>()) 239 return emitError(loc, "expected floating point type"); 240 return success(); 241 } 242 243 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 244 double value) { 245 return verifyFloatTypeInvariants(loc, type); 246 } 247 248 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 249 const APFloat &value) { 250 // Verify that the type is correct. 251 if (failed(verifyFloatTypeInvariants(loc, type))) 252 return failure(); 253 254 // Verify that the type semantics match that of the value. 255 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 256 return emitError( 257 loc, "FloatAttr type doesn't match the type implied by its value"); 258 } 259 return success(); 260 } 261 262 //===----------------------------------------------------------------------===// 263 // SymbolRefAttr 264 //===----------------------------------------------------------------------===// 265 266 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { 267 return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None) 268 .cast<FlatSymbolRefAttr>(); 269 } 270 271 SymbolRefAttr SymbolRefAttr::get(StringRef value, 272 ArrayRef<FlatSymbolRefAttr> nestedReferences, 273 MLIRContext *ctx) { 274 return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences); 275 } 276 277 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } 278 279 StringRef SymbolRefAttr::getLeafReference() const { 280 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 281 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 282 } 283 284 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const { 285 return getImpl()->getNestedRefs(); 286 } 287 288 //===----------------------------------------------------------------------===// 289 // IntegerAttr 290 //===----------------------------------------------------------------------===// 291 292 IntegerAttr IntegerAttr::get(Type type, const APInt &value) { 293 return Base::get(type.getContext(), StandardAttributes::Integer, type, value); 294 } 295 296 IntegerAttr IntegerAttr::get(Type type, int64_t value) { 297 // This uses 64 bit APInts by default for index type. 298 if (type.isIndex()) 299 return get(type, APInt(IndexType::kInternalStorageBitWidth, value)); 300 301 auto intType = type.cast<IntegerType>(); 302 return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger())); 303 } 304 305 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } 306 307 int64_t IntegerAttr::getInt() const { 308 assert((getImpl()->getType().isIndex() || 309 getImpl()->getType().isSignlessInteger()) && 310 "must be signless integer"); 311 return getValue().getSExtValue(); 312 } 313 314 int64_t IntegerAttr::getSInt() const { 315 assert(getImpl()->getType().isSignedInteger() && "must be signed integer"); 316 return getValue().getSExtValue(); 317 } 318 319 uint64_t IntegerAttr::getUInt() const { 320 assert(getImpl()->getType().isUnsignedInteger() && 321 "must be unsigned integer"); 322 return getValue().getZExtValue(); 323 } 324 325 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) { 326 if (type.isa<IntegerType>() || type.isa<IndexType>()) 327 return success(); 328 return emitError(loc, "expected integer or index type"); 329 } 330 331 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 332 int64_t value) { 333 return verifyIntegerTypeInvariants(loc, type); 334 } 335 336 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 337 const APInt &value) { 338 if (failed(verifyIntegerTypeInvariants(loc, type))) 339 return failure(); 340 if (auto integerType = type.dyn_cast<IntegerType>()) 341 if (integerType.getWidth() != value.getBitWidth()) 342 return emitError(loc, "integer type bit width (") 343 << integerType.getWidth() << ") doesn't match value bit width (" 344 << value.getBitWidth() << ")"; 345 return success(); 346 } 347 348 //===----------------------------------------------------------------------===// 349 // IntegerSetAttr 350 //===----------------------------------------------------------------------===// 351 352 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { 353 return Base::get(value.getConstraint(0).getContext(), 354 StandardAttributes::IntegerSet, value); 355 } 356 357 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } 358 359 //===----------------------------------------------------------------------===// 360 // OpaqueAttr 361 //===----------------------------------------------------------------------===// 362 363 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, 364 MLIRContext *context) { 365 return Base::get(context, StandardAttributes::Opaque, dialect, attrData, 366 type); 367 } 368 369 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, 370 Type type, Location location) { 371 return Base::getChecked(location, StandardAttributes::Opaque, dialect, 372 attrData, type); 373 } 374 375 /// Returns the dialect namespace of the opaque attribute. 376 Identifier OpaqueAttr::getDialectNamespace() const { 377 return getImpl()->dialectNamespace; 378 } 379 380 /// Returns the raw attribute data of the opaque attribute. 381 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } 382 383 /// Verify the construction of an opaque attribute. 384 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc, 385 Identifier dialect, 386 StringRef attrData, 387 Type type) { 388 if (!Dialect::isValidNamespace(dialect.strref())) 389 return emitError(loc, "invalid dialect namespace '") << dialect << "'"; 390 return success(); 391 } 392 393 //===----------------------------------------------------------------------===// 394 // StringAttr 395 //===----------------------------------------------------------------------===// 396 397 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { 398 return get(bytes, NoneType::get(context)); 399 } 400 401 /// Get an instance of a StringAttr with the given string and Type. 402 StringAttr StringAttr::get(StringRef bytes, Type type) { 403 return Base::get(type.getContext(), StandardAttributes::String, bytes, type); 404 } 405 406 StringRef StringAttr::getValue() const { return getImpl()->value; } 407 408 //===----------------------------------------------------------------------===// 409 // TypeAttr 410 //===----------------------------------------------------------------------===// 411 412 TypeAttr TypeAttr::get(Type value) { 413 return Base::get(value.getContext(), StandardAttributes::Type, value); 414 } 415 416 Type TypeAttr::getValue() const { return getImpl()->value; } 417 418 //===----------------------------------------------------------------------===// 419 // ElementsAttr 420 //===----------------------------------------------------------------------===// 421 422 ShapedType ElementsAttr::getType() const { 423 return Attribute::getType().cast<ShapedType>(); 424 } 425 426 /// Returns the number of elements held by this attribute. 427 int64_t ElementsAttr::getNumElements() const { 428 return getType().getNumElements(); 429 } 430 431 /// Return the value at the given index. If index does not refer to a valid 432 /// element, then a null attribute is returned. 433 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 434 switch (getKind()) { 435 case StandardAttributes::DenseIntOrFPElements: 436 return cast<DenseElementsAttr>().getValue(index); 437 case StandardAttributes::OpaqueElements: 438 return cast<OpaqueElementsAttr>().getValue(index); 439 case StandardAttributes::SparseElements: 440 return cast<SparseElementsAttr>().getValue(index); 441 default: 442 llvm_unreachable("unknown ElementsAttr kind"); 443 } 444 } 445 446 /// Return if the given 'index' refers to a valid element in this attribute. 447 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 448 auto type = getType(); 449 450 // Verify that the rank of the indices matches the held type. 451 auto rank = type.getRank(); 452 if (rank != static_cast<int64_t>(index.size())) 453 return false; 454 455 // Verify that all of the indices are within the shape dimensions. 456 auto shape = type.getShape(); 457 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 458 return static_cast<int64_t>(index[i]) < shape[i]; 459 }); 460 } 461 462 ElementsAttr 463 ElementsAttr::mapValues(Type newElementType, 464 function_ref<APInt(const APInt &)> mapping) const { 465 switch (getKind()) { 466 case StandardAttributes::DenseIntOrFPElements: 467 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 468 default: 469 llvm_unreachable("unsupported ElementsAttr subtype"); 470 } 471 } 472 473 ElementsAttr 474 ElementsAttr::mapValues(Type newElementType, 475 function_ref<APInt(const APFloat &)> mapping) const { 476 switch (getKind()) { 477 case StandardAttributes::DenseIntOrFPElements: 478 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 479 default: 480 llvm_unreachable("unsupported ElementsAttr subtype"); 481 } 482 } 483 484 /// Returns the 1 dimensional flattened row-major index from the given 485 /// multi-dimensional index. 486 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 487 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 488 auto type = getType(); 489 490 // Reduce the provided multidimensional index into a flattended 1D row-major 491 // index. 492 auto rank = type.getRank(); 493 auto shape = type.getShape(); 494 uint64_t valueIndex = 0; 495 uint64_t dimMultiplier = 1; 496 for (int i = rank - 1; i >= 0; --i) { 497 valueIndex += index[i] * dimMultiplier; 498 dimMultiplier *= shape[i]; 499 } 500 return valueIndex; 501 } 502 503 //===----------------------------------------------------------------------===// 504 // DenseElementAttr Utilities 505 //===----------------------------------------------------------------------===// 506 507 /// Get the bitwidth of a dense element type within the buffer. 508 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 509 static size_t getDenseElementStorageWidth(size_t origWidth) { 510 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 511 } 512 513 /// Set a bit to a specific value. 514 static void setBit(char *rawData, size_t bitPos, bool value) { 515 if (value) 516 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 517 else 518 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 519 } 520 521 /// Return the value of the specified bit. 522 static bool getBit(const char *rawData, size_t bitPos) { 523 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 524 } 525 526 /// Writes value to the bit position `bitPos` in array `rawData`. 527 static void writeBits(char *rawData, size_t bitPos, APInt value) { 528 size_t bitWidth = value.getBitWidth(); 529 530 // If the bitwidth is 1 we just toggle the specific bit. 531 if (bitWidth == 1) 532 return setBit(rawData, bitPos, value.isOneValue()); 533 534 // Otherwise, the bit position is guaranteed to be byte aligned. 535 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 536 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 537 llvm::divideCeil(bitWidth, CHAR_BIT), 538 rawData + (bitPos / CHAR_BIT)); 539 } 540 541 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 542 /// `rawData`. 543 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 544 // Handle a boolean bit position. 545 if (bitWidth == 1) 546 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 547 548 // Otherwise, the bit position must be 8-bit aligned. 549 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 550 APInt result(bitWidth, 0); 551 std::copy_n( 552 rawData + (bitPos / CHAR_BIT), llvm::divideCeil(bitWidth, CHAR_BIT), 553 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 554 return result; 555 } 556 557 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the 558 /// same element count as 'type'. 559 template <typename Values> 560 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 561 return (values.size() == 1) || 562 (type.getNumElements() == static_cast<int64_t>(values.size())); 563 } 564 565 //===----------------------------------------------------------------------===// 566 // DenseElementAttr Iterators 567 //===----------------------------------------------------------------------===// 568 569 /// Constructs a new iterator. 570 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 571 DenseElementsAttr attr, size_t index) 572 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 573 Attribute, Attribute, Attribute>( 574 attr.getAsOpaquePointer(), index) {} 575 576 /// Accesses the Attribute value at this iterator position. 577 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 578 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 579 Type eltTy = owner.getType().getElementType(); 580 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) { 581 if (intEltTy.getWidth() == 1) 582 return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(), 583 owner.getContext()); 584 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 585 } 586 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 587 IntElementIterator intIt(owner, index); 588 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 589 return FloatAttr::get(eltTy, *floatIt); 590 } 591 llvm_unreachable("unexpected element type"); 592 } 593 594 /// Constructs a new iterator. 595 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 596 DenseElementsAttr attr, size_t dataIndex) 597 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 598 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 599 600 /// Accesses the bool value at this iterator position. 601 bool DenseElementsAttr::BoolElementIterator::operator*() const { 602 return getBit(getData(), getDataIndex()); 603 } 604 605 /// Constructs a new iterator. 606 DenseElementsAttr::IntElementIterator::IntElementIterator( 607 DenseElementsAttr attr, size_t dataIndex) 608 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 609 attr.getRawData().data(), attr.isSplat(), dataIndex), 610 bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} 611 612 /// Accesses the raw APInt value at this iterator position. 613 APInt DenseElementsAttr::IntElementIterator::operator*() const { 614 return readBits(getData(), 615 getDataIndex() * getDenseElementStorageWidth(bitWidth), 616 bitWidth); 617 } 618 619 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 620 const llvm::fltSemantics &smt, IntElementIterator it) 621 : llvm::mapped_iterator<IntElementIterator, 622 std::function<APFloat(const APInt &)>>( 623 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 624 625 //===----------------------------------------------------------------------===// 626 // DenseElementsAttr 627 //===----------------------------------------------------------------------===// 628 629 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 630 ArrayRef<Attribute> values) { 631 assert(type.getElementType().isIntOrIndexOrFloat() && 632 "expected int or index or float element type"); 633 assert(hasSameElementsOrSplat(type, values)); 634 635 auto eltType = type.getElementType(); 636 size_t bitWidth = getDenseElementBitWidth(eltType); 637 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 638 639 // Compress the attribute values into a character buffer. 640 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 641 values.size()); 642 APInt intVal; 643 for (unsigned i = 0, e = values.size(); i < e; ++i) { 644 assert(eltType == values[i].getType() && 645 "expected attribute value to have element type"); 646 647 switch (eltType.getKind()) { 648 case StandardTypes::BF16: 649 case StandardTypes::F16: 650 case StandardTypes::F32: 651 case StandardTypes::F64: 652 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 653 break; 654 case StandardTypes::Integer: 655 case StandardTypes::Index: 656 intVal = values[i].isa<BoolAttr>() 657 ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0) 658 : values[i].cast<IntegerAttr>().getValue(); 659 break; 660 default: 661 llvm_unreachable("unexpected element type"); 662 } 663 assert(intVal.getBitWidth() == bitWidth && 664 "expected value to have same bitwidth as element type"); 665 writeBits(data.data(), i * storageBitWidth, intVal); 666 } 667 return DenseIntOrFPElementsAttr::getRaw(type, data, 668 /*isSplat=*/(values.size() == 1)); 669 } 670 671 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 672 ArrayRef<bool> values) { 673 assert(hasSameElementsOrSplat(type, values)); 674 assert(type.getElementType().isInteger(1)); 675 676 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 677 for (int i = 0, e = values.size(); i != e; ++i) 678 setBit(buff.data(), i, values[i]); 679 return DenseIntOrFPElementsAttr::getRaw(type, buff, 680 /*isSplat=*/(values.size() == 1)); 681 } 682 683 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 684 ArrayRef<StringRef> values) { 685 assert(!type.getElementType().isIntOrFloat()); 686 return DenseStringElementsAttr::get(type, values); 687 } 688 689 /// Constructs a dense integer elements attribute from an array of APInt 690 /// values. Each APInt value is expected to have the same bitwidth as the 691 /// element type of 'type'. 692 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 693 ArrayRef<APInt> values) { 694 assert(type.getElementType().isIntOrIndex()); 695 return DenseIntOrFPElementsAttr::getRaw(type, values); 696 } 697 698 // Constructs a dense float elements attribute from an array of APFloat 699 // values. Each APFloat value is expected to have the same bitwidth as the 700 // element type of 'type'. 701 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 702 ArrayRef<APFloat> values) { 703 assert(type.getElementType().isa<FloatType>()); 704 705 // Convert the APFloat values to APInt and create a dense elements attribute. 706 std::vector<APInt> intValues(values.size()); 707 for (unsigned i = 0, e = values.size(); i != e; ++i) 708 intValues[i] = values[i].bitcastToAPInt(); 709 return DenseIntOrFPElementsAttr::getRaw(type, intValues); 710 } 711 712 /// Construct a dense elements attribute from a raw buffer representing the 713 /// data for this attribute. Users should generally not use this methods as 714 /// the expected buffer format may not be a form the user expects. 715 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 716 ArrayRef<char> rawBuffer, 717 bool isSplatBuffer) { 718 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); 719 } 720 721 /// Check the information for a C++ data type, check if this type is valid for 722 /// the current attribute. This method is used to verify specific type 723 /// invariants that the templatized 'getValues' method cannot. 724 static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, bool isInt, 725 bool isSigned) { 726 // Make sure that the data element size is the same as the type element width. 727 if (getDenseElementBitWidth(type.getElementType()) != 728 static_cast<size_t>(dataEltSize * CHAR_BIT)) 729 return false; 730 731 // Check that the element type is either float or integer or index. 732 if (!isInt) 733 return type.getElementType().isa<FloatType>(); 734 735 if (type.getElementType().isIndex()) 736 return true; 737 738 auto intType = type.getElementType().dyn_cast<IntegerType>(); 739 if (!intType) 740 return false; 741 742 // Make sure signedness semantics is consistent. 743 if (intType.isSignless()) 744 return true; 745 return intType.isSigned() ? isSigned : !isSigned; 746 } 747 748 /// Defaults down the subclass implementation. 749 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 750 ArrayRef<char> data, 751 int64_t dataEltSize, 752 bool isInt, 753 bool isSigned) { 754 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 755 isInt, isSigned); 756 } 757 758 /// A method used to verify specific type invariants that the templatized 'get' 759 /// method cannot. 760 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 761 bool isSigned) const { 762 return ::isValidIntOrFloat(getType(), dataEltSize, isInt, isSigned); 763 } 764 765 /// Returns if this attribute corresponds to a splat, i.e. if all element 766 /// values are the same. 767 bool DenseElementsAttr::isSplat() const { 768 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 769 } 770 771 /// Return the held element values as a range of Attributes. 772 auto DenseElementsAttr::getAttributeValues() const 773 -> llvm::iterator_range<AttributeElementIterator> { 774 return {attr_value_begin(), attr_value_end()}; 775 } 776 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 777 return AttributeElementIterator(*this, 0); 778 } 779 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 780 return AttributeElementIterator(*this, getNumElements()); 781 } 782 783 /// Return the held element values as a range of bool. The element type of 784 /// this attribute must be of integer type of bitwidth 1. 785 auto DenseElementsAttr::getBoolValues() const 786 -> llvm::iterator_range<BoolElementIterator> { 787 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 788 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 789 (void)eltType; 790 return {BoolElementIterator(*this, 0), 791 BoolElementIterator(*this, getNumElements())}; 792 } 793 794 /// Return the held element values as a range of APInts. The element type of 795 /// this attribute must be of integer type. 796 auto DenseElementsAttr::getIntValues() const 797 -> llvm::iterator_range<IntElementIterator> { 798 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 799 return {raw_int_begin(), raw_int_end()}; 800 } 801 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 802 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 803 return raw_int_begin(); 804 } 805 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 806 assert(getType().getElementType().isIntOrIndex() && "expected integral type"); 807 return raw_int_end(); 808 } 809 810 /// Return the held element values as a range of APFloat. The element type of 811 /// this attribute must be of float type. 812 auto DenseElementsAttr::getFloatValues() const 813 -> llvm::iterator_range<FloatElementIterator> { 814 auto elementType = getType().getElementType().cast<FloatType>(); 815 assert(elementType.isa<FloatType>() && "expected float type"); 816 const auto &elementSemantics = elementType.getFloatSemantics(); 817 return {FloatElementIterator(elementSemantics, raw_int_begin()), 818 FloatElementIterator(elementSemantics, raw_int_end())}; 819 } 820 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 821 return getFloatValues().begin(); 822 } 823 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 824 return getFloatValues().end(); 825 } 826 827 /// Return the raw storage data held by this attribute. 828 ArrayRef<char> DenseElementsAttr::getRawData() const { 829 return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data; 830 } 831 832 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 833 return static_cast<DenseStringElementsAttributeStorage *>(impl)->data; 834 } 835 836 /// Return a new DenseElementsAttr that has the same data as the current 837 /// attribute, but has been reshaped to 'newType'. The new type must have the 838 /// same total number of elements as well as element type. 839 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 840 ShapedType curType = getType(); 841 if (curType == newType) 842 return *this; 843 844 (void)curType; 845 assert(newType.getElementType() == curType.getElementType() && 846 "expected the same element type"); 847 assert(newType.getNumElements() == curType.getNumElements() && 848 "expected the same number of elements"); 849 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); 850 } 851 852 DenseElementsAttr 853 DenseElementsAttr::mapValues(Type newElementType, 854 function_ref<APInt(const APInt &)> mapping) const { 855 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 856 } 857 858 DenseElementsAttr DenseElementsAttr::mapValues( 859 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 860 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 861 } 862 863 //===----------------------------------------------------------------------===// 864 // DenseStringElementsAttr 865 //===----------------------------------------------------------------------===// 866 867 DenseStringElementsAttr 868 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) { 869 return Base::get(type.getContext(), StandardAttributes::DenseStringElements, 870 type, values, (values.size() == 1)); 871 } 872 873 //===----------------------------------------------------------------------===// 874 // DenseIntOrFPElementsAttr 875 //===----------------------------------------------------------------------===// 876 877 /// Constructs a dense elements attribute from an array of raw APInt values. 878 /// Each APInt value is expected to have the same bitwidth as the element type 879 /// of 'type'. 880 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 881 ArrayRef<APInt> values) { 882 assert(hasSameElementsOrSplat(type, values)); 883 884 size_t bitWidth = getDenseElementBitWidth(type.getElementType()); 885 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 886 std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 887 values.size()); 888 for (unsigned i = 0, e = values.size(); i != e; ++i) { 889 assert(values[i].getBitWidth() == bitWidth); 890 writeBits(elementData.data(), i * storageBitWidth, values[i]); 891 } 892 return DenseIntOrFPElementsAttr::getRaw(type, elementData, 893 /*isSplat=*/(values.size() == 1)); 894 } 895 896 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 897 ArrayRef<char> data, 898 bool isSplat) { 899 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 900 "type must be ranked tensor or vector"); 901 assert(type.hasStaticShape() && "type must have static shape"); 902 return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements, 903 type, data, isSplat); 904 } 905 906 /// Overload of the 'getRaw' method that asserts that the given type is of 907 /// integer type. This method is used to verify type invariants that the 908 /// templatized 'get' method cannot. 909 DenseElementsAttr 910 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 911 int64_t dataEltSize, bool isInt, 912 bool isSigned) { 913 assert(::isValidIntOrFloat(type, dataEltSize, isInt, isSigned)); 914 915 int64_t numElements = data.size() / dataEltSize; 916 assert(numElements == 1 || numElements == type.getNumElements()); 917 return getRaw(type, data, /*isSplat=*/numElements == 1); 918 } 919 920 //===----------------------------------------------------------------------===// 921 // DenseFPElementsAttr 922 //===----------------------------------------------------------------------===// 923 924 template <typename Fn, typename Attr> 925 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 926 Type newElementType, 927 llvm::SmallVectorImpl<char> &data) { 928 size_t bitWidth = getDenseElementBitWidth(newElementType); 929 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 930 931 ShapedType newArrayType; 932 if (inType.isa<RankedTensorType>()) 933 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 934 else if (inType.isa<UnrankedTensorType>()) 935 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 936 else if (inType.isa<VectorType>()) 937 newArrayType = VectorType::get(inType.getShape(), newElementType); 938 else 939 assert(newArrayType && "Unhandled tensor type"); 940 941 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 942 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 943 944 // Functor used to process a single element value of the attribute. 945 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 946 auto newInt = mapping(value); 947 assert(newInt.getBitWidth() == bitWidth); 948 writeBits(data.data(), index * storageBitWidth, newInt); 949 }; 950 951 // Check for the splat case. 952 if (attr.isSplat()) { 953 processElt(*attr.begin(), /*index=*/0); 954 return newArrayType; 955 } 956 957 // Otherwise, process all of the element values. 958 uint64_t elementIdx = 0; 959 for (auto value : attr) 960 processElt(value, elementIdx++); 961 return newArrayType; 962 } 963 964 DenseElementsAttr DenseFPElementsAttr::mapValues( 965 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 966 llvm::SmallVector<char, 8> elementData; 967 auto newArrayType = 968 mappingHelper(mapping, *this, getType(), newElementType, elementData); 969 970 return getRaw(newArrayType, elementData, isSplat()); 971 } 972 973 /// Method for supporting type inquiry through isa, cast and dyn_cast. 974 bool DenseFPElementsAttr::classof(Attribute attr) { 975 return attr.isa<DenseElementsAttr>() && 976 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 977 } 978 979 //===----------------------------------------------------------------------===// 980 // DenseIntElementsAttr 981 //===----------------------------------------------------------------------===// 982 983 DenseElementsAttr DenseIntElementsAttr::mapValues( 984 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 985 llvm::SmallVector<char, 8> elementData; 986 auto newArrayType = 987 mappingHelper(mapping, *this, getType(), newElementType, elementData); 988 989 return getRaw(newArrayType, elementData, isSplat()); 990 } 991 992 /// Method for supporting type inquiry through isa, cast and dyn_cast. 993 bool DenseIntElementsAttr::classof(Attribute attr) { 994 return attr.isa<DenseElementsAttr>() && 995 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); 996 } 997 998 //===----------------------------------------------------------------------===// 999 // OpaqueElementsAttr 1000 //===----------------------------------------------------------------------===// 1001 1002 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, 1003 StringRef bytes) { 1004 assert(TensorType::isValidElementType(type.getElementType()) && 1005 "Input element type should be a valid tensor element type"); 1006 return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type, 1007 dialect, bytes); 1008 } 1009 1010 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } 1011 1012 /// Return the value at the given index. If index does not refer to a valid 1013 /// element, then a null attribute is returned. 1014 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1015 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1016 if (Dialect *dialect = getDialect()) 1017 return dialect->extractElementHook(*this, index); 1018 return Attribute(); 1019 } 1020 1021 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } 1022 1023 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 1024 if (auto *d = getDialect()) 1025 return d->decodeHook(*this, result); 1026 return true; 1027 } 1028 1029 //===----------------------------------------------------------------------===// 1030 // SparseElementsAttr 1031 //===----------------------------------------------------------------------===// 1032 1033 SparseElementsAttr SparseElementsAttr::get(ShapedType type, 1034 DenseElementsAttr indices, 1035 DenseElementsAttr values) { 1036 assert(indices.getType().getElementType().isInteger(64) && 1037 "expected sparse indices to be 64-bit integer values"); 1038 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 1039 "type must be ranked tensor or vector"); 1040 assert(type.hasStaticShape() && "type must have static shape"); 1041 return Base::get(type.getContext(), StandardAttributes::SparseElements, type, 1042 indices.cast<DenseIntElementsAttr>(), values); 1043 } 1044 1045 DenseIntElementsAttr SparseElementsAttr::getIndices() const { 1046 return getImpl()->indices; 1047 } 1048 1049 DenseElementsAttr SparseElementsAttr::getValues() const { 1050 return getImpl()->values; 1051 } 1052 1053 /// Return the value of the element at the given index. 1054 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 1055 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 1056 auto type = getType(); 1057 1058 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1059 // as a 1-D index array. 1060 auto sparseIndices = getIndices(); 1061 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1062 1063 // Check to see if the indices are a splat. 1064 if (sparseIndices.isSplat()) { 1065 // If the index is also not a splat of the index value, we know that the 1066 // value is zero. 1067 auto splatIndex = *sparseIndexValues.begin(); 1068 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1069 return getZeroAttr(); 1070 1071 // If the indices are a splat, we also expect the values to be a splat. 1072 assert(getValues().isSplat() && "expected splat values"); 1073 return getValues().getSplatValue(); 1074 } 1075 1076 // Build a mapping between known indices and the offset of the stored element. 1077 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1078 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1079 size_t rank = type.getRank(); 1080 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1081 mappedIndices.try_emplace( 1082 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1083 1084 // Look for the provided index key within the mapped indices. If the provided 1085 // index is not found, then return a zero attribute. 1086 auto it = mappedIndices.find(index); 1087 if (it == mappedIndices.end()) 1088 return getZeroAttr(); 1089 1090 // Otherwise, return the held sparse value element. 1091 return getValues().getValue(it->second); 1092 } 1093 1094 /// Get a zero APFloat for the given sparse attribute. 1095 APFloat SparseElementsAttr::getZeroAPFloat() const { 1096 auto eltType = getType().getElementType().cast<FloatType>(); 1097 return APFloat(eltType.getFloatSemantics()); 1098 } 1099 1100 /// Get a zero APInt for the given sparse attribute. 1101 APInt SparseElementsAttr::getZeroAPInt() const { 1102 auto eltType = getType().getElementType().cast<IntegerType>(); 1103 return APInt::getNullValue(eltType.getWidth()); 1104 } 1105 1106 /// Get a zero attribute for the given attribute type. 1107 Attribute SparseElementsAttr::getZeroAttr() const { 1108 auto eltType = getType().getElementType(); 1109 1110 // Handle floating point elements. 1111 if (eltType.isa<FloatType>()) 1112 return FloatAttr::get(eltType, 0); 1113 1114 // Otherwise, this is an integer. 1115 auto intEltTy = eltType.cast<IntegerType>(); 1116 if (intEltTy.getWidth() == 1) 1117 return BoolAttr::get(false, eltType.getContext()); 1118 return IntegerAttr::get(eltType, 0); 1119 } 1120 1121 /// Flatten, and return, all of the sparse indices in this attribute in 1122 /// row-major order. 1123 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1124 std::vector<ptrdiff_t> flatSparseIndices; 1125 1126 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1127 // as a 1-D index array. 1128 auto sparseIndices = getIndices(); 1129 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1130 if (sparseIndices.isSplat()) { 1131 SmallVector<uint64_t, 8> indices(getType().getRank(), 1132 *sparseIndexValues.begin()); 1133 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1134 return flatSparseIndices; 1135 } 1136 1137 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1138 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1139 size_t rank = getType().getRank(); 1140 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1141 flatSparseIndices.push_back(getFlattenedIndex( 1142 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1143 return flatSparseIndices; 1144 } 1145 1146 //===----------------------------------------------------------------------===// 1147 // NamedAttributeList 1148 //===----------------------------------------------------------------------===// 1149 1150 NamedAttributeList::NamedAttributeList(ArrayRef<NamedAttribute> attributes) { 1151 setAttrs(attributes); 1152 } 1153 1154 ArrayRef<NamedAttribute> NamedAttributeList::getAttrs() const { 1155 return attrs ? attrs.getValue() : llvm::None; 1156 } 1157 1158 /// Replace the held attributes with ones provided in 'newAttrs'. 1159 void NamedAttributeList::setAttrs(ArrayRef<NamedAttribute> attributes) { 1160 // Don't create an attribute list if there are no attributes. 1161 if (attributes.empty()) 1162 attrs = nullptr; 1163 else 1164 attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext()); 1165 } 1166 1167 /// Return the specified attribute if present, null otherwise. 1168 Attribute NamedAttributeList::get(StringRef name) const { 1169 return attrs ? attrs.get(name) : nullptr; 1170 } 1171 1172 /// Return the specified attribute if present, null otherwise. 1173 Attribute NamedAttributeList::get(Identifier name) const { 1174 return attrs ? attrs.get(name) : nullptr; 1175 } 1176 1177 /// If the an attribute exists with the specified name, change it to the new 1178 /// value. Otherwise, add a new attribute with the specified name/value. 1179 void NamedAttributeList::set(Identifier name, Attribute value) { 1180 assert(value && "attributes may never be null"); 1181 1182 // Look for an existing value for the given name, and set it in-place. 1183 ArrayRef<NamedAttribute> values = getAttrs(); 1184 auto it = llvm::find_if( 1185 values, [name](NamedAttribute attr) { return attr.first == name; }); 1186 if (it != values.end()) { 1187 // Bail out early if the value is the same as what we already have. 1188 if (it->second == value) 1189 return; 1190 1191 SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end()); 1192 newAttrs[it - values.begin()].second = value; 1193 attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); 1194 return; 1195 } 1196 1197 // Otherwise, insert the new attribute into its sorted position. 1198 it = llvm::lower_bound(values, name, compareNamedAttributeWithName); 1199 SmallVector<NamedAttribute, 8> newAttrs; 1200 newAttrs.reserve(values.size() + 1); 1201 newAttrs.append(values.begin(), it); 1202 newAttrs.push_back({name, value}); 1203 newAttrs.append(it, values.end()); 1204 attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); 1205 } 1206 1207 /// Remove the attribute with the specified name if it exists. The return 1208 /// value indicates whether the attribute was present or not. 1209 auto NamedAttributeList::remove(Identifier name) -> RemoveResult { 1210 auto origAttrs = getAttrs(); 1211 for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { 1212 if (origAttrs[i].first == name) { 1213 // Handle the simple case of removing the only attribute in the list. 1214 if (e == 1) { 1215 attrs = nullptr; 1216 return RemoveResult::Removed; 1217 } 1218 1219 SmallVector<NamedAttribute, 8> newAttrs; 1220 newAttrs.reserve(origAttrs.size() - 1); 1221 newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); 1222 newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); 1223 attrs = DictionaryAttr::getWithSorted(newAttrs, 1224 newAttrs[0].second.getContext()); 1225 return RemoveResult::Removed; 1226 } 1227 } 1228 return RemoveResult::NotFound; 1229 } 1230