1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Attributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Diagnostics.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/Function.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/IR/Types.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 23 //===----------------------------------------------------------------------===// 24 // AttributeStorage 25 //===----------------------------------------------------------------------===// 26 27 AttributeStorage::AttributeStorage(Type type) 28 : type(type.getAsOpaquePointer()) {} 29 AttributeStorage::AttributeStorage() : type(nullptr) {} 30 31 Type AttributeStorage::getType() const { 32 return Type::getFromOpaquePointer(type); 33 } 34 void AttributeStorage::setType(Type newType) { 35 type = newType.getAsOpaquePointer(); 36 } 37 38 //===----------------------------------------------------------------------===// 39 // Attribute 40 //===----------------------------------------------------------------------===// 41 42 /// Return the type of this attribute. 43 Type Attribute::getType() const { return impl->getType(); } 44 45 /// Return the context this attribute belongs to. 46 MLIRContext *Attribute::getContext() const { return getType().getContext(); } 47 48 /// Get the dialect this attribute is registered to. 49 Dialect &Attribute::getDialect() const { return impl->getDialect(); } 50 51 //===----------------------------------------------------------------------===// 52 // AffineMapAttr 53 //===----------------------------------------------------------------------===// 54 55 AffineMapAttr AffineMapAttr::get(AffineMap value) { 56 return Base::get(value.getContext(), StandardAttributes::AffineMap, value); 57 } 58 59 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } 60 61 //===----------------------------------------------------------------------===// 62 // ArrayAttr 63 //===----------------------------------------------------------------------===// 64 65 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) { 66 return Base::get(context, StandardAttributes::Array, value); 67 } 68 69 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } 70 71 Attribute ArrayAttr::operator[](unsigned idx) const { 72 assert(idx < size() && "index out of bounds"); 73 return getValue()[idx]; 74 } 75 76 //===----------------------------------------------------------------------===// 77 // BoolAttr 78 //===----------------------------------------------------------------------===// 79 80 bool BoolAttr::getValue() const { return getImpl()->value; } 81 82 //===----------------------------------------------------------------------===// 83 // DictionaryAttr 84 //===----------------------------------------------------------------------===// 85 86 /// Perform a three-way comparison between the names of the specified 87 /// NamedAttributes. 88 static int compareNamedAttributes(const NamedAttribute *lhs, 89 const NamedAttribute *rhs) { 90 return lhs->first.strref().compare(rhs->first.strref()); 91 } 92 93 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value, 94 MLIRContext *context) { 95 assert(llvm::all_of(value, 96 [](const NamedAttribute &attr) { return attr.second; }) && 97 "value cannot have null entries"); 98 99 // We need to sort the element list to canonicalize it, but we also don't want 100 // to do a ton of work in the super common case where the element list is 101 // already sorted. 102 SmallVector<NamedAttribute, 8> storage; 103 switch (value.size()) { 104 case 0: 105 break; 106 case 1: 107 // A single element is already sorted. 108 break; 109 case 2: 110 assert(value[0].first != value[1].first && 111 "DictionaryAttr element names must be unique"); 112 113 // Don't invoke a general sort for two element case. 114 if (value[0].first.strref() > value[1].first.strref()) { 115 storage.push_back(value[1]); 116 storage.push_back(value[0]); 117 value = storage; 118 } 119 break; 120 default: 121 // Check to see they are sorted already. 122 bool isSorted = true; 123 for (unsigned i = 0, e = value.size() - 1; i != e; ++i) { 124 if (value[i].first.strref() > value[i + 1].first.strref()) { 125 isSorted = false; 126 break; 127 } 128 } 129 // If not, do a general sort. 130 if (!isSorted) { 131 storage.append(value.begin(), value.end()); 132 llvm::array_pod_sort(storage.begin(), storage.end(), 133 compareNamedAttributes); 134 value = storage; 135 } 136 137 // Ensure that the attribute elements are unique. 138 assert(std::adjacent_find(value.begin(), value.end(), 139 [](NamedAttribute l, NamedAttribute r) { 140 return l.first == r.first; 141 }) == value.end() && 142 "DictionaryAttr element names must be unique"); 143 } 144 145 return Base::get(context, StandardAttributes::Dictionary, value); 146 } 147 148 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const { 149 return getImpl()->getElements(); 150 } 151 152 /// Return the specified attribute if present, null otherwise. 153 Attribute DictionaryAttr::get(StringRef name) const { 154 ArrayRef<NamedAttribute> values = getValue(); 155 auto compare = [](NamedAttribute attr, StringRef name) { 156 return attr.first.strref() < name; 157 }; 158 auto it = llvm::lower_bound(values, name, compare); 159 return it != values.end() && it->first.is(name) ? it->second : Attribute(); 160 } 161 Attribute DictionaryAttr::get(Identifier name) const { 162 for (auto elt : getValue()) 163 if (elt.first == name) 164 return elt.second; 165 return nullptr; 166 } 167 168 DictionaryAttr::iterator DictionaryAttr::begin() const { 169 return getValue().begin(); 170 } 171 DictionaryAttr::iterator DictionaryAttr::end() const { 172 return getValue().end(); 173 } 174 size_t DictionaryAttr::size() const { return getValue().size(); } 175 176 //===----------------------------------------------------------------------===// 177 // FloatAttr 178 //===----------------------------------------------------------------------===// 179 180 FloatAttr FloatAttr::get(Type type, double value) { 181 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 182 } 183 184 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { 185 return Base::getChecked(loc, StandardAttributes::Float, type, value); 186 } 187 188 FloatAttr FloatAttr::get(Type type, const APFloat &value) { 189 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 190 } 191 192 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { 193 return Base::getChecked(loc, StandardAttributes::Float, type, value); 194 } 195 196 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } 197 198 double FloatAttr::getValueAsDouble() const { 199 return getValueAsDouble(getValue()); 200 } 201 double FloatAttr::getValueAsDouble(APFloat value) { 202 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 203 bool losesInfo = false; 204 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 205 &losesInfo); 206 } 207 return value.convertToDouble(); 208 } 209 210 /// Verify construction invariants. 211 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) { 212 if (!type.isa<FloatType>()) 213 return emitError(loc, "expected floating point type"); 214 return success(); 215 } 216 217 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 218 double value) { 219 return verifyFloatTypeInvariants(loc, type); 220 } 221 222 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, 223 const APFloat &value) { 224 // Verify that the type is correct. 225 if (failed(verifyFloatTypeInvariants(loc, type))) 226 return failure(); 227 228 // Verify that the type semantics match that of the value. 229 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 230 return emitError( 231 loc, "FloatAttr type doesn't match the type implied by its value"); 232 } 233 return success(); 234 } 235 236 //===----------------------------------------------------------------------===// 237 // SymbolRefAttr 238 //===----------------------------------------------------------------------===// 239 240 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { 241 return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None) 242 .cast<FlatSymbolRefAttr>(); 243 } 244 245 SymbolRefAttr SymbolRefAttr::get(StringRef value, 246 ArrayRef<FlatSymbolRefAttr> nestedReferences, 247 MLIRContext *ctx) { 248 return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences); 249 } 250 251 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } 252 253 StringRef SymbolRefAttr::getLeafReference() const { 254 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 255 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 256 } 257 258 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const { 259 return getImpl()->getNestedRefs(); 260 } 261 262 //===----------------------------------------------------------------------===// 263 // IntegerAttr 264 //===----------------------------------------------------------------------===// 265 266 IntegerAttr IntegerAttr::get(Type type, const APInt &value) { 267 return Base::get(type.getContext(), StandardAttributes::Integer, type, value); 268 } 269 270 IntegerAttr IntegerAttr::get(Type type, int64_t value) { 271 // This uses 64 bit APInts by default for index type. 272 if (type.isIndex()) 273 return get(type, APInt(64, value)); 274 275 auto intType = type.cast<IntegerType>(); 276 return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger())); 277 } 278 279 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } 280 281 int64_t IntegerAttr::getInt() const { 282 assert((getImpl()->getType().isIndex() || 283 getImpl()->getType().isSignlessInteger()) && 284 "must be signless integer"); 285 return getValue().getSExtValue(); 286 } 287 288 int64_t IntegerAttr::getSInt() const { 289 assert(getImpl()->getType().isSignedInteger() && "must be signed integer"); 290 return getValue().getSExtValue(); 291 } 292 293 uint64_t IntegerAttr::getUInt() const { 294 assert(getImpl()->getType().isUnsignedInteger() && 295 "must be unsigned integer"); 296 return getValue().getZExtValue(); 297 } 298 299 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) { 300 if (type.isa<IntegerType>() || type.isa<IndexType>()) 301 return success(); 302 return emitError(loc, "expected integer or index type"); 303 } 304 305 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 306 int64_t value) { 307 return verifyIntegerTypeInvariants(loc, type); 308 } 309 310 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, 311 const APInt &value) { 312 if (failed(verifyIntegerTypeInvariants(loc, type))) 313 return failure(); 314 if (auto integerType = type.dyn_cast<IntegerType>()) 315 if (integerType.getWidth() != value.getBitWidth()) 316 return emitError(loc, "integer type bit width (") 317 << integerType.getWidth() << ") doesn't match value bit width (" 318 << value.getBitWidth() << ")"; 319 return success(); 320 } 321 322 //===----------------------------------------------------------------------===// 323 // IntegerSetAttr 324 //===----------------------------------------------------------------------===// 325 326 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { 327 return Base::get(value.getConstraint(0).getContext(), 328 StandardAttributes::IntegerSet, value); 329 } 330 331 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } 332 333 //===----------------------------------------------------------------------===// 334 // OpaqueAttr 335 //===----------------------------------------------------------------------===// 336 337 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, 338 MLIRContext *context) { 339 return Base::get(context, StandardAttributes::Opaque, dialect, attrData, 340 type); 341 } 342 343 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, 344 Type type, Location location) { 345 return Base::getChecked(location, StandardAttributes::Opaque, dialect, 346 attrData, type); 347 } 348 349 /// Returns the dialect namespace of the opaque attribute. 350 Identifier OpaqueAttr::getDialectNamespace() const { 351 return getImpl()->dialectNamespace; 352 } 353 354 /// Returns the raw attribute data of the opaque attribute. 355 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } 356 357 /// Verify the construction of an opaque attribute. 358 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc, 359 Identifier dialect, 360 StringRef attrData, 361 Type type) { 362 if (!Dialect::isValidNamespace(dialect.strref())) 363 return emitError(loc, "invalid dialect namespace '") << dialect << "'"; 364 return success(); 365 } 366 367 //===----------------------------------------------------------------------===// 368 // StringAttr 369 //===----------------------------------------------------------------------===// 370 371 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { 372 return get(bytes, NoneType::get(context)); 373 } 374 375 /// Get an instance of a StringAttr with the given string and Type. 376 StringAttr StringAttr::get(StringRef bytes, Type type) { 377 return Base::get(type.getContext(), StandardAttributes::String, bytes, type); 378 } 379 380 StringRef StringAttr::getValue() const { return getImpl()->value; } 381 382 //===----------------------------------------------------------------------===// 383 // TypeAttr 384 //===----------------------------------------------------------------------===// 385 386 TypeAttr TypeAttr::get(Type value) { 387 return Base::get(value.getContext(), StandardAttributes::Type, value); 388 } 389 390 Type TypeAttr::getValue() const { return getImpl()->value; } 391 392 //===----------------------------------------------------------------------===// 393 // ElementsAttr 394 //===----------------------------------------------------------------------===// 395 396 ShapedType ElementsAttr::getType() const { 397 return Attribute::getType().cast<ShapedType>(); 398 } 399 400 /// Returns the number of elements held by this attribute. 401 int64_t ElementsAttr::getNumElements() const { 402 return getType().getNumElements(); 403 } 404 405 /// Return the value at the given index. If index does not refer to a valid 406 /// element, then a null attribute is returned. 407 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 408 switch (getKind()) { 409 case StandardAttributes::DenseElements: 410 return cast<DenseElementsAttr>().getValue(index); 411 case StandardAttributes::OpaqueElements: 412 return cast<OpaqueElementsAttr>().getValue(index); 413 case StandardAttributes::SparseElements: 414 return cast<SparseElementsAttr>().getValue(index); 415 default: 416 llvm_unreachable("unknown ElementsAttr kind"); 417 } 418 } 419 420 /// Return if the given 'index' refers to a valid element in this attribute. 421 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 422 auto type = getType(); 423 424 // Verify that the rank of the indices matches the held type. 425 auto rank = type.getRank(); 426 if (rank != static_cast<int64_t>(index.size())) 427 return false; 428 429 // Verify that all of the indices are within the shape dimensions. 430 auto shape = type.getShape(); 431 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 432 return static_cast<int64_t>(index[i]) < shape[i]; 433 }); 434 } 435 436 ElementsAttr 437 ElementsAttr::mapValues(Type newElementType, 438 function_ref<APInt(const APInt &)> mapping) const { 439 switch (getKind()) { 440 case StandardAttributes::DenseElements: 441 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 442 default: 443 llvm_unreachable("unsupported ElementsAttr subtype"); 444 } 445 } 446 447 ElementsAttr 448 ElementsAttr::mapValues(Type newElementType, 449 function_ref<APInt(const APFloat &)> mapping) const { 450 switch (getKind()) { 451 case StandardAttributes::DenseElements: 452 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 453 default: 454 llvm_unreachable("unsupported ElementsAttr subtype"); 455 } 456 } 457 458 /// Returns the 1 dimensional flattened row-major index from the given 459 /// multi-dimensional index. 460 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 461 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 462 auto type = getType(); 463 464 // Reduce the provided multidimensional index into a flattended 1D row-major 465 // index. 466 auto rank = type.getRank(); 467 auto shape = type.getShape(); 468 uint64_t valueIndex = 0; 469 uint64_t dimMultiplier = 1; 470 for (int i = rank - 1; i >= 0; --i) { 471 valueIndex += index[i] * dimMultiplier; 472 dimMultiplier *= shape[i]; 473 } 474 return valueIndex; 475 } 476 477 //===----------------------------------------------------------------------===// 478 // DenseElementAttr Utilities 479 //===----------------------------------------------------------------------===// 480 481 static size_t getDenseElementBitwidth(Type eltType) { 482 // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored 483 // with double semantics. 484 return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth(); 485 } 486 487 /// Get the bitwidth of a dense element type within the buffer. 488 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 489 static size_t getDenseElementStorageWidth(size_t origWidth) { 490 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 491 } 492 493 /// Set a bit to a specific value. 494 static void setBit(char *rawData, size_t bitPos, bool value) { 495 if (value) 496 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 497 else 498 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 499 } 500 501 /// Return the value of the specified bit. 502 static bool getBit(const char *rawData, size_t bitPos) { 503 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 504 } 505 506 /// Writes value to the bit position `bitPos` in array `rawData`. 507 static void writeBits(char *rawData, size_t bitPos, APInt value) { 508 size_t bitWidth = value.getBitWidth(); 509 510 // If the bitwidth is 1 we just toggle the specific bit. 511 if (bitWidth == 1) 512 return setBit(rawData, bitPos, value.isOneValue()); 513 514 // Otherwise, the bit position is guaranteed to be byte aligned. 515 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 516 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 517 llvm::divideCeil(bitWidth, CHAR_BIT), 518 rawData + (bitPos / CHAR_BIT)); 519 } 520 521 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 522 /// `rawData`. 523 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 524 // Handle a boolean bit position. 525 if (bitWidth == 1) 526 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 527 528 // Otherwise, the bit position must be 8-bit aligned. 529 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 530 APInt result(bitWidth, 0); 531 std::copy_n( 532 rawData + (bitPos / CHAR_BIT), llvm::divideCeil(bitWidth, CHAR_BIT), 533 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 534 return result; 535 } 536 537 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the 538 /// same element count as 'type'. 539 template <typename Values> 540 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 541 return (values.size() == 1) || 542 (type.getNumElements() == static_cast<int64_t>(values.size())); 543 } 544 545 //===----------------------------------------------------------------------===// 546 // DenseElementAttr Iterators 547 //===----------------------------------------------------------------------===// 548 549 /// Constructs a new iterator. 550 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 551 DenseElementsAttr attr, size_t index) 552 : indexed_accessor_iterator<AttributeElementIterator, const void *, 553 Attribute, Attribute, Attribute>( 554 attr.getAsOpaquePointer(), index) {} 555 556 /// Accesses the Attribute value at this iterator position. 557 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 558 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 559 Type eltTy = owner.getType().getElementType(); 560 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) { 561 if (intEltTy.getWidth() == 1) 562 return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(), 563 owner.getContext()); 564 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 565 } 566 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 567 IntElementIterator intIt(owner, index); 568 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 569 return FloatAttr::get(eltTy, *floatIt); 570 } 571 llvm_unreachable("unexpected element type"); 572 } 573 574 /// Constructs a new iterator. 575 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 576 DenseElementsAttr attr, size_t dataIndex) 577 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 578 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 579 580 /// Accesses the bool value at this iterator position. 581 bool DenseElementsAttr::BoolElementIterator::operator*() const { 582 return getBit(getData(), getDataIndex()); 583 } 584 585 /// Constructs a new iterator. 586 DenseElementsAttr::IntElementIterator::IntElementIterator( 587 DenseElementsAttr attr, size_t dataIndex) 588 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 589 attr.getRawData().data(), attr.isSplat(), dataIndex), 590 bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {} 591 592 /// Accesses the raw APInt value at this iterator position. 593 APInt DenseElementsAttr::IntElementIterator::operator*() const { 594 return readBits(getData(), 595 getDataIndex() * getDenseElementStorageWidth(bitWidth), 596 bitWidth); 597 } 598 599 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 600 const llvm::fltSemantics &smt, IntElementIterator it) 601 : llvm::mapped_iterator<IntElementIterator, 602 std::function<APFloat(const APInt &)>>( 603 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 604 605 //===----------------------------------------------------------------------===// 606 // DenseElementsAttr 607 //===----------------------------------------------------------------------===// 608 609 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 610 ArrayRef<Attribute> values) { 611 assert(type.getElementType().isSignlessIntOrFloat() && 612 "expected int or float element type"); 613 assert(hasSameElementsOrSplat(type, values)); 614 615 auto eltType = type.getElementType(); 616 size_t bitWidth = getDenseElementBitwidth(eltType); 617 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 618 619 // Compress the attribute values into a character buffer. 620 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 621 values.size()); 622 APInt intVal; 623 for (unsigned i = 0, e = values.size(); i < e; ++i) { 624 assert(eltType == values[i].getType() && 625 "expected attribute value to have element type"); 626 627 switch (eltType.getKind()) { 628 case StandardTypes::BF16: 629 case StandardTypes::F16: 630 case StandardTypes::F32: 631 case StandardTypes::F64: 632 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 633 break; 634 case StandardTypes::Integer: 635 intVal = values[i].isa<BoolAttr>() 636 ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0) 637 : values[i].cast<IntegerAttr>().getValue(); 638 break; 639 default: 640 llvm_unreachable("unexpected element type"); 641 } 642 assert(intVal.getBitWidth() == bitWidth && 643 "expected value to have same bitwidth as element type"); 644 writeBits(data.data(), i * storageBitWidth, intVal); 645 } 646 return getRaw(type, data, /*isSplat=*/(values.size() == 1)); 647 } 648 649 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 650 ArrayRef<bool> values) { 651 assert(hasSameElementsOrSplat(type, values)); 652 assert(type.getElementType().isInteger(1)); 653 654 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 655 for (int i = 0, e = values.size(); i != e; ++i) 656 setBit(buff.data(), i, values[i]); 657 return getRaw(type, buff, /*isSplat=*/(values.size() == 1)); 658 } 659 660 /// Constructs a dense integer elements attribute from an array of APInt 661 /// values. Each APInt value is expected to have the same bitwidth as the 662 /// element type of 'type'. 663 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 664 ArrayRef<APInt> values) { 665 assert(type.getElementType().isa<IntegerType>()); 666 return getRaw(type, values); 667 } 668 669 // Constructs a dense float elements attribute from an array of APFloat 670 // values. Each APFloat value is expected to have the same bitwidth as the 671 // element type of 'type'. 672 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 673 ArrayRef<APFloat> values) { 674 assert(type.getElementType().isa<FloatType>()); 675 676 // Convert the APFloat values to APInt and create a dense elements attribute. 677 std::vector<APInt> intValues(values.size()); 678 for (unsigned i = 0, e = values.size(); i != e; ++i) 679 intValues[i] = values[i].bitcastToAPInt(); 680 return getRaw(type, intValues); 681 } 682 683 /// Construct a dense elements attribute from a raw buffer representing the 684 /// data for this attribute. Users should generally not use this methods as 685 /// the expected buffer format may not be a form the user expects. 686 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, 687 ArrayRef<char> rawBuffer, 688 bool isSplatBuffer) { 689 return getRaw(type, rawBuffer, isSplatBuffer); 690 } 691 692 /// Constructs a dense elements attribute from an array of raw APInt values. 693 /// Each APInt value is expected to have the same bitwidth as the element type 694 /// of 'type'. 695 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, 696 ArrayRef<APInt> values) { 697 assert(hasSameElementsOrSplat(type, values)); 698 699 size_t bitWidth = getDenseElementBitwidth(type.getElementType()); 700 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 701 std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 702 values.size()); 703 for (unsigned i = 0, e = values.size(); i != e; ++i) { 704 assert(values[i].getBitWidth() == bitWidth); 705 writeBits(elementData.data(), i * storageBitWidth, values[i]); 706 } 707 return getRaw(type, elementData, /*isSplat=*/(values.size() == 1)); 708 } 709 710 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, 711 ArrayRef<char> data, bool isSplat) { 712 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 713 "type must be ranked tensor or vector"); 714 assert(type.hasStaticShape() && "type must have static shape"); 715 return Base::get(type.getContext(), StandardAttributes::DenseElements, type, 716 data, isSplat); 717 } 718 719 /// Check the information for a C++ data type, check if this type is valid for 720 /// the current attribute. This method is used to verify specific type 721 /// invariants that the templatized 'getValues' method cannot. 722 static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, bool isInt, 723 bool isSigned) { 724 // Make sure that the data element size is the same as the type element width. 725 if (getDenseElementBitwidth(type.getElementType()) != 726 static_cast<size_t>(dataEltSize * CHAR_BIT)) 727 return false; 728 729 // Check that the element type is either float or integer. 730 if (!isInt) 731 return type.getElementType().isa<FloatType>(); 732 733 auto intType = type.getElementType().dyn_cast<IntegerType>(); 734 if (!intType) 735 return false; 736 737 // Make sure signedness semantics is consistent. 738 if (intType.isSignless()) 739 return true; 740 return intType.isSigned() ? isSigned : !isSigned; 741 } 742 743 /// Overload of the 'getRaw' method that asserts that the given type is of 744 /// integer type. This method is used to verify type invariants that the 745 /// templatized 'get' method cannot. 746 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 747 ArrayRef<char> data, 748 int64_t dataEltSize, 749 bool isInt, 750 bool isSigned) { 751 assert(::isValidIntOrFloat(type, dataEltSize, isInt, isSigned)); 752 753 int64_t numElements = data.size() / dataEltSize; 754 assert(numElements == 1 || numElements == type.getNumElements()); 755 return getRaw(type, data, /*isSplat=*/numElements == 1); 756 } 757 758 /// A method used to verify specific type invariants that the templatized 'get' 759 /// method cannot. 760 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 761 bool isSigned) const { 762 return ::isValidIntOrFloat(getType(), dataEltSize, isInt, isSigned); 763 } 764 765 /// Returns if this attribute corresponds to a splat, i.e. if all element 766 /// values are the same. 767 bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; } 768 769 /// Return the held element values as a range of Attributes. 770 auto DenseElementsAttr::getAttributeValues() const 771 -> llvm::iterator_range<AttributeElementIterator> { 772 return {attr_value_begin(), attr_value_end()}; 773 } 774 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 775 return AttributeElementIterator(*this, 0); 776 } 777 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 778 return AttributeElementIterator(*this, getNumElements()); 779 } 780 781 /// Return the held element values as a range of bool. The element type of 782 /// this attribute must be of integer type of bitwidth 1. 783 auto DenseElementsAttr::getBoolValues() const 784 -> llvm::iterator_range<BoolElementIterator> { 785 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 786 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 787 (void)eltType; 788 return {BoolElementIterator(*this, 0), 789 BoolElementIterator(*this, getNumElements())}; 790 } 791 792 /// Return the held element values as a range of APInts. The element type of 793 /// this attribute must be of integer type. 794 auto DenseElementsAttr::getIntValues() const 795 -> llvm::iterator_range<IntElementIterator> { 796 assert(getType().getElementType().isa<IntegerType>() && 797 "expected integer type"); 798 return {raw_int_begin(), raw_int_end()}; 799 } 800 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 801 assert(getType().getElementType().isa<IntegerType>() && 802 "expected integer type"); 803 return raw_int_begin(); 804 } 805 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 806 assert(getType().getElementType().isa<IntegerType>() && 807 "expected integer type"); 808 return raw_int_end(); 809 } 810 811 /// Return the held element values as a range of APFloat. The element type of 812 /// this attribute must be of float type. 813 auto DenseElementsAttr::getFloatValues() const 814 -> llvm::iterator_range<FloatElementIterator> { 815 auto elementType = getType().getElementType().cast<FloatType>(); 816 assert(elementType.isa<FloatType>() && "expected float type"); 817 const auto &elementSemantics = elementType.getFloatSemantics(); 818 return {FloatElementIterator(elementSemantics, raw_int_begin()), 819 FloatElementIterator(elementSemantics, raw_int_end())}; 820 } 821 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 822 return getFloatValues().begin(); 823 } 824 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 825 return getFloatValues().end(); 826 } 827 828 /// Return the raw storage data held by this attribute. 829 ArrayRef<char> DenseElementsAttr::getRawData() const { 830 return static_cast<ImplType *>(impl)->data; 831 } 832 833 /// Return a new DenseElementsAttr that has the same data as the current 834 /// attribute, but has been reshaped to 'newType'. The new type must have the 835 /// same total number of elements as well as element type. 836 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 837 ShapedType curType = getType(); 838 if (curType == newType) 839 return *this; 840 841 (void)curType; 842 assert(newType.getElementType() == curType.getElementType() && 843 "expected the same element type"); 844 assert(newType.getNumElements() == curType.getNumElements() && 845 "expected the same number of elements"); 846 return getRaw(newType, getRawData(), isSplat()); 847 } 848 849 DenseElementsAttr 850 DenseElementsAttr::mapValues(Type newElementType, 851 function_ref<APInt(const APInt &)> mapping) const { 852 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 853 } 854 855 DenseElementsAttr DenseElementsAttr::mapValues( 856 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 857 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 858 } 859 860 //===----------------------------------------------------------------------===// 861 // DenseFPElementsAttr 862 //===----------------------------------------------------------------------===// 863 864 template <typename Fn, typename Attr> 865 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 866 Type newElementType, 867 llvm::SmallVectorImpl<char> &data) { 868 size_t bitWidth = getDenseElementBitwidth(newElementType); 869 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 870 871 ShapedType newArrayType; 872 if (inType.isa<RankedTensorType>()) 873 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 874 else if (inType.isa<UnrankedTensorType>()) 875 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 876 else if (inType.isa<VectorType>()) 877 newArrayType = VectorType::get(inType.getShape(), newElementType); 878 else 879 assert(newArrayType && "Unhandled tensor type"); 880 881 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 882 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 883 884 // Functor used to process a single element value of the attribute. 885 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 886 auto newInt = mapping(value); 887 assert(newInt.getBitWidth() == bitWidth); 888 writeBits(data.data(), index * storageBitWidth, newInt); 889 }; 890 891 // Check for the splat case. 892 if (attr.isSplat()) { 893 processElt(*attr.begin(), /*index=*/0); 894 return newArrayType; 895 } 896 897 // Otherwise, process all of the element values. 898 uint64_t elementIdx = 0; 899 for (auto value : attr) 900 processElt(value, elementIdx++); 901 return newArrayType; 902 } 903 904 DenseElementsAttr DenseFPElementsAttr::mapValues( 905 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 906 llvm::SmallVector<char, 8> elementData; 907 auto newArrayType = 908 mappingHelper(mapping, *this, getType(), newElementType, elementData); 909 910 return getRaw(newArrayType, elementData, isSplat()); 911 } 912 913 /// Method for supporting type inquiry through isa, cast and dyn_cast. 914 bool DenseFPElementsAttr::classof(Attribute attr) { 915 return attr.isa<DenseElementsAttr>() && 916 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 917 } 918 919 //===----------------------------------------------------------------------===// 920 // DenseIntElementsAttr 921 //===----------------------------------------------------------------------===// 922 923 DenseElementsAttr DenseIntElementsAttr::mapValues( 924 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 925 llvm::SmallVector<char, 8> elementData; 926 auto newArrayType = 927 mappingHelper(mapping, *this, getType(), newElementType, elementData); 928 929 return getRaw(newArrayType, elementData, isSplat()); 930 } 931 932 /// Method for supporting type inquiry through isa, cast and dyn_cast. 933 bool DenseIntElementsAttr::classof(Attribute attr) { 934 return attr.isa<DenseElementsAttr>() && 935 attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>(); 936 } 937 938 //===----------------------------------------------------------------------===// 939 // OpaqueElementsAttr 940 //===----------------------------------------------------------------------===// 941 942 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, 943 StringRef bytes) { 944 assert(TensorType::isValidElementType(type.getElementType()) && 945 "Input element type should be a valid tensor element type"); 946 return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type, 947 dialect, bytes); 948 } 949 950 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } 951 952 /// Return the value at the given index. If index does not refer to a valid 953 /// element, then a null attribute is returned. 954 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 955 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 956 if (Dialect *dialect = getDialect()) 957 return dialect->extractElementHook(*this, index); 958 return Attribute(); 959 } 960 961 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } 962 963 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 964 if (auto *d = getDialect()) 965 return d->decodeHook(*this, result); 966 return true; 967 } 968 969 //===----------------------------------------------------------------------===// 970 // SparseElementsAttr 971 //===----------------------------------------------------------------------===// 972 973 SparseElementsAttr SparseElementsAttr::get(ShapedType type, 974 DenseElementsAttr indices, 975 DenseElementsAttr values) { 976 assert(indices.getType().getElementType().isInteger(64) && 977 "expected sparse indices to be 64-bit integer values"); 978 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 979 "type must be ranked tensor or vector"); 980 assert(type.hasStaticShape() && "type must have static shape"); 981 return Base::get(type.getContext(), StandardAttributes::SparseElements, type, 982 indices.cast<DenseIntElementsAttr>(), values); 983 } 984 985 DenseIntElementsAttr SparseElementsAttr::getIndices() const { 986 return getImpl()->indices; 987 } 988 989 DenseElementsAttr SparseElementsAttr::getValues() const { 990 return getImpl()->values; 991 } 992 993 /// Return the value of the element at the given index. 994 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 995 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 996 auto type = getType(); 997 998 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 999 // as a 1-D index array. 1000 auto sparseIndices = getIndices(); 1001 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1002 1003 // Check to see if the indices are a splat. 1004 if (sparseIndices.isSplat()) { 1005 // If the index is also not a splat of the index value, we know that the 1006 // value is zero. 1007 auto splatIndex = *sparseIndexValues.begin(); 1008 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 1009 return getZeroAttr(); 1010 1011 // If the indices are a splat, we also expect the values to be a splat. 1012 assert(getValues().isSplat() && "expected splat values"); 1013 return getValues().getSplatValue(); 1014 } 1015 1016 // Build a mapping between known indices and the offset of the stored element. 1017 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 1018 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1019 size_t rank = type.getRank(); 1020 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1021 mappedIndices.try_emplace( 1022 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 1023 1024 // Look for the provided index key within the mapped indices. If the provided 1025 // index is not found, then return a zero attribute. 1026 auto it = mappedIndices.find(index); 1027 if (it == mappedIndices.end()) 1028 return getZeroAttr(); 1029 1030 // Otherwise, return the held sparse value element. 1031 return getValues().getValue(it->second); 1032 } 1033 1034 /// Get a zero APFloat for the given sparse attribute. 1035 APFloat SparseElementsAttr::getZeroAPFloat() const { 1036 auto eltType = getType().getElementType().cast<FloatType>(); 1037 return APFloat(eltType.getFloatSemantics()); 1038 } 1039 1040 /// Get a zero APInt for the given sparse attribute. 1041 APInt SparseElementsAttr::getZeroAPInt() const { 1042 auto eltType = getType().getElementType().cast<IntegerType>(); 1043 return APInt::getNullValue(eltType.getWidth()); 1044 } 1045 1046 /// Get a zero attribute for the given attribute type. 1047 Attribute SparseElementsAttr::getZeroAttr() const { 1048 auto eltType = getType().getElementType(); 1049 1050 // Handle floating point elements. 1051 if (eltType.isa<FloatType>()) 1052 return FloatAttr::get(eltType, 0); 1053 1054 // Otherwise, this is an integer. 1055 auto intEltTy = eltType.cast<IntegerType>(); 1056 if (intEltTy.getWidth() == 1) 1057 return BoolAttr::get(false, eltType.getContext()); 1058 return IntegerAttr::get(eltType, 0); 1059 } 1060 1061 /// Flatten, and return, all of the sparse indices in this attribute in 1062 /// row-major order. 1063 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1064 std::vector<ptrdiff_t> flatSparseIndices; 1065 1066 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1067 // as a 1-D index array. 1068 auto sparseIndices = getIndices(); 1069 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1070 if (sparseIndices.isSplat()) { 1071 SmallVector<uint64_t, 8> indices(getType().getRank(), 1072 *sparseIndexValues.begin()); 1073 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1074 return flatSparseIndices; 1075 } 1076 1077 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1078 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1079 size_t rank = getType().getRank(); 1080 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1081 flatSparseIndices.push_back(getFlattenedIndex( 1082 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1083 return flatSparseIndices; 1084 } 1085 1086 //===----------------------------------------------------------------------===// 1087 // NamedAttributeList 1088 //===----------------------------------------------------------------------===// 1089 1090 NamedAttributeList::NamedAttributeList(ArrayRef<NamedAttribute> attributes) { 1091 setAttrs(attributes); 1092 } 1093 1094 ArrayRef<NamedAttribute> NamedAttributeList::getAttrs() const { 1095 return attrs ? attrs.getValue() : llvm::None; 1096 } 1097 1098 /// Replace the held attributes with ones provided in 'newAttrs'. 1099 void NamedAttributeList::setAttrs(ArrayRef<NamedAttribute> attributes) { 1100 // Don't create an attribute list if there are no attributes. 1101 if (attributes.empty()) 1102 attrs = nullptr; 1103 else 1104 attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext()); 1105 } 1106 1107 /// Return the specified attribute if present, null otherwise. 1108 Attribute NamedAttributeList::get(StringRef name) const { 1109 return attrs ? attrs.get(name) : nullptr; 1110 } 1111 1112 /// Return the specified attribute if present, null otherwise. 1113 Attribute NamedAttributeList::get(Identifier name) const { 1114 return attrs ? attrs.get(name) : nullptr; 1115 } 1116 1117 /// If the an attribute exists with the specified name, change it to the new 1118 /// value. Otherwise, add a new attribute with the specified name/value. 1119 void NamedAttributeList::set(Identifier name, Attribute value) { 1120 assert(value && "attributes may never be null"); 1121 1122 // If we already have this attribute, replace it. 1123 auto origAttrs = getAttrs(); 1124 SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end()); 1125 for (auto &elt : newAttrs) 1126 if (elt.first == name) { 1127 elt.second = value; 1128 attrs = DictionaryAttr::get(newAttrs, value.getContext()); 1129 return; 1130 } 1131 1132 // Otherwise, add it. 1133 newAttrs.push_back({name, value}); 1134 attrs = DictionaryAttr::get(newAttrs, value.getContext()); 1135 } 1136 1137 /// Remove the attribute with the specified name if it exists. The return 1138 /// value indicates whether the attribute was present or not. 1139 auto NamedAttributeList::remove(Identifier name) -> RemoveResult { 1140 auto origAttrs = getAttrs(); 1141 for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { 1142 if (origAttrs[i].first == name) { 1143 // Handle the simple case of removing the only attribute in the list. 1144 if (e == 1) { 1145 attrs = nullptr; 1146 return RemoveResult::Removed; 1147 } 1148 1149 SmallVector<NamedAttribute, 8> newAttrs; 1150 newAttrs.reserve(origAttrs.size() - 1); 1151 newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); 1152 newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); 1153 attrs = DictionaryAttr::get(newAttrs, newAttrs[0].second.getContext()); 1154 return RemoveResult::Removed; 1155 } 1156 } 1157 return RemoveResult::NotFound; 1158 } 1159