1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Attributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Diagnostics.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/Function.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/IR/Types.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 23 //===----------------------------------------------------------------------===// 24 // AttributeStorage 25 //===----------------------------------------------------------------------===// 26 27 AttributeStorage::AttributeStorage(Type type) 28 : type(type.getAsOpaquePointer()) {} 29 AttributeStorage::AttributeStorage() : type(nullptr) {} 30 31 Type AttributeStorage::getType() const { 32 return Type::getFromOpaquePointer(type); 33 } 34 void AttributeStorage::setType(Type newType) { 35 type = newType.getAsOpaquePointer(); 36 } 37 38 //===----------------------------------------------------------------------===// 39 // Attribute 40 //===----------------------------------------------------------------------===// 41 42 /// Return the type of this attribute. 43 Type Attribute::getType() const { return impl->getType(); } 44 45 /// Return the context this attribute belongs to. 46 MLIRContext *Attribute::getContext() const { return getType().getContext(); } 47 48 /// Get the dialect this attribute is registered to. 49 Dialect &Attribute::getDialect() const { return impl->getDialect(); } 50 51 //===----------------------------------------------------------------------===// 52 // AffineMapAttr 53 //===----------------------------------------------------------------------===// 54 55 AffineMapAttr AffineMapAttr::get(AffineMap value) { 56 return Base::get(value.getContext(), StandardAttributes::AffineMap, value); 57 } 58 59 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } 60 61 //===----------------------------------------------------------------------===// 62 // ArrayAttr 63 //===----------------------------------------------------------------------===// 64 65 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) { 66 return Base::get(context, StandardAttributes::Array, value); 67 } 68 69 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } 70 71 Attribute ArrayAttr::operator[](unsigned idx) const { 72 assert(idx < size() && "index out of bounds"); 73 return getValue()[idx]; 74 } 75 76 //===----------------------------------------------------------------------===// 77 // BoolAttr 78 //===----------------------------------------------------------------------===// 79 80 bool BoolAttr::getValue() const { return getImpl()->value; } 81 82 //===----------------------------------------------------------------------===// 83 // DictionaryAttr 84 //===----------------------------------------------------------------------===// 85 86 /// Perform a three-way comparison between the names of the specified 87 /// NamedAttributes. 88 static int compareNamedAttributes(const NamedAttribute *lhs, 89 const NamedAttribute *rhs) { 90 return lhs->first.strref().compare(rhs->first.strref()); 91 } 92 93 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value, 94 MLIRContext *context) { 95 assert(llvm::all_of(value, 96 [](const NamedAttribute &attr) { return attr.second; }) && 97 "value cannot have null entries"); 98 99 // We need to sort the element list to canonicalize it, but we also don't want 100 // to do a ton of work in the super common case where the element list is 101 // already sorted. 102 SmallVector<NamedAttribute, 8> storage; 103 switch (value.size()) { 104 case 0: 105 break; 106 case 1: 107 // A single element is already sorted. 108 break; 109 case 2: 110 assert(value[0].first != value[1].first && 111 "DictionaryAttr element names must be unique"); 112 113 // Don't invoke a general sort for two element case. 114 if (value[0].first.strref() > value[1].first.strref()) { 115 storage.push_back(value[1]); 116 storage.push_back(value[0]); 117 value = storage; 118 } 119 break; 120 default: 121 // Check to see they are sorted already. 122 bool isSorted = true; 123 for (unsigned i = 0, e = value.size() - 1; i != e; ++i) { 124 if (value[i].first.strref() > value[i + 1].first.strref()) { 125 isSorted = false; 126 break; 127 } 128 } 129 // If not, do a general sort. 130 if (!isSorted) { 131 storage.append(value.begin(), value.end()); 132 llvm::array_pod_sort(storage.begin(), storage.end(), 133 compareNamedAttributes); 134 value = storage; 135 } 136 137 // Ensure that the attribute elements are unique. 138 assert(std::adjacent_find(value.begin(), value.end(), 139 [](NamedAttribute l, NamedAttribute r) { 140 return l.first == r.first; 141 }) == value.end() && 142 "DictionaryAttr element names must be unique"); 143 } 144 145 return Base::get(context, StandardAttributes::Dictionary, value); 146 } 147 148 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const { 149 return getImpl()->getElements(); 150 } 151 152 /// Return the specified attribute if present, null otherwise. 153 Attribute DictionaryAttr::get(StringRef name) const { 154 ArrayRef<NamedAttribute> values = getValue(); 155 auto compare = [](NamedAttribute attr, StringRef name) { 156 return attr.first.strref() < name; 157 }; 158 auto it = llvm::lower_bound(values, name, compare); 159 return it != values.end() && it->first.is(name) ? it->second : Attribute(); 160 } 161 Attribute DictionaryAttr::get(Identifier name) const { 162 for (auto elt : getValue()) 163 if (elt.first == name) 164 return elt.second; 165 return nullptr; 166 } 167 168 DictionaryAttr::iterator DictionaryAttr::begin() const { 169 return getValue().begin(); 170 } 171 DictionaryAttr::iterator DictionaryAttr::end() const { 172 return getValue().end(); 173 } 174 size_t DictionaryAttr::size() const { return getValue().size(); } 175 176 //===----------------------------------------------------------------------===// 177 // FloatAttr 178 //===----------------------------------------------------------------------===// 179 180 FloatAttr FloatAttr::get(Type type, double value) { 181 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 182 } 183 184 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { 185 return Base::getChecked(loc, type.getContext(), StandardAttributes::Float, 186 type, value); 187 } 188 189 FloatAttr FloatAttr::get(Type type, const APFloat &value) { 190 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 191 } 192 193 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { 194 return Base::getChecked(loc, type.getContext(), StandardAttributes::Float, 195 type, value); 196 } 197 198 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } 199 200 double FloatAttr::getValueAsDouble() const { 201 return getValueAsDouble(getValue()); 202 } 203 double FloatAttr::getValueAsDouble(APFloat value) { 204 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 205 bool losesInfo = false; 206 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 207 &losesInfo); 208 } 209 return value.convertToDouble(); 210 } 211 212 /// Verify construction invariants. 213 static LogicalResult verifyFloatTypeInvariants(Optional<Location> loc, 214 Type type) { 215 if (!type.isa<FloatType>()) 216 return emitOptionalError(loc, "expected floating point type"); 217 return success(); 218 } 219 220 LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc, 221 MLIRContext *ctx, 222 Type type, double value) { 223 return verifyFloatTypeInvariants(loc, type); 224 } 225 226 LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc, 227 MLIRContext *ctx, 228 Type type, 229 const APFloat &value) { 230 // Verify that the type is correct. 231 if (failed(verifyFloatTypeInvariants(loc, type))) 232 return failure(); 233 234 // Verify that the type semantics match that of the value. 235 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 236 return emitOptionalError( 237 loc, "FloatAttr type doesn't match the type implied by its value"); 238 } 239 return success(); 240 } 241 242 //===----------------------------------------------------------------------===// 243 // SymbolRefAttr 244 //===----------------------------------------------------------------------===// 245 246 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { 247 return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None) 248 .cast<FlatSymbolRefAttr>(); 249 } 250 251 SymbolRefAttr SymbolRefAttr::get(StringRef value, 252 ArrayRef<FlatSymbolRefAttr> nestedReferences, 253 MLIRContext *ctx) { 254 return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences); 255 } 256 257 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } 258 259 StringRef SymbolRefAttr::getLeafReference() const { 260 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 261 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 262 } 263 264 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const { 265 return getImpl()->getNestedRefs(); 266 } 267 268 //===----------------------------------------------------------------------===// 269 // IntegerAttr 270 //===----------------------------------------------------------------------===// 271 272 IntegerAttr IntegerAttr::get(Type type, const APInt &value) { 273 return Base::get(type.getContext(), StandardAttributes::Integer, type, value); 274 } 275 276 IntegerAttr IntegerAttr::get(Type type, int64_t value) { 277 // This uses 64 bit APInts by default for index type. 278 if (type.isIndex()) 279 return get(type, APInt(64, value)); 280 281 auto intType = type.cast<IntegerType>(); 282 return get(type, APInt(intType.getWidth(), value)); 283 } 284 285 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } 286 287 int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); } 288 289 static LogicalResult verifyIntegerTypeInvariants(Optional<Location> loc, 290 Type type) { 291 if (type.isa<IntegerType>() || type.isa<IndexType>()) 292 return success(); 293 return emitOptionalError(loc, "expected integer or index type"); 294 } 295 296 LogicalResult IntegerAttr::verifyConstructionInvariants(Optional<Location> loc, 297 MLIRContext *ctx, 298 Type type, 299 int64_t value) { 300 return verifyIntegerTypeInvariants(loc, type); 301 } 302 303 LogicalResult IntegerAttr::verifyConstructionInvariants(Optional<Location> loc, 304 MLIRContext *ctx, 305 Type type, 306 const APInt &value) { 307 if (failed(verifyIntegerTypeInvariants(loc, type))) 308 return failure(); 309 if (auto integerType = type.dyn_cast<IntegerType>()) 310 if (integerType.getWidth() != value.getBitWidth()) 311 return emitOptionalError( 312 loc, "integer type bit width (", integerType.getWidth(), 313 ") doesn't match value bit width (", value.getBitWidth(), ")"); 314 return success(); 315 } 316 317 //===----------------------------------------------------------------------===// 318 // IntegerSetAttr 319 //===----------------------------------------------------------------------===// 320 321 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { 322 return Base::get(value.getConstraint(0).getContext(), 323 StandardAttributes::IntegerSet, value); 324 } 325 326 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } 327 328 //===----------------------------------------------------------------------===// 329 // OpaqueAttr 330 //===----------------------------------------------------------------------===// 331 332 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, 333 MLIRContext *context) { 334 return Base::get(context, StandardAttributes::Opaque, dialect, attrData, 335 type); 336 } 337 338 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, 339 Type type, Location location) { 340 return Base::getChecked(location, type.getContext(), 341 StandardAttributes::Opaque, dialect, attrData, type); 342 } 343 344 /// Returns the dialect namespace of the opaque attribute. 345 Identifier OpaqueAttr::getDialectNamespace() const { 346 return getImpl()->dialectNamespace; 347 } 348 349 /// Returns the raw attribute data of the opaque attribute. 350 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } 351 352 /// Verify the construction of an opaque attribute. 353 LogicalResult OpaqueAttr::verifyConstructionInvariants(Optional<Location> loc, 354 MLIRContext *context, 355 Identifier dialect, 356 StringRef attrData, 357 Type type) { 358 if (!Dialect::isValidNamespace(dialect.strref())) 359 return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'"); 360 return success(); 361 } 362 363 //===----------------------------------------------------------------------===// 364 // StringAttr 365 //===----------------------------------------------------------------------===// 366 367 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { 368 return get(bytes, NoneType::get(context)); 369 } 370 371 /// Get an instance of a StringAttr with the given string and Type. 372 StringAttr StringAttr::get(StringRef bytes, Type type) { 373 return Base::get(type.getContext(), StandardAttributes::String, bytes, type); 374 } 375 376 StringRef StringAttr::getValue() const { return getImpl()->value; } 377 378 //===----------------------------------------------------------------------===// 379 // TypeAttr 380 //===----------------------------------------------------------------------===// 381 382 TypeAttr TypeAttr::get(Type value) { 383 return Base::get(value.getContext(), StandardAttributes::Type, value); 384 } 385 386 Type TypeAttr::getValue() const { return getImpl()->value; } 387 388 //===----------------------------------------------------------------------===// 389 // ElementsAttr 390 //===----------------------------------------------------------------------===// 391 392 ShapedType ElementsAttr::getType() const { 393 return Attribute::getType().cast<ShapedType>(); 394 } 395 396 /// Returns the number of elements held by this attribute. 397 int64_t ElementsAttr::getNumElements() const { 398 return getType().getNumElements(); 399 } 400 401 /// Return the value at the given index. If index does not refer to a valid 402 /// element, then a null attribute is returned. 403 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 404 switch (getKind()) { 405 case StandardAttributes::DenseElements: 406 return cast<DenseElementsAttr>().getValue(index); 407 case StandardAttributes::OpaqueElements: 408 return cast<OpaqueElementsAttr>().getValue(index); 409 case StandardAttributes::SparseElements: 410 return cast<SparseElementsAttr>().getValue(index); 411 default: 412 llvm_unreachable("unknown ElementsAttr kind"); 413 } 414 } 415 416 /// Return if the given 'index' refers to a valid element in this attribute. 417 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 418 auto type = getType(); 419 420 // Verify that the rank of the indices matches the held type. 421 auto rank = type.getRank(); 422 if (rank != static_cast<int64_t>(index.size())) 423 return false; 424 425 // Verify that all of the indices are within the shape dimensions. 426 auto shape = type.getShape(); 427 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 428 return static_cast<int64_t>(index[i]) < shape[i]; 429 }); 430 } 431 432 ElementsAttr 433 ElementsAttr::mapValues(Type newElementType, 434 function_ref<APInt(const APInt &)> mapping) const { 435 switch (getKind()) { 436 case StandardAttributes::DenseElements: 437 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 438 default: 439 llvm_unreachable("unsupported ElementsAttr subtype"); 440 } 441 } 442 443 ElementsAttr 444 ElementsAttr::mapValues(Type newElementType, 445 function_ref<APInt(const APFloat &)> mapping) const { 446 switch (getKind()) { 447 case StandardAttributes::DenseElements: 448 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 449 default: 450 llvm_unreachable("unsupported ElementsAttr subtype"); 451 } 452 } 453 454 /// Returns the 1 dimensional flattened row-major index from the given 455 /// multi-dimensional index. 456 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 457 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 458 auto type = getType(); 459 460 // Reduce the provided multidimensional index into a flattended 1D row-major 461 // index. 462 auto rank = type.getRank(); 463 auto shape = type.getShape(); 464 uint64_t valueIndex = 0; 465 uint64_t dimMultiplier = 1; 466 for (int i = rank - 1; i >= 0; --i) { 467 valueIndex += index[i] * dimMultiplier; 468 dimMultiplier *= shape[i]; 469 } 470 return valueIndex; 471 } 472 473 //===----------------------------------------------------------------------===// 474 // DenseElementAttr Utilities 475 //===----------------------------------------------------------------------===// 476 477 static size_t getDenseElementBitwidth(Type eltType) { 478 // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored 479 // with double semantics. 480 return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth(); 481 } 482 483 /// Get the bitwidth of a dense element type within the buffer. 484 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 485 static size_t getDenseElementStorageWidth(size_t origWidth) { 486 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 487 } 488 489 /// Set a bit to a specific value. 490 static void setBit(char *rawData, size_t bitPos, bool value) { 491 if (value) 492 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 493 else 494 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 495 } 496 497 /// Return the value of the specified bit. 498 static bool getBit(const char *rawData, size_t bitPos) { 499 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 500 } 501 502 /// Writes value to the bit position `bitPos` in array `rawData`. 503 static void writeBits(char *rawData, size_t bitPos, APInt value) { 504 size_t bitWidth = value.getBitWidth(); 505 506 // If the bitwidth is 1 we just toggle the specific bit. 507 if (bitWidth == 1) 508 return setBit(rawData, bitPos, value.isOneValue()); 509 510 // Otherwise, the bit position is guaranteed to be byte aligned. 511 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 512 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 513 llvm::divideCeil(bitWidth, CHAR_BIT), 514 rawData + (bitPos / CHAR_BIT)); 515 } 516 517 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 518 /// `rawData`. 519 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 520 // Handle a boolean bit position. 521 if (bitWidth == 1) 522 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 523 524 // Otherwise, the bit position must be 8-bit aligned. 525 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 526 APInt result(bitWidth, 0); 527 std::copy_n( 528 rawData + (bitPos / CHAR_BIT), llvm::divideCeil(bitWidth, CHAR_BIT), 529 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 530 return result; 531 } 532 533 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the 534 /// same element count as 'type'. 535 template <typename Values> 536 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 537 return (values.size() == 1) || 538 (type.getNumElements() == static_cast<int64_t>(values.size())); 539 } 540 541 //===----------------------------------------------------------------------===// 542 // DenseElementAttr Iterators 543 //===----------------------------------------------------------------------===// 544 545 /// Constructs a new iterator. 546 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 547 DenseElementsAttr attr, size_t index) 548 : indexed_accessor_iterator<AttributeElementIterator, const void *, 549 Attribute, Attribute, Attribute>( 550 attr.getAsOpaquePointer(), index) {} 551 552 /// Accesses the Attribute value at this iterator position. 553 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 554 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 555 Type eltTy = owner.getType().getElementType(); 556 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) { 557 if (intEltTy.getWidth() == 1) 558 return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(), 559 owner.getContext()); 560 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 561 } 562 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 563 IntElementIterator intIt(owner, index); 564 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 565 return FloatAttr::get(eltTy, *floatIt); 566 } 567 llvm_unreachable("unexpected element type"); 568 } 569 570 /// Constructs a new iterator. 571 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 572 DenseElementsAttr attr, size_t dataIndex) 573 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 574 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 575 576 /// Accesses the bool value at this iterator position. 577 bool DenseElementsAttr::BoolElementIterator::operator*() const { 578 return getBit(getData(), getDataIndex()); 579 } 580 581 /// Constructs a new iterator. 582 DenseElementsAttr::IntElementIterator::IntElementIterator( 583 DenseElementsAttr attr, size_t dataIndex) 584 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 585 attr.getRawData().data(), attr.isSplat(), dataIndex), 586 bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {} 587 588 /// Accesses the raw APInt value at this iterator position. 589 APInt DenseElementsAttr::IntElementIterator::operator*() const { 590 return readBits(getData(), 591 getDataIndex() * getDenseElementStorageWidth(bitWidth), 592 bitWidth); 593 } 594 595 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 596 const llvm::fltSemantics &smt, IntElementIterator it) 597 : llvm::mapped_iterator<IntElementIterator, 598 std::function<APFloat(const APInt &)>>( 599 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 600 601 //===----------------------------------------------------------------------===// 602 // DenseElementsAttr 603 //===----------------------------------------------------------------------===// 604 605 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 606 ArrayRef<Attribute> values) { 607 assert(type.getElementType().isIntOrFloat() && 608 "expected int or float element type"); 609 assert(hasSameElementsOrSplat(type, values)); 610 611 auto eltType = type.getElementType(); 612 size_t bitWidth = getDenseElementBitwidth(eltType); 613 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 614 615 // Compress the attribute values into a character buffer. 616 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 617 values.size()); 618 APInt intVal; 619 for (unsigned i = 0, e = values.size(); i < e; ++i) { 620 assert(eltType == values[i].getType() && 621 "expected attribute value to have element type"); 622 623 switch (eltType.getKind()) { 624 case StandardTypes::BF16: 625 case StandardTypes::F16: 626 case StandardTypes::F32: 627 case StandardTypes::F64: 628 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 629 break; 630 case StandardTypes::Integer: 631 intVal = values[i].isa<BoolAttr>() 632 ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0) 633 : values[i].cast<IntegerAttr>().getValue(); 634 break; 635 default: 636 llvm_unreachable("unexpected element type"); 637 } 638 assert(intVal.getBitWidth() == bitWidth && 639 "expected value to have same bitwidth as element type"); 640 writeBits(data.data(), i * storageBitWidth, intVal); 641 } 642 return getRaw(type, data, /*isSplat=*/(values.size() == 1)); 643 } 644 645 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 646 ArrayRef<bool> values) { 647 assert(hasSameElementsOrSplat(type, values)); 648 assert(type.getElementType().isInteger(1)); 649 650 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 651 for (int i = 0, e = values.size(); i != e; ++i) 652 setBit(buff.data(), i, values[i]); 653 return getRaw(type, buff, /*isSplat=*/(values.size() == 1)); 654 } 655 656 /// Constructs a dense integer elements attribute from an array of APInt 657 /// values. Each APInt value is expected to have the same bitwidth as the 658 /// element type of 'type'. 659 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 660 ArrayRef<APInt> values) { 661 assert(type.getElementType().isa<IntegerType>()); 662 return getRaw(type, values); 663 } 664 665 // Constructs a dense float elements attribute from an array of APFloat 666 // values. Each APFloat value is expected to have the same bitwidth as the 667 // element type of 'type'. 668 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 669 ArrayRef<APFloat> values) { 670 assert(type.getElementType().isa<FloatType>()); 671 672 // Convert the APFloat values to APInt and create a dense elements attribute. 673 std::vector<APInt> intValues(values.size()); 674 for (unsigned i = 0, e = values.size(); i != e; ++i) 675 intValues[i] = values[i].bitcastToAPInt(); 676 return getRaw(type, intValues); 677 } 678 679 // Constructs a dense elements attribute from an array of raw APInt values. 680 // Each APInt value is expected to have the same bitwidth as the element type 681 // of 'type'. 682 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, 683 ArrayRef<APInt> values) { 684 assert(hasSameElementsOrSplat(type, values)); 685 686 size_t bitWidth = getDenseElementBitwidth(type.getElementType()); 687 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 688 std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 689 values.size()); 690 for (unsigned i = 0, e = values.size(); i != e; ++i) { 691 assert(values[i].getBitWidth() == bitWidth); 692 writeBits(elementData.data(), i * storageBitWidth, values[i]); 693 } 694 return getRaw(type, elementData, /*isSplat=*/(values.size() == 1)); 695 } 696 697 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, 698 ArrayRef<char> data, bool isSplat) { 699 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 700 "type must be ranked tensor or vector"); 701 assert(type.hasStaticShape() && "type must have static shape"); 702 return Base::get(type.getContext(), StandardAttributes::DenseElements, type, 703 data, isSplat); 704 } 705 706 /// Check the information for a c++ data type, check if this type is valid for 707 /// the current attribute. This method is used to verify specific type 708 /// invariants that the templatized 'getValues' method cannot. 709 static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, 710 bool isInt) { 711 // Make sure that the data element size is the same as the type element width. 712 if (getDenseElementBitwidth(type.getElementType()) != 713 static_cast<size_t>(dataEltSize * CHAR_BIT)) 714 return false; 715 716 // Check that the element type is valid. 717 return isInt ? type.getElementType().isa<IntegerType>() 718 : type.getElementType().isa<FloatType>(); 719 } 720 721 /// Overload of the 'getRaw' method that asserts that the given type is of 722 /// integer type. This method is used to verify type invariants that the 723 /// templatized 'get' method cannot. 724 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 725 ArrayRef<char> data, 726 int64_t dataEltSize, 727 bool isInt) { 728 assert(::isValidIntOrFloat(type, dataEltSize, isInt)); 729 730 int64_t numElements = data.size() / dataEltSize; 731 assert(numElements == 1 || numElements == type.getNumElements()); 732 return getRaw(type, data, /*isSplat=*/numElements == 1); 733 } 734 735 /// A method used to verify specific type invariants that the templatized 'get' 736 /// method cannot. 737 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, 738 bool isInt) const { 739 return ::isValidIntOrFloat(getType(), dataEltSize, isInt); 740 } 741 742 /// Return the raw storage data held by this attribute. 743 ArrayRef<char> DenseElementsAttr::getRawData() const { 744 return static_cast<ImplType *>(impl)->data; 745 } 746 747 /// Returns if this attribute corresponds to a splat, i.e. if all element 748 /// values are the same. 749 bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; } 750 751 /// Return the held element values as a range of Attributes. 752 auto DenseElementsAttr::getAttributeValues() const 753 -> llvm::iterator_range<AttributeElementIterator> { 754 return {attr_value_begin(), attr_value_end()}; 755 } 756 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 757 return AttributeElementIterator(*this, 0); 758 } 759 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 760 return AttributeElementIterator(*this, getNumElements()); 761 } 762 763 /// Return the held element values as a range of bool. The element type of 764 /// this attribute must be of integer type of bitwidth 1. 765 auto DenseElementsAttr::getBoolValues() const 766 -> llvm::iterator_range<BoolElementIterator> { 767 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 768 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 769 (void)eltType; 770 return {BoolElementIterator(*this, 0), 771 BoolElementIterator(*this, getNumElements())}; 772 } 773 774 /// Return the held element values as a range of APInts. The element type of 775 /// this attribute must be of integer type. 776 auto DenseElementsAttr::getIntValues() const 777 -> llvm::iterator_range<IntElementIterator> { 778 assert(getType().getElementType().isa<IntegerType>() && 779 "expected integer type"); 780 return {raw_int_begin(), raw_int_end()}; 781 } 782 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 783 assert(getType().getElementType().isa<IntegerType>() && 784 "expected integer type"); 785 return raw_int_begin(); 786 } 787 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 788 assert(getType().getElementType().isa<IntegerType>() && 789 "expected integer type"); 790 return raw_int_end(); 791 } 792 793 /// Return the held element values as a range of APFloat. The element type of 794 /// this attribute must be of float type. 795 auto DenseElementsAttr::getFloatValues() const 796 -> llvm::iterator_range<FloatElementIterator> { 797 auto elementType = getType().getElementType().cast<FloatType>(); 798 assert(elementType.isa<FloatType>() && "expected float type"); 799 const auto &elementSemantics = elementType.getFloatSemantics(); 800 return {FloatElementIterator(elementSemantics, raw_int_begin()), 801 FloatElementIterator(elementSemantics, raw_int_end())}; 802 } 803 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 804 return getFloatValues().begin(); 805 } 806 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 807 return getFloatValues().end(); 808 } 809 810 /// Return a new DenseElementsAttr that has the same data as the current 811 /// attribute, but has been reshaped to 'newType'. The new type must have the 812 /// same total number of elements as well as element type. 813 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 814 ShapedType curType = getType(); 815 if (curType == newType) 816 return *this; 817 818 (void)curType; 819 assert(newType.getElementType() == curType.getElementType() && 820 "expected the same element type"); 821 assert(newType.getNumElements() == curType.getNumElements() && 822 "expected the same number of elements"); 823 return getRaw(newType, getRawData(), isSplat()); 824 } 825 826 DenseElementsAttr 827 DenseElementsAttr::mapValues(Type newElementType, 828 function_ref<APInt(const APInt &)> mapping) const { 829 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 830 } 831 832 DenseElementsAttr DenseElementsAttr::mapValues( 833 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 834 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 835 } 836 837 //===----------------------------------------------------------------------===// 838 // DenseFPElementsAttr 839 //===----------------------------------------------------------------------===// 840 841 template <typename Fn, typename Attr> 842 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 843 Type newElementType, 844 llvm::SmallVectorImpl<char> &data) { 845 size_t bitWidth = getDenseElementBitwidth(newElementType); 846 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 847 848 ShapedType newArrayType; 849 if (inType.isa<RankedTensorType>()) 850 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 851 else if (inType.isa<UnrankedTensorType>()) 852 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 853 else if (inType.isa<VectorType>()) 854 newArrayType = VectorType::get(inType.getShape(), newElementType); 855 else 856 assert(newArrayType && "Unhandled tensor type"); 857 858 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 859 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 860 861 // Functor used to process a single element value of the attribute. 862 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 863 auto newInt = mapping(value); 864 assert(newInt.getBitWidth() == bitWidth); 865 writeBits(data.data(), index * storageBitWidth, newInt); 866 }; 867 868 // Check for the splat case. 869 if (attr.isSplat()) { 870 processElt(*attr.begin(), /*index=*/0); 871 return newArrayType; 872 } 873 874 // Otherwise, process all of the element values. 875 uint64_t elementIdx = 0; 876 for (auto value : attr) 877 processElt(value, elementIdx++); 878 return newArrayType; 879 } 880 881 DenseElementsAttr DenseFPElementsAttr::mapValues( 882 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 883 llvm::SmallVector<char, 8> elementData; 884 auto newArrayType = 885 mappingHelper(mapping, *this, getType(), newElementType, elementData); 886 887 return getRaw(newArrayType, elementData, isSplat()); 888 } 889 890 /// Method for supporting type inquiry through isa, cast and dyn_cast. 891 bool DenseFPElementsAttr::classof(Attribute attr) { 892 return attr.isa<DenseElementsAttr>() && 893 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 894 } 895 896 //===----------------------------------------------------------------------===// 897 // DenseIntElementsAttr 898 //===----------------------------------------------------------------------===// 899 900 DenseElementsAttr DenseIntElementsAttr::mapValues( 901 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 902 llvm::SmallVector<char, 8> elementData; 903 auto newArrayType = 904 mappingHelper(mapping, *this, getType(), newElementType, elementData); 905 906 return getRaw(newArrayType, elementData, isSplat()); 907 } 908 909 /// Method for supporting type inquiry through isa, cast and dyn_cast. 910 bool DenseIntElementsAttr::classof(Attribute attr) { 911 return attr.isa<DenseElementsAttr>() && 912 attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>(); 913 } 914 915 //===----------------------------------------------------------------------===// 916 // OpaqueElementsAttr 917 //===----------------------------------------------------------------------===// 918 919 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, 920 StringRef bytes) { 921 assert(TensorType::isValidElementType(type.getElementType()) && 922 "Input element type should be a valid tensor element type"); 923 return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type, 924 dialect, bytes); 925 } 926 927 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } 928 929 /// Return the value at the given index. If index does not refer to a valid 930 /// element, then a null attribute is returned. 931 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 932 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 933 if (Dialect *dialect = getDialect()) 934 return dialect->extractElementHook(*this, index); 935 return Attribute(); 936 } 937 938 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } 939 940 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 941 if (auto *d = getDialect()) 942 return d->decodeHook(*this, result); 943 return true; 944 } 945 946 //===----------------------------------------------------------------------===// 947 // SparseElementsAttr 948 //===----------------------------------------------------------------------===// 949 950 SparseElementsAttr SparseElementsAttr::get(ShapedType type, 951 DenseElementsAttr indices, 952 DenseElementsAttr values) { 953 assert(indices.getType().getElementType().isInteger(64) && 954 "expected sparse indices to be 64-bit integer values"); 955 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 956 "type must be ranked tensor or vector"); 957 assert(type.hasStaticShape() && "type must have static shape"); 958 return Base::get(type.getContext(), StandardAttributes::SparseElements, type, 959 indices.cast<DenseIntElementsAttr>(), values); 960 } 961 962 DenseIntElementsAttr SparseElementsAttr::getIndices() const { 963 return getImpl()->indices; 964 } 965 966 DenseElementsAttr SparseElementsAttr::getValues() const { 967 return getImpl()->values; 968 } 969 970 /// Return the value of the element at the given index. 971 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 972 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 973 auto type = getType(); 974 975 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 976 // as a 1-D index array. 977 auto sparseIndices = getIndices(); 978 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 979 980 // Check to see if the indices are a splat. 981 if (sparseIndices.isSplat()) { 982 // If the index is also not a splat of the index value, we know that the 983 // value is zero. 984 auto splatIndex = *sparseIndexValues.begin(); 985 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 986 return getZeroAttr(); 987 988 // If the indices are a splat, we also expect the values to be a splat. 989 assert(getValues().isSplat() && "expected splat values"); 990 return getValues().getSplatValue(); 991 } 992 993 // Build a mapping between known indices and the offset of the stored element. 994 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 995 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 996 size_t rank = type.getRank(); 997 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 998 mappedIndices.try_emplace( 999 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1000 1001 // Look for the provided index key within the mapped indices. If the provided 1002 // index is not found, then return a zero attribute. 1003 auto it = mappedIndices.find(index); 1004 if (it == mappedIndices.end()) 1005 return getZeroAttr(); 1006 1007 // Otherwise, return the held sparse value element. 1008 return getValues().getValue(it->second); 1009 } 1010 1011 /// Get a zero APFloat for the given sparse attribute. 1012 APFloat SparseElementsAttr::getZeroAPFloat() const { 1013 auto eltType = getType().getElementType().cast<FloatType>(); 1014 return APFloat(eltType.getFloatSemantics()); 1015 } 1016 1017 /// Get a zero APInt for the given sparse attribute. 1018 APInt SparseElementsAttr::getZeroAPInt() const { 1019 auto eltType = getType().getElementType().cast<IntegerType>(); 1020 return APInt::getNullValue(eltType.getWidth()); 1021 } 1022 1023 /// Get a zero attribute for the given attribute type. 1024 Attribute SparseElementsAttr::getZeroAttr() const { 1025 auto eltType = getType().getElementType(); 1026 1027 // Handle floating point elements. 1028 if (eltType.isa<FloatType>()) 1029 return FloatAttr::get(eltType, 0); 1030 1031 // Otherwise, this is an integer. 1032 auto intEltTy = eltType.cast<IntegerType>(); 1033 if (intEltTy.getWidth() == 1) 1034 return BoolAttr::get(false, eltType.getContext()); 1035 return IntegerAttr::get(eltType, 0); 1036 } 1037 1038 /// Flatten, and return, all of the sparse indices in this attribute in 1039 /// row-major order. 1040 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1041 std::vector<ptrdiff_t> flatSparseIndices; 1042 1043 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1044 // as a 1-D index array. 1045 auto sparseIndices = getIndices(); 1046 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1047 if (sparseIndices.isSplat()) { 1048 SmallVector<uint64_t, 8> indices(getType().getRank(), 1049 *sparseIndexValues.begin()); 1050 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1051 return flatSparseIndices; 1052 } 1053 1054 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1055 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1056 size_t rank = getType().getRank(); 1057 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1058 flatSparseIndices.push_back(getFlattenedIndex( 1059 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1060 return flatSparseIndices; 1061 } 1062 1063 //===----------------------------------------------------------------------===// 1064 // NamedAttributeList 1065 //===----------------------------------------------------------------------===// 1066 1067 NamedAttributeList::NamedAttributeList(ArrayRef<NamedAttribute> attributes) { 1068 setAttrs(attributes); 1069 } 1070 1071 ArrayRef<NamedAttribute> NamedAttributeList::getAttrs() const { 1072 return attrs ? attrs.getValue() : llvm::None; 1073 } 1074 1075 /// Replace the held attributes with ones provided in 'newAttrs'. 1076 void NamedAttributeList::setAttrs(ArrayRef<NamedAttribute> attributes) { 1077 // Don't create an attribute list if there are no attributes. 1078 if (attributes.empty()) 1079 attrs = nullptr; 1080 else 1081 attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext()); 1082 } 1083 1084 /// Return the specified attribute if present, null otherwise. 1085 Attribute NamedAttributeList::get(StringRef name) const { 1086 return attrs ? attrs.get(name) : nullptr; 1087 } 1088 1089 /// Return the specified attribute if present, null otherwise. 1090 Attribute NamedAttributeList::get(Identifier name) const { 1091 return attrs ? attrs.get(name) : nullptr; 1092 } 1093 1094 /// If the an attribute exists with the specified name, change it to the new 1095 /// value. Otherwise, add a new attribute with the specified name/value. 1096 void NamedAttributeList::set(Identifier name, Attribute value) { 1097 assert(value && "attributes may never be null"); 1098 1099 // If we already have this attribute, replace it. 1100 auto origAttrs = getAttrs(); 1101 SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end()); 1102 for (auto &elt : newAttrs) 1103 if (elt.first == name) { 1104 elt.second = value; 1105 attrs = DictionaryAttr::get(newAttrs, value.getContext()); 1106 return; 1107 } 1108 1109 // Otherwise, add it. 1110 newAttrs.push_back({name, value}); 1111 attrs = DictionaryAttr::get(newAttrs, value.getContext()); 1112 } 1113 1114 /// Remove the attribute with the specified name if it exists. The return 1115 /// value indicates whether the attribute was present or not. 1116 auto NamedAttributeList::remove(Identifier name) -> RemoveResult { 1117 auto origAttrs = getAttrs(); 1118 for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { 1119 if (origAttrs[i].first == name) { 1120 // Handle the simple case of removing the only attribute in the list. 1121 if (e == 1) { 1122 attrs = nullptr; 1123 return RemoveResult::Removed; 1124 } 1125 1126 SmallVector<NamedAttribute, 8> newAttrs; 1127 newAttrs.reserve(origAttrs.size() - 1); 1128 newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); 1129 newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); 1130 attrs = DictionaryAttr::get(newAttrs, newAttrs[0].second.getContext()); 1131 return RemoveResult::Removed; 1132 } 1133 } 1134 return RemoveResult::NotFound; 1135 } 1136