1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Attributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Diagnostics.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/Function.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/IR/Types.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 23 //===----------------------------------------------------------------------===// 24 // AttributeStorage 25 //===----------------------------------------------------------------------===// 26 27 AttributeStorage::AttributeStorage(Type type) 28 : type(type.getAsOpaquePointer()) {} 29 AttributeStorage::AttributeStorage() : type(nullptr) {} 30 31 Type AttributeStorage::getType() const { 32 return Type::getFromOpaquePointer(type); 33 } 34 void AttributeStorage::setType(Type newType) { 35 type = newType.getAsOpaquePointer(); 36 } 37 38 //===----------------------------------------------------------------------===// 39 // Attribute 40 //===----------------------------------------------------------------------===// 41 42 /// Return the type of this attribute. 43 Type Attribute::getType() const { return impl->getType(); } 44 45 /// Return the context this attribute belongs to. 46 MLIRContext *Attribute::getContext() const { return getType().getContext(); } 47 48 /// Get the dialect this attribute is registered to. 49 Dialect &Attribute::getDialect() const { return impl->getDialect(); } 50 51 //===----------------------------------------------------------------------===// 52 // AffineMapAttr 53 //===----------------------------------------------------------------------===// 54 55 AffineMapAttr AffineMapAttr::get(AffineMap value) { 56 return Base::get(value.getContext(), StandardAttributes::AffineMap, value); 57 } 58 59 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } 60 61 //===----------------------------------------------------------------------===// 62 // ArrayAttr 63 //===----------------------------------------------------------------------===// 64 65 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) { 66 return Base::get(context, StandardAttributes::Array, value); 67 } 68 69 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } 70 71 //===----------------------------------------------------------------------===// 72 // BoolAttr 73 //===----------------------------------------------------------------------===// 74 75 bool BoolAttr::getValue() const { return getImpl()->value; } 76 77 //===----------------------------------------------------------------------===// 78 // DictionaryAttr 79 //===----------------------------------------------------------------------===// 80 81 /// Perform a three-way comparison between the names of the specified 82 /// NamedAttributes. 83 static int compareNamedAttributes(const NamedAttribute *lhs, 84 const NamedAttribute *rhs) { 85 return lhs->first.strref().compare(rhs->first.strref()); 86 } 87 88 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value, 89 MLIRContext *context) { 90 assert(llvm::all_of(value, 91 [](const NamedAttribute &attr) { return attr.second; }) && 92 "value cannot have null entries"); 93 94 // We need to sort the element list to canonicalize it, but we also don't want 95 // to do a ton of work in the super common case where the element list is 96 // already sorted. 97 SmallVector<NamedAttribute, 8> storage; 98 switch (value.size()) { 99 case 0: 100 break; 101 case 1: 102 // A single element is already sorted. 103 break; 104 case 2: 105 assert(value[0].first != value[1].first && 106 "DictionaryAttr element names must be unique"); 107 108 // Don't invoke a general sort for two element case. 109 if (value[0].first.strref() > value[1].first.strref()) { 110 storage.push_back(value[1]); 111 storage.push_back(value[0]); 112 value = storage; 113 } 114 break; 115 default: 116 // Check to see they are sorted already. 117 bool isSorted = true; 118 for (unsigned i = 0, e = value.size() - 1; i != e; ++i) { 119 if (value[i].first.strref() > value[i + 1].first.strref()) { 120 isSorted = false; 121 break; 122 } 123 } 124 // If not, do a general sort. 125 if (!isSorted) { 126 storage.append(value.begin(), value.end()); 127 llvm::array_pod_sort(storage.begin(), storage.end(), 128 compareNamedAttributes); 129 value = storage; 130 } 131 132 // Ensure that the attribute elements are unique. 133 assert(std::adjacent_find(value.begin(), value.end(), 134 [](NamedAttribute l, NamedAttribute r) { 135 return l.first == r.first; 136 }) == value.end() && 137 "DictionaryAttr element names must be unique"); 138 } 139 140 return Base::get(context, StandardAttributes::Dictionary, value); 141 } 142 143 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const { 144 return getImpl()->getElements(); 145 } 146 147 /// Return the specified attribute if present, null otherwise. 148 Attribute DictionaryAttr::get(StringRef name) const { 149 ArrayRef<NamedAttribute> values = getValue(); 150 auto compare = [](NamedAttribute attr, StringRef name) { 151 return attr.first.strref() < name; 152 }; 153 auto it = llvm::lower_bound(values, name, compare); 154 return it != values.end() && it->first.is(name) ? it->second : Attribute(); 155 } 156 Attribute DictionaryAttr::get(Identifier name) const { 157 for (auto elt : getValue()) 158 if (elt.first == name) 159 return elt.second; 160 return nullptr; 161 } 162 163 DictionaryAttr::iterator DictionaryAttr::begin() const { 164 return getValue().begin(); 165 } 166 DictionaryAttr::iterator DictionaryAttr::end() const { 167 return getValue().end(); 168 } 169 size_t DictionaryAttr::size() const { return getValue().size(); } 170 171 //===----------------------------------------------------------------------===// 172 // FloatAttr 173 //===----------------------------------------------------------------------===// 174 175 FloatAttr FloatAttr::get(Type type, double value) { 176 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 177 } 178 179 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { 180 return Base::getChecked(loc, type.getContext(), StandardAttributes::Float, 181 type, value); 182 } 183 184 FloatAttr FloatAttr::get(Type type, const APFloat &value) { 185 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 186 } 187 188 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { 189 return Base::getChecked(loc, type.getContext(), StandardAttributes::Float, 190 type, value); 191 } 192 193 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } 194 195 double FloatAttr::getValueAsDouble() const { 196 return getValueAsDouble(getValue()); 197 } 198 double FloatAttr::getValueAsDouble(APFloat value) { 199 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 200 bool losesInfo = false; 201 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 202 &losesInfo); 203 } 204 return value.convertToDouble(); 205 } 206 207 /// Verify construction invariants. 208 static LogicalResult verifyFloatTypeInvariants(Optional<Location> loc, 209 Type type) { 210 if (!type.isa<FloatType>()) 211 return emitOptionalError(loc, "expected floating point type"); 212 return success(); 213 } 214 215 LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc, 216 MLIRContext *ctx, 217 Type type, double value) { 218 return verifyFloatTypeInvariants(loc, type); 219 } 220 221 LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc, 222 MLIRContext *ctx, 223 Type type, 224 const APFloat &value) { 225 // Verify that the type is correct. 226 if (failed(verifyFloatTypeInvariants(loc, type))) 227 return failure(); 228 229 // Verify that the type semantics match that of the value. 230 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 231 return emitOptionalError( 232 loc, "FloatAttr type doesn't match the type implied by its value"); 233 } 234 return success(); 235 } 236 237 //===----------------------------------------------------------------------===// 238 // SymbolRefAttr 239 //===----------------------------------------------------------------------===// 240 241 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { 242 return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None) 243 .cast<FlatSymbolRefAttr>(); 244 } 245 246 SymbolRefAttr SymbolRefAttr::get(StringRef value, 247 ArrayRef<FlatSymbolRefAttr> nestedReferences, 248 MLIRContext *ctx) { 249 return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences); 250 } 251 252 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } 253 254 StringRef SymbolRefAttr::getLeafReference() const { 255 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 256 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 257 } 258 259 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const { 260 return getImpl()->getNestedRefs(); 261 } 262 263 //===----------------------------------------------------------------------===// 264 // IntegerAttr 265 //===----------------------------------------------------------------------===// 266 267 IntegerAttr IntegerAttr::get(Type type, const APInt &value) { 268 return Base::get(type.getContext(), StandardAttributes::Integer, type, value); 269 } 270 271 IntegerAttr IntegerAttr::get(Type type, int64_t value) { 272 // This uses 64 bit APInts by default for index type. 273 if (type.isIndex()) 274 return get(type, APInt(64, value)); 275 276 auto intType = type.cast<IntegerType>(); 277 return get(type, APInt(intType.getWidth(), value)); 278 } 279 280 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } 281 282 int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); } 283 284 //===----------------------------------------------------------------------===// 285 // IntegerSetAttr 286 //===----------------------------------------------------------------------===// 287 288 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { 289 return Base::get(value.getConstraint(0).getContext(), 290 StandardAttributes::IntegerSet, value); 291 } 292 293 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } 294 295 //===----------------------------------------------------------------------===// 296 // OpaqueAttr 297 //===----------------------------------------------------------------------===// 298 299 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, 300 MLIRContext *context) { 301 return Base::get(context, StandardAttributes::Opaque, dialect, attrData, 302 type); 303 } 304 305 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, 306 Type type, Location location) { 307 return Base::getChecked(location, type.getContext(), 308 StandardAttributes::Opaque, dialect, attrData, type); 309 } 310 311 /// Returns the dialect namespace of the opaque attribute. 312 Identifier OpaqueAttr::getDialectNamespace() const { 313 return getImpl()->dialectNamespace; 314 } 315 316 /// Returns the raw attribute data of the opaque attribute. 317 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } 318 319 /// Verify the construction of an opaque attribute. 320 LogicalResult OpaqueAttr::verifyConstructionInvariants(Optional<Location> loc, 321 MLIRContext *context, 322 Identifier dialect, 323 StringRef attrData, 324 Type type) { 325 if (!Dialect::isValidNamespace(dialect.strref())) 326 return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'"); 327 return success(); 328 } 329 330 //===----------------------------------------------------------------------===// 331 // StringAttr 332 //===----------------------------------------------------------------------===// 333 334 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { 335 return get(bytes, NoneType::get(context)); 336 } 337 338 /// Get an instance of a StringAttr with the given string and Type. 339 StringAttr StringAttr::get(StringRef bytes, Type type) { 340 return Base::get(type.getContext(), StandardAttributes::String, bytes, type); 341 } 342 343 StringRef StringAttr::getValue() const { return getImpl()->value; } 344 345 //===----------------------------------------------------------------------===// 346 // TypeAttr 347 //===----------------------------------------------------------------------===// 348 349 TypeAttr TypeAttr::get(Type value) { 350 return Base::get(value.getContext(), StandardAttributes::Type, value); 351 } 352 353 Type TypeAttr::getValue() const { return getImpl()->value; } 354 355 //===----------------------------------------------------------------------===// 356 // ElementsAttr 357 //===----------------------------------------------------------------------===// 358 359 ShapedType ElementsAttr::getType() const { 360 return Attribute::getType().cast<ShapedType>(); 361 } 362 363 /// Returns the number of elements held by this attribute. 364 int64_t ElementsAttr::getNumElements() const { 365 return getType().getNumElements(); 366 } 367 368 /// Return the value at the given index. If index does not refer to a valid 369 /// element, then a null attribute is returned. 370 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 371 switch (getKind()) { 372 case StandardAttributes::DenseElements: 373 return cast<DenseElementsAttr>().getValue(index); 374 case StandardAttributes::OpaqueElements: 375 return cast<OpaqueElementsAttr>().getValue(index); 376 case StandardAttributes::SparseElements: 377 return cast<SparseElementsAttr>().getValue(index); 378 default: 379 llvm_unreachable("unknown ElementsAttr kind"); 380 } 381 } 382 383 /// Return if the given 'index' refers to a valid element in this attribute. 384 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 385 auto type = getType(); 386 387 // Verify that the rank of the indices matches the held type. 388 auto rank = type.getRank(); 389 if (rank != static_cast<int64_t>(index.size())) 390 return false; 391 392 // Verify that all of the indices are within the shape dimensions. 393 auto shape = type.getShape(); 394 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 395 return static_cast<int64_t>(index[i]) < shape[i]; 396 }); 397 } 398 399 ElementsAttr 400 ElementsAttr::mapValues(Type newElementType, 401 function_ref<APInt(const APInt &)> mapping) const { 402 switch (getKind()) { 403 case StandardAttributes::DenseElements: 404 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 405 default: 406 llvm_unreachable("unsupported ElementsAttr subtype"); 407 } 408 } 409 410 ElementsAttr 411 ElementsAttr::mapValues(Type newElementType, 412 function_ref<APInt(const APFloat &)> mapping) const { 413 switch (getKind()) { 414 case StandardAttributes::DenseElements: 415 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 416 default: 417 llvm_unreachable("unsupported ElementsAttr subtype"); 418 } 419 } 420 421 /// Returns the 1 dimensional flattened row-major index from the given 422 /// multi-dimensional index. 423 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 424 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 425 auto type = getType(); 426 427 // Reduce the provided multidimensional index into a flattended 1D row-major 428 // index. 429 auto rank = type.getRank(); 430 auto shape = type.getShape(); 431 uint64_t valueIndex = 0; 432 uint64_t dimMultiplier = 1; 433 for (int i = rank - 1; i >= 0; --i) { 434 valueIndex += index[i] * dimMultiplier; 435 dimMultiplier *= shape[i]; 436 } 437 return valueIndex; 438 } 439 440 //===----------------------------------------------------------------------===// 441 // DenseElementAttr Utilities 442 //===----------------------------------------------------------------------===// 443 444 static size_t getDenseElementBitwidth(Type eltType) { 445 // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored 446 // with double semantics. 447 return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth(); 448 } 449 450 /// Get the bitwidth of a dense element type within the buffer. 451 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 452 static size_t getDenseElementStorageWidth(size_t origWidth) { 453 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 454 } 455 456 /// Set a bit to a specific value. 457 static void setBit(char *rawData, size_t bitPos, bool value) { 458 if (value) 459 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 460 else 461 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 462 } 463 464 /// Return the value of the specified bit. 465 static bool getBit(const char *rawData, size_t bitPos) { 466 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 467 } 468 469 /// Writes value to the bit position `bitPos` in array `rawData`. 470 static void writeBits(char *rawData, size_t bitPos, APInt value) { 471 size_t bitWidth = value.getBitWidth(); 472 473 // If the bitwidth is 1 we just toggle the specific bit. 474 if (bitWidth == 1) 475 return setBit(rawData, bitPos, value.isOneValue()); 476 477 // Otherwise, the bit position is guaranteed to be byte aligned. 478 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 479 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 480 llvm::divideCeil(bitWidth, CHAR_BIT), 481 rawData + (bitPos / CHAR_BIT)); 482 } 483 484 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 485 /// `rawData`. 486 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 487 // Handle a boolean bit position. 488 if (bitWidth == 1) 489 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 490 491 // Otherwise, the bit position must be 8-bit aligned. 492 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 493 APInt result(bitWidth, 0); 494 std::copy_n( 495 rawData + (bitPos / CHAR_BIT), llvm::divideCeil(bitWidth, CHAR_BIT), 496 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 497 return result; 498 } 499 500 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the 501 /// same element count as 'type'. 502 template <typename Values> 503 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 504 return (values.size() == 1) || 505 (type.getNumElements() == static_cast<int64_t>(values.size())); 506 } 507 508 //===----------------------------------------------------------------------===// 509 // DenseElementAttr Iterators 510 //===----------------------------------------------------------------------===// 511 512 /// Constructs a new iterator. 513 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 514 DenseElementsAttr attr, size_t index) 515 : indexed_accessor_iterator<AttributeElementIterator, const void *, 516 Attribute, Attribute, Attribute>( 517 attr.getAsOpaquePointer(), index) {} 518 519 /// Accesses the Attribute value at this iterator position. 520 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 521 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 522 Type eltTy = owner.getType().getElementType(); 523 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) { 524 if (intEltTy.getWidth() == 1) 525 return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(), 526 owner.getContext()); 527 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 528 } 529 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 530 IntElementIterator intIt(owner, index); 531 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 532 return FloatAttr::get(eltTy, *floatIt); 533 } 534 llvm_unreachable("unexpected element type"); 535 } 536 537 /// Constructs a new iterator. 538 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 539 DenseElementsAttr attr, size_t dataIndex) 540 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 541 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 542 543 /// Accesses the bool value at this iterator position. 544 bool DenseElementsAttr::BoolElementIterator::operator*() const { 545 return getBit(getData(), getDataIndex()); 546 } 547 548 /// Constructs a new iterator. 549 DenseElementsAttr::IntElementIterator::IntElementIterator( 550 DenseElementsAttr attr, size_t dataIndex) 551 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 552 attr.getRawData().data(), attr.isSplat(), dataIndex), 553 bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {} 554 555 /// Accesses the raw APInt value at this iterator position. 556 APInt DenseElementsAttr::IntElementIterator::operator*() const { 557 return readBits(getData(), 558 getDataIndex() * getDenseElementStorageWidth(bitWidth), 559 bitWidth); 560 } 561 562 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 563 const llvm::fltSemantics &smt, IntElementIterator it) 564 : llvm::mapped_iterator<IntElementIterator, 565 std::function<APFloat(const APInt &)>>( 566 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 567 568 //===----------------------------------------------------------------------===// 569 // DenseElementsAttr 570 //===----------------------------------------------------------------------===// 571 572 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 573 ArrayRef<Attribute> values) { 574 assert(type.getElementType().isIntOrFloat() && 575 "expected int or float element type"); 576 assert(hasSameElementsOrSplat(type, values)); 577 578 auto eltType = type.getElementType(); 579 size_t bitWidth = getDenseElementBitwidth(eltType); 580 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 581 582 // Compress the attribute values into a character buffer. 583 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 584 values.size()); 585 APInt intVal; 586 for (unsigned i = 0, e = values.size(); i < e; ++i) { 587 assert(eltType == values[i].getType() && 588 "expected attribute value to have element type"); 589 590 switch (eltType.getKind()) { 591 case StandardTypes::BF16: 592 case StandardTypes::F16: 593 case StandardTypes::F32: 594 case StandardTypes::F64: 595 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 596 break; 597 case StandardTypes::Integer: 598 intVal = values[i].isa<BoolAttr>() 599 ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0) 600 : values[i].cast<IntegerAttr>().getValue(); 601 break; 602 default: 603 llvm_unreachable("unexpected element type"); 604 } 605 assert(intVal.getBitWidth() == bitWidth && 606 "expected value to have same bitwidth as element type"); 607 writeBits(data.data(), i * storageBitWidth, intVal); 608 } 609 return getRaw(type, data, /*isSplat=*/(values.size() == 1)); 610 } 611 612 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 613 ArrayRef<bool> values) { 614 assert(hasSameElementsOrSplat(type, values)); 615 assert(type.getElementType().isInteger(1)); 616 617 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 618 for (int i = 0, e = values.size(); i != e; ++i) 619 setBit(buff.data(), i, values[i]); 620 return getRaw(type, buff, /*isSplat=*/(values.size() == 1)); 621 } 622 623 /// Constructs a dense integer elements attribute from an array of APInt 624 /// values. Each APInt value is expected to have the same bitwidth as the 625 /// element type of 'type'. 626 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 627 ArrayRef<APInt> values) { 628 assert(type.getElementType().isa<IntegerType>()); 629 return getRaw(type, values); 630 } 631 632 // Constructs a dense float elements attribute from an array of APFloat 633 // values. Each APFloat value is expected to have the same bitwidth as the 634 // element type of 'type'. 635 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 636 ArrayRef<APFloat> values) { 637 assert(type.getElementType().isa<FloatType>()); 638 639 // Convert the APFloat values to APInt and create a dense elements attribute. 640 std::vector<APInt> intValues(values.size()); 641 for (unsigned i = 0, e = values.size(); i != e; ++i) 642 intValues[i] = values[i].bitcastToAPInt(); 643 return getRaw(type, intValues); 644 } 645 646 // Constructs a dense elements attribute from an array of raw APInt values. 647 // Each APInt value is expected to have the same bitwidth as the element type 648 // of 'type'. 649 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, 650 ArrayRef<APInt> values) { 651 assert(hasSameElementsOrSplat(type, values)); 652 653 size_t bitWidth = getDenseElementBitwidth(type.getElementType()); 654 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 655 std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 656 values.size()); 657 for (unsigned i = 0, e = values.size(); i != e; ++i) { 658 assert(values[i].getBitWidth() == bitWidth); 659 writeBits(elementData.data(), i * storageBitWidth, values[i]); 660 } 661 return getRaw(type, elementData, /*isSplat=*/(values.size() == 1)); 662 } 663 664 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, 665 ArrayRef<char> data, bool isSplat) { 666 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 667 "type must be ranked tensor or vector"); 668 assert(type.hasStaticShape() && "type must have static shape"); 669 return Base::get(type.getContext(), StandardAttributes::DenseElements, type, 670 data, isSplat); 671 } 672 673 /// Check the information for a c++ data type, check if this type is valid for 674 /// the current attribute. This method is used to verify specific type 675 /// invariants that the templatized 'getValues' method cannot. 676 static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, 677 bool isInt) { 678 // Make sure that the data element size is the same as the type element width. 679 if ((dataEltSize * CHAR_BIT) != type.getElementTypeBitWidth()) 680 return false; 681 682 // Check that the element type is valid. 683 return isInt ? type.getElementType().isa<IntegerType>() 684 : type.getElementType().isa<FloatType>(); 685 } 686 687 /// Overload of the 'getRaw' method that asserts that the given type is of 688 /// integer type. This method is used to verify type invariants that the 689 /// templatized 'get' method cannot. 690 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 691 ArrayRef<char> data, 692 int64_t dataEltSize, 693 bool isInt) { 694 assert(::isValidIntOrFloat(type, dataEltSize, isInt)); 695 696 int64_t numElements = data.size() / dataEltSize; 697 assert(numElements == 1 || numElements == type.getNumElements()); 698 return getRaw(type, data, /*isSplat=*/numElements == 1); 699 } 700 701 /// A method used to verify specific type invariants that the templatized 'get' 702 /// method cannot. 703 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, 704 bool isInt) const { 705 return ::isValidIntOrFloat(getType(), dataEltSize, isInt); 706 } 707 708 /// Return the raw storage data held by this attribute. 709 ArrayRef<char> DenseElementsAttr::getRawData() const { 710 return static_cast<ImplType *>(impl)->data; 711 } 712 713 /// Returns if this attribute corresponds to a splat, i.e. if all element 714 /// values are the same. 715 bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; } 716 717 /// Return the held element values as a range of Attributes. 718 auto DenseElementsAttr::getAttributeValues() const 719 -> llvm::iterator_range<AttributeElementIterator> { 720 return {attr_value_begin(), attr_value_end()}; 721 } 722 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 723 return AttributeElementIterator(*this, 0); 724 } 725 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 726 return AttributeElementIterator(*this, getNumElements()); 727 } 728 729 /// Return the held element values as a range of bool. The element type of 730 /// this attribute must be of integer type of bitwidth 1. 731 auto DenseElementsAttr::getBoolValues() const 732 -> llvm::iterator_range<BoolElementIterator> { 733 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 734 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 735 (void)eltType; 736 return {BoolElementIterator(*this, 0), 737 BoolElementIterator(*this, getNumElements())}; 738 } 739 740 /// Return the held element values as a range of APInts. The element type of 741 /// this attribute must be of integer type. 742 auto DenseElementsAttr::getIntValues() const 743 -> llvm::iterator_range<IntElementIterator> { 744 assert(getType().getElementType().isa<IntegerType>() && 745 "expected integer type"); 746 return {raw_int_begin(), raw_int_end()}; 747 } 748 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 749 assert(getType().getElementType().isa<IntegerType>() && 750 "expected integer type"); 751 return raw_int_begin(); 752 } 753 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 754 assert(getType().getElementType().isa<IntegerType>() && 755 "expected integer type"); 756 return raw_int_end(); 757 } 758 759 /// Return the held element values as a range of APFloat. The element type of 760 /// this attribute must be of float type. 761 auto DenseElementsAttr::getFloatValues() const 762 -> llvm::iterator_range<FloatElementIterator> { 763 auto elementType = getType().getElementType().cast<FloatType>(); 764 assert(elementType.isa<FloatType>() && "expected float type"); 765 const auto &elementSemantics = elementType.getFloatSemantics(); 766 return {FloatElementIterator(elementSemantics, raw_int_begin()), 767 FloatElementIterator(elementSemantics, raw_int_end())}; 768 } 769 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 770 return getFloatValues().begin(); 771 } 772 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 773 return getFloatValues().end(); 774 } 775 776 /// Return a new DenseElementsAttr that has the same data as the current 777 /// attribute, but has been reshaped to 'newType'. The new type must have the 778 /// same total number of elements as well as element type. 779 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 780 ShapedType curType = getType(); 781 if (curType == newType) 782 return *this; 783 784 (void)curType; 785 assert(newType.getElementType() == curType.getElementType() && 786 "expected the same element type"); 787 assert(newType.getNumElements() == curType.getNumElements() && 788 "expected the same number of elements"); 789 return getRaw(newType, getRawData(), isSplat()); 790 } 791 792 DenseElementsAttr 793 DenseElementsAttr::mapValues(Type newElementType, 794 function_ref<APInt(const APInt &)> mapping) const { 795 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 796 } 797 798 DenseElementsAttr DenseElementsAttr::mapValues( 799 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 800 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 801 } 802 803 //===----------------------------------------------------------------------===// 804 // DenseFPElementsAttr 805 //===----------------------------------------------------------------------===// 806 807 template <typename Fn, typename Attr> 808 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 809 Type newElementType, 810 llvm::SmallVectorImpl<char> &data) { 811 size_t bitWidth = getDenseElementBitwidth(newElementType); 812 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 813 814 ShapedType newArrayType; 815 if (inType.isa<RankedTensorType>()) 816 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 817 else if (inType.isa<UnrankedTensorType>()) 818 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 819 else if (inType.isa<VectorType>()) 820 newArrayType = VectorType::get(inType.getShape(), newElementType); 821 else 822 assert(newArrayType && "Unhandled tensor type"); 823 824 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 825 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 826 827 // Functor used to process a single element value of the attribute. 828 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 829 auto newInt = mapping(value); 830 assert(newInt.getBitWidth() == bitWidth); 831 writeBits(data.data(), index * storageBitWidth, newInt); 832 }; 833 834 // Check for the splat case. 835 if (attr.isSplat()) { 836 processElt(*attr.begin(), /*index=*/0); 837 return newArrayType; 838 } 839 840 // Otherwise, process all of the element values. 841 uint64_t elementIdx = 0; 842 for (auto value : attr) 843 processElt(value, elementIdx++); 844 return newArrayType; 845 } 846 847 DenseElementsAttr DenseFPElementsAttr::mapValues( 848 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 849 llvm::SmallVector<char, 8> elementData; 850 auto newArrayType = 851 mappingHelper(mapping, *this, getType(), newElementType, elementData); 852 853 return getRaw(newArrayType, elementData, isSplat()); 854 } 855 856 /// Method for supporting type inquiry through isa, cast and dyn_cast. 857 bool DenseFPElementsAttr::classof(Attribute attr) { 858 return attr.isa<DenseElementsAttr>() && 859 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 860 } 861 862 //===----------------------------------------------------------------------===// 863 // DenseIntElementsAttr 864 //===----------------------------------------------------------------------===// 865 866 DenseElementsAttr DenseIntElementsAttr::mapValues( 867 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 868 llvm::SmallVector<char, 8> elementData; 869 auto newArrayType = 870 mappingHelper(mapping, *this, getType(), newElementType, elementData); 871 872 return getRaw(newArrayType, elementData, isSplat()); 873 } 874 875 /// Method for supporting type inquiry through isa, cast and dyn_cast. 876 bool DenseIntElementsAttr::classof(Attribute attr) { 877 return attr.isa<DenseElementsAttr>() && 878 attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>(); 879 } 880 881 //===----------------------------------------------------------------------===// 882 // OpaqueElementsAttr 883 //===----------------------------------------------------------------------===// 884 885 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, 886 StringRef bytes) { 887 assert(TensorType::isValidElementType(type.getElementType()) && 888 "Input element type should be a valid tensor element type"); 889 return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type, 890 dialect, bytes); 891 } 892 893 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } 894 895 /// Return the value at the given index. If index does not refer to a valid 896 /// element, then a null attribute is returned. 897 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 898 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 899 if (Dialect *dialect = getDialect()) 900 return dialect->extractElementHook(*this, index); 901 return Attribute(); 902 } 903 904 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } 905 906 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 907 if (auto *d = getDialect()) 908 return d->decodeHook(*this, result); 909 return true; 910 } 911 912 //===----------------------------------------------------------------------===// 913 // SparseElementsAttr 914 //===----------------------------------------------------------------------===// 915 916 SparseElementsAttr SparseElementsAttr::get(ShapedType type, 917 DenseElementsAttr indices, 918 DenseElementsAttr values) { 919 assert(indices.getType().getElementType().isInteger(64) && 920 "expected sparse indices to be 64-bit integer values"); 921 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 922 "type must be ranked tensor or vector"); 923 assert(type.hasStaticShape() && "type must have static shape"); 924 return Base::get(type.getContext(), StandardAttributes::SparseElements, type, 925 indices.cast<DenseIntElementsAttr>(), values); 926 } 927 928 DenseIntElementsAttr SparseElementsAttr::getIndices() const { 929 return getImpl()->indices; 930 } 931 932 DenseElementsAttr SparseElementsAttr::getValues() const { 933 return getImpl()->values; 934 } 935 936 /// Return the value of the element at the given index. 937 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 938 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 939 auto type = getType(); 940 941 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 942 // as a 1-D index array. 943 auto sparseIndices = getIndices(); 944 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 945 946 // Check to see if the indices are a splat. 947 if (sparseIndices.isSplat()) { 948 // If the index is also not a splat of the index value, we know that the 949 // value is zero. 950 auto splatIndex = *sparseIndexValues.begin(); 951 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 952 return getZeroAttr(); 953 954 // If the indices are a splat, we also expect the values to be a splat. 955 assert(getValues().isSplat() && "expected splat values"); 956 return getValues().getSplatValue(); 957 } 958 959 // Build a mapping between known indices and the offset of the stored element. 960 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 961 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 962 size_t rank = type.getRank(); 963 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 964 mappedIndices.try_emplace( 965 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 966 967 // Look for the provided index key within the mapped indices. If the provided 968 // index is not found, then return a zero attribute. 969 auto it = mappedIndices.find(index); 970 if (it == mappedIndices.end()) 971 return getZeroAttr(); 972 973 // Otherwise, return the held sparse value element. 974 return getValues().getValue(it->second); 975 } 976 977 /// Get a zero APFloat for the given sparse attribute. 978 APFloat SparseElementsAttr::getZeroAPFloat() const { 979 auto eltType = getType().getElementType().cast<FloatType>(); 980 return APFloat(eltType.getFloatSemantics()); 981 } 982 983 /// Get a zero APInt for the given sparse attribute. 984 APInt SparseElementsAttr::getZeroAPInt() const { 985 auto eltType = getType().getElementType().cast<IntegerType>(); 986 return APInt::getNullValue(eltType.getWidth()); 987 } 988 989 /// Get a zero attribute for the given attribute type. 990 Attribute SparseElementsAttr::getZeroAttr() const { 991 auto eltType = getType().getElementType(); 992 993 // Handle floating point elements. 994 if (eltType.isa<FloatType>()) 995 return FloatAttr::get(eltType, 0); 996 997 // Otherwise, this is an integer. 998 auto intEltTy = eltType.cast<IntegerType>(); 999 if (intEltTy.getWidth() == 1) 1000 return BoolAttr::get(false, eltType.getContext()); 1001 return IntegerAttr::get(eltType, 0); 1002 } 1003 1004 /// Flatten, and return, all of the sparse indices in this attribute in 1005 /// row-major order. 1006 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1007 std::vector<ptrdiff_t> flatSparseIndices; 1008 1009 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1010 // as a 1-D index array. 1011 auto sparseIndices = getIndices(); 1012 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1013 if (sparseIndices.isSplat()) { 1014 SmallVector<uint64_t, 8> indices(getType().getRank(), 1015 *sparseIndexValues.begin()); 1016 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1017 return flatSparseIndices; 1018 } 1019 1020 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1021 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1022 size_t rank = getType().getRank(); 1023 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1024 flatSparseIndices.push_back(getFlattenedIndex( 1025 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1026 return flatSparseIndices; 1027 } 1028 1029 //===----------------------------------------------------------------------===// 1030 // NamedAttributeList 1031 //===----------------------------------------------------------------------===// 1032 1033 NamedAttributeList::NamedAttributeList(ArrayRef<NamedAttribute> attributes) { 1034 setAttrs(attributes); 1035 } 1036 1037 ArrayRef<NamedAttribute> NamedAttributeList::getAttrs() const { 1038 return attrs ? attrs.getValue() : llvm::None; 1039 } 1040 1041 /// Replace the held attributes with ones provided in 'newAttrs'. 1042 void NamedAttributeList::setAttrs(ArrayRef<NamedAttribute> attributes) { 1043 // Don't create an attribute list if there are no attributes. 1044 if (attributes.empty()) 1045 attrs = nullptr; 1046 else 1047 attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext()); 1048 } 1049 1050 /// Return the specified attribute if present, null otherwise. 1051 Attribute NamedAttributeList::get(StringRef name) const { 1052 return attrs ? attrs.get(name) : nullptr; 1053 } 1054 1055 /// Return the specified attribute if present, null otherwise. 1056 Attribute NamedAttributeList::get(Identifier name) const { 1057 return attrs ? attrs.get(name) : nullptr; 1058 } 1059 1060 /// If the an attribute exists with the specified name, change it to the new 1061 /// value. Otherwise, add a new attribute with the specified name/value. 1062 void NamedAttributeList::set(Identifier name, Attribute value) { 1063 assert(value && "attributes may never be null"); 1064 1065 // If we already have this attribute, replace it. 1066 auto origAttrs = getAttrs(); 1067 SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end()); 1068 for (auto &elt : newAttrs) 1069 if (elt.first == name) { 1070 elt.second = value; 1071 attrs = DictionaryAttr::get(newAttrs, value.getContext()); 1072 return; 1073 } 1074 1075 // Otherwise, add it. 1076 newAttrs.push_back({name, value}); 1077 attrs = DictionaryAttr::get(newAttrs, value.getContext()); 1078 } 1079 1080 /// Remove the attribute with the specified name if it exists. The return 1081 /// value indicates whether the attribute was present or not. 1082 auto NamedAttributeList::remove(Identifier name) -> RemoveResult { 1083 auto origAttrs = getAttrs(); 1084 for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { 1085 if (origAttrs[i].first == name) { 1086 // Handle the simple case of removing the only attribute in the list. 1087 if (e == 1) { 1088 attrs = nullptr; 1089 return RemoveResult::Removed; 1090 } 1091 1092 SmallVector<NamedAttribute, 8> newAttrs; 1093 newAttrs.reserve(origAttrs.size() - 1); 1094 newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); 1095 newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); 1096 attrs = DictionaryAttr::get(newAttrs, newAttrs[0].second.getContext()); 1097 return RemoveResult::Removed; 1098 } 1099 } 1100 return RemoveResult::NotFound; 1101 } 1102