1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Attributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Diagnostics.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/Function.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/IR/Types.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 23 //===----------------------------------------------------------------------===// 24 // AttributeStorage 25 //===----------------------------------------------------------------------===// 26 27 AttributeStorage::AttributeStorage(Type type) 28 : type(type.getAsOpaquePointer()) {} 29 AttributeStorage::AttributeStorage() : type(nullptr) {} 30 31 Type AttributeStorage::getType() const { 32 return Type::getFromOpaquePointer(type); 33 } 34 void AttributeStorage::setType(Type newType) { 35 type = newType.getAsOpaquePointer(); 36 } 37 38 //===----------------------------------------------------------------------===// 39 // Attribute 40 //===----------------------------------------------------------------------===// 41 42 /// Return the type of this attribute. 43 Type Attribute::getType() const { return impl->getType(); } 44 45 /// Return the context this attribute belongs to. 46 MLIRContext *Attribute::getContext() const { return getType().getContext(); } 47 48 /// Get the dialect this attribute is registered to. 49 Dialect &Attribute::getDialect() const { return impl->getDialect(); } 50 51 //===----------------------------------------------------------------------===// 52 // AffineMapAttr 53 //===----------------------------------------------------------------------===// 54 55 AffineMapAttr AffineMapAttr::get(AffineMap value) { 56 return Base::get(value.getContext(), StandardAttributes::AffineMap, value); 57 } 58 59 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } 60 61 //===----------------------------------------------------------------------===// 62 // ArrayAttr 63 //===----------------------------------------------------------------------===// 64 65 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) { 66 return Base::get(context, StandardAttributes::Array, value); 67 } 68 69 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } 70 71 //===----------------------------------------------------------------------===// 72 // BoolAttr 73 //===----------------------------------------------------------------------===// 74 75 bool BoolAttr::getValue() const { return getImpl()->value; } 76 77 //===----------------------------------------------------------------------===// 78 // DictionaryAttr 79 //===----------------------------------------------------------------------===// 80 81 /// Perform a three-way comparison between the names of the specified 82 /// NamedAttributes. 83 static int compareNamedAttributes(const NamedAttribute *lhs, 84 const NamedAttribute *rhs) { 85 return lhs->first.strref().compare(rhs->first.strref()); 86 } 87 88 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value, 89 MLIRContext *context) { 90 assert(llvm::all_of(value, 91 [](const NamedAttribute &attr) { return attr.second; }) && 92 "value cannot have null entries"); 93 94 // We need to sort the element list to canonicalize it, but we also don't want 95 // to do a ton of work in the super common case where the element list is 96 // already sorted. 97 SmallVector<NamedAttribute, 8> storage; 98 switch (value.size()) { 99 case 0: 100 break; 101 case 1: 102 // A single element is already sorted. 103 break; 104 case 2: 105 assert(value[0].first != value[1].first && 106 "DictionaryAttr element names must be unique"); 107 108 // Don't invoke a general sort for two element case. 109 if (value[0].first.strref() > value[1].first.strref()) { 110 storage.push_back(value[1]); 111 storage.push_back(value[0]); 112 value = storage; 113 } 114 break; 115 default: 116 // Check to see they are sorted already. 117 bool isSorted = true; 118 for (unsigned i = 0, e = value.size() - 1; i != e; ++i) { 119 if (value[i].first.strref() > value[i + 1].first.strref()) { 120 isSorted = false; 121 break; 122 } 123 } 124 // If not, do a general sort. 125 if (!isSorted) { 126 storage.append(value.begin(), value.end()); 127 llvm::array_pod_sort(storage.begin(), storage.end(), 128 compareNamedAttributes); 129 value = storage; 130 } 131 132 // Ensure that the attribute elements are unique. 133 assert(std::adjacent_find(value.begin(), value.end(), 134 [](NamedAttribute l, NamedAttribute r) { 135 return l.first == r.first; 136 }) == value.end() && 137 "DictionaryAttr element names must be unique"); 138 } 139 140 return Base::get(context, StandardAttributes::Dictionary, value); 141 } 142 143 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const { 144 return getImpl()->getElements(); 145 } 146 147 /// Return the specified attribute if present, null otherwise. 148 Attribute DictionaryAttr::get(StringRef name) const { 149 ArrayRef<NamedAttribute> values = getValue(); 150 auto compare = [](NamedAttribute attr, StringRef name) { 151 return attr.first.strref() < name; 152 }; 153 auto it = llvm::lower_bound(values, name, compare); 154 return it != values.end() && it->first.is(name) ? it->second : Attribute(); 155 } 156 Attribute DictionaryAttr::get(Identifier name) const { 157 for (auto elt : getValue()) 158 if (elt.first == name) 159 return elt.second; 160 return nullptr; 161 } 162 163 DictionaryAttr::iterator DictionaryAttr::begin() const { 164 return getValue().begin(); 165 } 166 DictionaryAttr::iterator DictionaryAttr::end() const { 167 return getValue().end(); 168 } 169 size_t DictionaryAttr::size() const { return getValue().size(); } 170 171 //===----------------------------------------------------------------------===// 172 // FloatAttr 173 //===----------------------------------------------------------------------===// 174 175 FloatAttr FloatAttr::get(Type type, double value) { 176 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 177 } 178 179 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { 180 return Base::getChecked(loc, type.getContext(), StandardAttributes::Float, 181 type, value); 182 } 183 184 FloatAttr FloatAttr::get(Type type, const APFloat &value) { 185 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 186 } 187 188 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { 189 return Base::getChecked(loc, type.getContext(), StandardAttributes::Float, 190 type, value); 191 } 192 193 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } 194 195 double FloatAttr::getValueAsDouble() const { 196 return getValueAsDouble(getValue()); 197 } 198 double FloatAttr::getValueAsDouble(APFloat value) { 199 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 200 bool losesInfo = false; 201 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 202 &losesInfo); 203 } 204 return value.convertToDouble(); 205 } 206 207 /// Verify construction invariants. 208 static LogicalResult verifyFloatTypeInvariants(Optional<Location> loc, 209 Type type) { 210 if (!type.isa<FloatType>()) 211 return emitOptionalError(loc, "expected floating point type"); 212 return success(); 213 } 214 215 LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc, 216 MLIRContext *ctx, 217 Type type, double value) { 218 return verifyFloatTypeInvariants(loc, type); 219 } 220 221 LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc, 222 MLIRContext *ctx, 223 Type type, 224 const APFloat &value) { 225 // Verify that the type is correct. 226 if (failed(verifyFloatTypeInvariants(loc, type))) 227 return failure(); 228 229 // Verify that the type semantics match that of the value. 230 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 231 return emitOptionalError( 232 loc, "FloatAttr type doesn't match the type implied by its value"); 233 } 234 return success(); 235 } 236 237 //===----------------------------------------------------------------------===// 238 // SymbolRefAttr 239 //===----------------------------------------------------------------------===// 240 241 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { 242 return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None) 243 .cast<FlatSymbolRefAttr>(); 244 } 245 246 SymbolRefAttr SymbolRefAttr::get(StringRef value, 247 ArrayRef<FlatSymbolRefAttr> nestedReferences, 248 MLIRContext *ctx) { 249 return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences); 250 } 251 252 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } 253 254 StringRef SymbolRefAttr::getLeafReference() const { 255 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 256 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 257 } 258 259 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const { 260 return getImpl()->getNestedRefs(); 261 } 262 263 //===----------------------------------------------------------------------===// 264 // IntegerAttr 265 //===----------------------------------------------------------------------===// 266 267 IntegerAttr IntegerAttr::get(Type type, const APInt &value) { 268 return Base::get(type.getContext(), StandardAttributes::Integer, type, value); 269 } 270 271 IntegerAttr IntegerAttr::get(Type type, int64_t value) { 272 // This uses 64 bit APInts by default for index type. 273 if (type.isIndex()) 274 return get(type, APInt(64, value)); 275 276 auto intType = type.cast<IntegerType>(); 277 return get(type, APInt(intType.getWidth(), value)); 278 } 279 280 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } 281 282 int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); } 283 284 //===----------------------------------------------------------------------===// 285 // IntegerSetAttr 286 //===----------------------------------------------------------------------===// 287 288 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { 289 return Base::get(value.getConstraint(0).getContext(), 290 StandardAttributes::IntegerSet, value); 291 } 292 293 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } 294 295 //===----------------------------------------------------------------------===// 296 // OpaqueAttr 297 //===----------------------------------------------------------------------===// 298 299 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, 300 MLIRContext *context) { 301 return Base::get(context, StandardAttributes::Opaque, dialect, attrData, 302 type); 303 } 304 305 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, 306 Type type, Location location) { 307 return Base::getChecked(location, type.getContext(), 308 StandardAttributes::Opaque, dialect, attrData, type); 309 } 310 311 /// Returns the dialect namespace of the opaque attribute. 312 Identifier OpaqueAttr::getDialectNamespace() const { 313 return getImpl()->dialectNamespace; 314 } 315 316 /// Returns the raw attribute data of the opaque attribute. 317 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } 318 319 /// Verify the construction of an opaque attribute. 320 LogicalResult OpaqueAttr::verifyConstructionInvariants(Optional<Location> loc, 321 MLIRContext *context, 322 Identifier dialect, 323 StringRef attrData, 324 Type type) { 325 if (!Dialect::isValidNamespace(dialect.strref())) 326 return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'"); 327 return success(); 328 } 329 330 //===----------------------------------------------------------------------===// 331 // StringAttr 332 //===----------------------------------------------------------------------===// 333 334 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { 335 return get(bytes, NoneType::get(context)); 336 } 337 338 /// Get an instance of a StringAttr with the given string and Type. 339 StringAttr StringAttr::get(StringRef bytes, Type type) { 340 return Base::get(type.getContext(), StandardAttributes::String, bytes, type); 341 } 342 343 StringRef StringAttr::getValue() const { return getImpl()->value; } 344 345 //===----------------------------------------------------------------------===// 346 // TypeAttr 347 //===----------------------------------------------------------------------===// 348 349 TypeAttr TypeAttr::get(Type value) { 350 return Base::get(value.getContext(), StandardAttributes::Type, value); 351 } 352 353 Type TypeAttr::getValue() const { return getImpl()->value; } 354 355 //===----------------------------------------------------------------------===// 356 // ElementsAttr 357 //===----------------------------------------------------------------------===// 358 359 ShapedType ElementsAttr::getType() const { 360 return Attribute::getType().cast<ShapedType>(); 361 } 362 363 /// Returns the number of elements held by this attribute. 364 int64_t ElementsAttr::getNumElements() const { 365 return getType().getNumElements(); 366 } 367 368 /// Return the value at the given index. If index does not refer to a valid 369 /// element, then a null attribute is returned. 370 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 371 switch (getKind()) { 372 case StandardAttributes::DenseElements: 373 return cast<DenseElementsAttr>().getValue(index); 374 case StandardAttributes::OpaqueElements: 375 return cast<OpaqueElementsAttr>().getValue(index); 376 case StandardAttributes::SparseElements: 377 return cast<SparseElementsAttr>().getValue(index); 378 default: 379 llvm_unreachable("unknown ElementsAttr kind"); 380 } 381 } 382 383 /// Return if the given 'index' refers to a valid element in this attribute. 384 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 385 auto type = getType(); 386 387 // Verify that the rank of the indices matches the held type. 388 auto rank = type.getRank(); 389 if (rank != static_cast<int64_t>(index.size())) 390 return false; 391 392 // Verify that all of the indices are within the shape dimensions. 393 auto shape = type.getShape(); 394 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 395 return static_cast<int64_t>(index[i]) < shape[i]; 396 }); 397 } 398 399 ElementsAttr 400 ElementsAttr::mapValues(Type newElementType, 401 function_ref<APInt(const APInt &)> mapping) const { 402 switch (getKind()) { 403 case StandardAttributes::DenseElements: 404 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 405 default: 406 llvm_unreachable("unsupported ElementsAttr subtype"); 407 } 408 } 409 410 ElementsAttr 411 ElementsAttr::mapValues(Type newElementType, 412 function_ref<APInt(const APFloat &)> mapping) const { 413 switch (getKind()) { 414 case StandardAttributes::DenseElements: 415 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 416 default: 417 llvm_unreachable("unsupported ElementsAttr subtype"); 418 } 419 } 420 421 /// Returns the 1 dimensional flattened row-major index from the given 422 /// multi-dimensional index. 423 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 424 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 425 auto type = getType(); 426 427 // Reduce the provided multidimensional index into a flattended 1D row-major 428 // index. 429 auto rank = type.getRank(); 430 auto shape = type.getShape(); 431 uint64_t valueIndex = 0; 432 uint64_t dimMultiplier = 1; 433 for (int i = rank - 1; i >= 0; --i) { 434 valueIndex += index[i] * dimMultiplier; 435 dimMultiplier *= shape[i]; 436 } 437 return valueIndex; 438 } 439 440 //===----------------------------------------------------------------------===// 441 // DenseElementAttr Utilities 442 //===----------------------------------------------------------------------===// 443 444 static size_t getDenseElementBitwidth(Type eltType) { 445 // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored 446 // with double semantics. 447 return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth(); 448 } 449 450 /// Get the bitwidth of a dense element type within the buffer. 451 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 452 static size_t getDenseElementStorageWidth(size_t origWidth) { 453 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 454 } 455 456 /// Set a bit to a specific value. 457 static void setBit(char *rawData, size_t bitPos, bool value) { 458 if (value) 459 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 460 else 461 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 462 } 463 464 /// Return the value of the specified bit. 465 static bool getBit(const char *rawData, size_t bitPos) { 466 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 467 } 468 469 /// Writes value to the bit position `bitPos` in array `rawData`. 470 static void writeBits(char *rawData, size_t bitPos, APInt value) { 471 size_t bitWidth = value.getBitWidth(); 472 473 // If the bitwidth is 1 we just toggle the specific bit. 474 if (bitWidth == 1) 475 return setBit(rawData, bitPos, value.isOneValue()); 476 477 // Otherwise, the bit position is guaranteed to be byte aligned. 478 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 479 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 480 llvm::divideCeil(bitWidth, CHAR_BIT), 481 rawData + (bitPos / CHAR_BIT)); 482 } 483 484 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 485 /// `rawData`. 486 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 487 // Handle a boolean bit position. 488 if (bitWidth == 1) 489 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 490 491 // Otherwise, the bit position must be 8-bit aligned. 492 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 493 APInt result(bitWidth, 0); 494 std::copy_n( 495 rawData + (bitPos / CHAR_BIT), llvm::divideCeil(bitWidth, CHAR_BIT), 496 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 497 return result; 498 } 499 500 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the 501 /// same element count as 'type'. 502 template <typename Values> 503 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 504 return (values.size() == 1) || 505 (type.getNumElements() == static_cast<int64_t>(values.size())); 506 } 507 508 //===----------------------------------------------------------------------===// 509 // DenseElementAttr Iterators 510 //===----------------------------------------------------------------------===// 511 512 /// Constructs a new iterator. 513 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 514 DenseElementsAttr attr, size_t index) 515 : indexed_accessor_iterator<AttributeElementIterator, const void *, 516 Attribute, Attribute, Attribute>( 517 attr.getAsOpaquePointer(), index) {} 518 519 /// Accesses the Attribute value at this iterator position. 520 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 521 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 522 Type eltTy = owner.getType().getElementType(); 523 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) { 524 if (intEltTy.getWidth() == 1) 525 return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(), 526 owner.getContext()); 527 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 528 } 529 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 530 IntElementIterator intIt(owner, index); 531 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 532 return FloatAttr::get(eltTy, *floatIt); 533 } 534 llvm_unreachable("unexpected element type"); 535 } 536 537 /// Constructs a new iterator. 538 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 539 DenseElementsAttr attr, size_t dataIndex) 540 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 541 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 542 543 /// Accesses the bool value at this iterator position. 544 bool DenseElementsAttr::BoolElementIterator::operator*() const { 545 return getBit(getData(), getDataIndex()); 546 } 547 548 /// Constructs a new iterator. 549 DenseElementsAttr::IntElementIterator::IntElementIterator( 550 DenseElementsAttr attr, size_t dataIndex) 551 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 552 attr.getRawData().data(), attr.isSplat(), dataIndex), 553 bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {} 554 555 /// Accesses the raw APInt value at this iterator position. 556 APInt DenseElementsAttr::IntElementIterator::operator*() const { 557 return readBits(getData(), 558 getDataIndex() * getDenseElementStorageWidth(bitWidth), 559 bitWidth); 560 } 561 562 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 563 const llvm::fltSemantics &smt, IntElementIterator it) 564 : llvm::mapped_iterator<IntElementIterator, 565 std::function<APFloat(const APInt &)>>( 566 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 567 568 //===----------------------------------------------------------------------===// 569 // DenseElementsAttr 570 //===----------------------------------------------------------------------===// 571 572 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 573 ArrayRef<Attribute> values) { 574 assert(type.getElementType().isIntOrFloat() && 575 "expected int or float element type"); 576 assert(hasSameElementsOrSplat(type, values)); 577 578 auto eltType = type.getElementType(); 579 size_t bitWidth = getDenseElementBitwidth(eltType); 580 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 581 582 // Compress the attribute values into a character buffer. 583 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 584 values.size()); 585 APInt intVal; 586 for (unsigned i = 0, e = values.size(); i < e; ++i) { 587 assert(eltType == values[i].getType() && 588 "expected attribute value to have element type"); 589 590 switch (eltType.getKind()) { 591 case StandardTypes::BF16: 592 case StandardTypes::F16: 593 case StandardTypes::F32: 594 case StandardTypes::F64: 595 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 596 break; 597 case StandardTypes::Integer: 598 intVal = values[i].isa<BoolAttr>() 599 ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0) 600 : values[i].cast<IntegerAttr>().getValue(); 601 break; 602 default: 603 llvm_unreachable("unexpected element type"); 604 } 605 assert(intVal.getBitWidth() == bitWidth && 606 "expected value to have same bitwidth as element type"); 607 writeBits(data.data(), i * storageBitWidth, intVal); 608 } 609 return getRaw(type, data, /*isSplat=*/(values.size() == 1)); 610 } 611 612 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 613 ArrayRef<bool> values) { 614 assert(hasSameElementsOrSplat(type, values)); 615 assert(type.getElementType().isInteger(1)); 616 617 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 618 for (int i = 0, e = values.size(); i != e; ++i) 619 setBit(buff.data(), i, values[i]); 620 return getRaw(type, buff, /*isSplat=*/(values.size() == 1)); 621 } 622 623 /// Constructs a dense integer elements attribute from an array of APInt 624 /// values. Each APInt value is expected to have the same bitwidth as the 625 /// element type of 'type'. 626 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 627 ArrayRef<APInt> values) { 628 assert(type.getElementType().isa<IntegerType>()); 629 return getRaw(type, values); 630 } 631 632 // Constructs a dense float elements attribute from an array of APFloat 633 // values. Each APFloat value is expected to have the same bitwidth as the 634 // element type of 'type'. 635 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 636 ArrayRef<APFloat> values) { 637 assert(type.getElementType().isa<FloatType>()); 638 639 // Convert the APFloat values to APInt and create a dense elements attribute. 640 std::vector<APInt> intValues(values.size()); 641 for (unsigned i = 0, e = values.size(); i != e; ++i) 642 intValues[i] = values[i].bitcastToAPInt(); 643 return getRaw(type, intValues); 644 } 645 646 // Constructs a dense elements attribute from an array of raw APInt values. 647 // Each APInt value is expected to have the same bitwidth as the element type 648 // of 'type'. 649 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, 650 ArrayRef<APInt> values) { 651 assert(hasSameElementsOrSplat(type, values)); 652 653 size_t bitWidth = getDenseElementBitwidth(type.getElementType()); 654 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 655 std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 656 values.size()); 657 for (unsigned i = 0, e = values.size(); i != e; ++i) { 658 assert(values[i].getBitWidth() == bitWidth); 659 writeBits(elementData.data(), i * storageBitWidth, values[i]); 660 } 661 return getRaw(type, elementData, /*isSplat=*/(values.size() == 1)); 662 } 663 664 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, 665 ArrayRef<char> data, bool isSplat) { 666 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 667 "type must be ranked tensor or vector"); 668 assert(type.hasStaticShape() && "type must have static shape"); 669 return Base::get(type.getContext(), StandardAttributes::DenseElements, type, 670 data, isSplat); 671 } 672 673 /// Check the information for a c++ data type, check if this type is valid for 674 /// the current attribute. This method is used to verify specific type 675 /// invariants that the templatized 'getValues' method cannot. 676 static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, 677 bool isInt) { 678 // Make sure that the data element size is the same as the type element width. 679 if (getDenseElementBitwidth(type.getElementType()) != 680 static_cast<size_t>(dataEltSize * CHAR_BIT)) 681 return false; 682 683 // Check that the element type is valid. 684 return isInt ? type.getElementType().isa<IntegerType>() 685 : type.getElementType().isa<FloatType>(); 686 } 687 688 /// Overload of the 'getRaw' method that asserts that the given type is of 689 /// integer type. This method is used to verify type invariants that the 690 /// templatized 'get' method cannot. 691 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 692 ArrayRef<char> data, 693 int64_t dataEltSize, 694 bool isInt) { 695 assert(::isValidIntOrFloat(type, dataEltSize, isInt)); 696 697 int64_t numElements = data.size() / dataEltSize; 698 assert(numElements == 1 || numElements == type.getNumElements()); 699 return getRaw(type, data, /*isSplat=*/numElements == 1); 700 } 701 702 /// A method used to verify specific type invariants that the templatized 'get' 703 /// method cannot. 704 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, 705 bool isInt) const { 706 return ::isValidIntOrFloat(getType(), dataEltSize, isInt); 707 } 708 709 /// Return the raw storage data held by this attribute. 710 ArrayRef<char> DenseElementsAttr::getRawData() const { 711 return static_cast<ImplType *>(impl)->data; 712 } 713 714 /// Returns if this attribute corresponds to a splat, i.e. if all element 715 /// values are the same. 716 bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; } 717 718 /// Return the held element values as a range of Attributes. 719 auto DenseElementsAttr::getAttributeValues() const 720 -> llvm::iterator_range<AttributeElementIterator> { 721 return {attr_value_begin(), attr_value_end()}; 722 } 723 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 724 return AttributeElementIterator(*this, 0); 725 } 726 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 727 return AttributeElementIterator(*this, getNumElements()); 728 } 729 730 /// Return the held element values as a range of bool. The element type of 731 /// this attribute must be of integer type of bitwidth 1. 732 auto DenseElementsAttr::getBoolValues() const 733 -> llvm::iterator_range<BoolElementIterator> { 734 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 735 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 736 (void)eltType; 737 return {BoolElementIterator(*this, 0), 738 BoolElementIterator(*this, getNumElements())}; 739 } 740 741 /// Return the held element values as a range of APInts. The element type of 742 /// this attribute must be of integer type. 743 auto DenseElementsAttr::getIntValues() const 744 -> llvm::iterator_range<IntElementIterator> { 745 assert(getType().getElementType().isa<IntegerType>() && 746 "expected integer type"); 747 return {raw_int_begin(), raw_int_end()}; 748 } 749 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 750 assert(getType().getElementType().isa<IntegerType>() && 751 "expected integer type"); 752 return raw_int_begin(); 753 } 754 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 755 assert(getType().getElementType().isa<IntegerType>() && 756 "expected integer type"); 757 return raw_int_end(); 758 } 759 760 /// Return the held element values as a range of APFloat. The element type of 761 /// this attribute must be of float type. 762 auto DenseElementsAttr::getFloatValues() const 763 -> llvm::iterator_range<FloatElementIterator> { 764 auto elementType = getType().getElementType().cast<FloatType>(); 765 assert(elementType.isa<FloatType>() && "expected float type"); 766 const auto &elementSemantics = elementType.getFloatSemantics(); 767 return {FloatElementIterator(elementSemantics, raw_int_begin()), 768 FloatElementIterator(elementSemantics, raw_int_end())}; 769 } 770 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 771 return getFloatValues().begin(); 772 } 773 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 774 return getFloatValues().end(); 775 } 776 777 /// Return a new DenseElementsAttr that has the same data as the current 778 /// attribute, but has been reshaped to 'newType'. The new type must have the 779 /// same total number of elements as well as element type. 780 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 781 ShapedType curType = getType(); 782 if (curType == newType) 783 return *this; 784 785 (void)curType; 786 assert(newType.getElementType() == curType.getElementType() && 787 "expected the same element type"); 788 assert(newType.getNumElements() == curType.getNumElements() && 789 "expected the same number of elements"); 790 return getRaw(newType, getRawData(), isSplat()); 791 } 792 793 DenseElementsAttr 794 DenseElementsAttr::mapValues(Type newElementType, 795 function_ref<APInt(const APInt &)> mapping) const { 796 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 797 } 798 799 DenseElementsAttr DenseElementsAttr::mapValues( 800 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 801 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 802 } 803 804 //===----------------------------------------------------------------------===// 805 // DenseFPElementsAttr 806 //===----------------------------------------------------------------------===// 807 808 template <typename Fn, typename Attr> 809 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 810 Type newElementType, 811 llvm::SmallVectorImpl<char> &data) { 812 size_t bitWidth = getDenseElementBitwidth(newElementType); 813 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 814 815 ShapedType newArrayType; 816 if (inType.isa<RankedTensorType>()) 817 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 818 else if (inType.isa<UnrankedTensorType>()) 819 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 820 else if (inType.isa<VectorType>()) 821 newArrayType = VectorType::get(inType.getShape(), newElementType); 822 else 823 assert(newArrayType && "Unhandled tensor type"); 824 825 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 826 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 827 828 // Functor used to process a single element value of the attribute. 829 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 830 auto newInt = mapping(value); 831 assert(newInt.getBitWidth() == bitWidth); 832 writeBits(data.data(), index * storageBitWidth, newInt); 833 }; 834 835 // Check for the splat case. 836 if (attr.isSplat()) { 837 processElt(*attr.begin(), /*index=*/0); 838 return newArrayType; 839 } 840 841 // Otherwise, process all of the element values. 842 uint64_t elementIdx = 0; 843 for (auto value : attr) 844 processElt(value, elementIdx++); 845 return newArrayType; 846 } 847 848 DenseElementsAttr DenseFPElementsAttr::mapValues( 849 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 850 llvm::SmallVector<char, 8> elementData; 851 auto newArrayType = 852 mappingHelper(mapping, *this, getType(), newElementType, elementData); 853 854 return getRaw(newArrayType, elementData, isSplat()); 855 } 856 857 /// Method for supporting type inquiry through isa, cast and dyn_cast. 858 bool DenseFPElementsAttr::classof(Attribute attr) { 859 return attr.isa<DenseElementsAttr>() && 860 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 861 } 862 863 //===----------------------------------------------------------------------===// 864 // DenseIntElementsAttr 865 //===----------------------------------------------------------------------===// 866 867 DenseElementsAttr DenseIntElementsAttr::mapValues( 868 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 869 llvm::SmallVector<char, 8> elementData; 870 auto newArrayType = 871 mappingHelper(mapping, *this, getType(), newElementType, elementData); 872 873 return getRaw(newArrayType, elementData, isSplat()); 874 } 875 876 /// Method for supporting type inquiry through isa, cast and dyn_cast. 877 bool DenseIntElementsAttr::classof(Attribute attr) { 878 return attr.isa<DenseElementsAttr>() && 879 attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>(); 880 } 881 882 //===----------------------------------------------------------------------===// 883 // OpaqueElementsAttr 884 //===----------------------------------------------------------------------===// 885 886 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, 887 StringRef bytes) { 888 assert(TensorType::isValidElementType(type.getElementType()) && 889 "Input element type should be a valid tensor element type"); 890 return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type, 891 dialect, bytes); 892 } 893 894 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } 895 896 /// Return the value at the given index. If index does not refer to a valid 897 /// element, then a null attribute is returned. 898 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 899 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 900 if (Dialect *dialect = getDialect()) 901 return dialect->extractElementHook(*this, index); 902 return Attribute(); 903 } 904 905 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } 906 907 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 908 if (auto *d = getDialect()) 909 return d->decodeHook(*this, result); 910 return true; 911 } 912 913 //===----------------------------------------------------------------------===// 914 // SparseElementsAttr 915 //===----------------------------------------------------------------------===// 916 917 SparseElementsAttr SparseElementsAttr::get(ShapedType type, 918 DenseElementsAttr indices, 919 DenseElementsAttr values) { 920 assert(indices.getType().getElementType().isInteger(64) && 921 "expected sparse indices to be 64-bit integer values"); 922 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 923 "type must be ranked tensor or vector"); 924 assert(type.hasStaticShape() && "type must have static shape"); 925 return Base::get(type.getContext(), StandardAttributes::SparseElements, type, 926 indices.cast<DenseIntElementsAttr>(), values); 927 } 928 929 DenseIntElementsAttr SparseElementsAttr::getIndices() const { 930 return getImpl()->indices; 931 } 932 933 DenseElementsAttr SparseElementsAttr::getValues() const { 934 return getImpl()->values; 935 } 936 937 /// Return the value of the element at the given index. 938 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 939 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 940 auto type = getType(); 941 942 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 943 // as a 1-D index array. 944 auto sparseIndices = getIndices(); 945 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 946 947 // Check to see if the indices are a splat. 948 if (sparseIndices.isSplat()) { 949 // If the index is also not a splat of the index value, we know that the 950 // value is zero. 951 auto splatIndex = *sparseIndexValues.begin(); 952 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 953 return getZeroAttr(); 954 955 // If the indices are a splat, we also expect the values to be a splat. 956 assert(getValues().isSplat() && "expected splat values"); 957 return getValues().getSplatValue(); 958 } 959 960 // Build a mapping between known indices and the offset of the stored element. 961 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 962 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 963 size_t rank = type.getRank(); 964 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 965 mappedIndices.try_emplace( 966 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 967 968 // Look for the provided index key within the mapped indices. If the provided 969 // index is not found, then return a zero attribute. 970 auto it = mappedIndices.find(index); 971 if (it == mappedIndices.end()) 972 return getZeroAttr(); 973 974 // Otherwise, return the held sparse value element. 975 return getValues().getValue(it->second); 976 } 977 978 /// Get a zero APFloat for the given sparse attribute. 979 APFloat SparseElementsAttr::getZeroAPFloat() const { 980 auto eltType = getType().getElementType().cast<FloatType>(); 981 return APFloat(eltType.getFloatSemantics()); 982 } 983 984 /// Get a zero APInt for the given sparse attribute. 985 APInt SparseElementsAttr::getZeroAPInt() const { 986 auto eltType = getType().getElementType().cast<IntegerType>(); 987 return APInt::getNullValue(eltType.getWidth()); 988 } 989 990 /// Get a zero attribute for the given attribute type. 991 Attribute SparseElementsAttr::getZeroAttr() const { 992 auto eltType = getType().getElementType(); 993 994 // Handle floating point elements. 995 if (eltType.isa<FloatType>()) 996 return FloatAttr::get(eltType, 0); 997 998 // Otherwise, this is an integer. 999 auto intEltTy = eltType.cast<IntegerType>(); 1000 if (intEltTy.getWidth() == 1) 1001 return BoolAttr::get(false, eltType.getContext()); 1002 return IntegerAttr::get(eltType, 0); 1003 } 1004 1005 /// Flatten, and return, all of the sparse indices in this attribute in 1006 /// row-major order. 1007 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1008 std::vector<ptrdiff_t> flatSparseIndices; 1009 1010 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1011 // as a 1-D index array. 1012 auto sparseIndices = getIndices(); 1013 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1014 if (sparseIndices.isSplat()) { 1015 SmallVector<uint64_t, 8> indices(getType().getRank(), 1016 *sparseIndexValues.begin()); 1017 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1018 return flatSparseIndices; 1019 } 1020 1021 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1022 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1023 size_t rank = getType().getRank(); 1024 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1025 flatSparseIndices.push_back(getFlattenedIndex( 1026 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1027 return flatSparseIndices; 1028 } 1029 1030 //===----------------------------------------------------------------------===// 1031 // NamedAttributeList 1032 //===----------------------------------------------------------------------===// 1033 1034 NamedAttributeList::NamedAttributeList(ArrayRef<NamedAttribute> attributes) { 1035 setAttrs(attributes); 1036 } 1037 1038 ArrayRef<NamedAttribute> NamedAttributeList::getAttrs() const { 1039 return attrs ? attrs.getValue() : llvm::None; 1040 } 1041 1042 /// Replace the held attributes with ones provided in 'newAttrs'. 1043 void NamedAttributeList::setAttrs(ArrayRef<NamedAttribute> attributes) { 1044 // Don't create an attribute list if there are no attributes. 1045 if (attributes.empty()) 1046 attrs = nullptr; 1047 else 1048 attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext()); 1049 } 1050 1051 /// Return the specified attribute if present, null otherwise. 1052 Attribute NamedAttributeList::get(StringRef name) const { 1053 return attrs ? attrs.get(name) : nullptr; 1054 } 1055 1056 /// Return the specified attribute if present, null otherwise. 1057 Attribute NamedAttributeList::get(Identifier name) const { 1058 return attrs ? attrs.get(name) : nullptr; 1059 } 1060 1061 /// If the an attribute exists with the specified name, change it to the new 1062 /// value. Otherwise, add a new attribute with the specified name/value. 1063 void NamedAttributeList::set(Identifier name, Attribute value) { 1064 assert(value && "attributes may never be null"); 1065 1066 // If we already have this attribute, replace it. 1067 auto origAttrs = getAttrs(); 1068 SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end()); 1069 for (auto &elt : newAttrs) 1070 if (elt.first == name) { 1071 elt.second = value; 1072 attrs = DictionaryAttr::get(newAttrs, value.getContext()); 1073 return; 1074 } 1075 1076 // Otherwise, add it. 1077 newAttrs.push_back({name, value}); 1078 attrs = DictionaryAttr::get(newAttrs, value.getContext()); 1079 } 1080 1081 /// Remove the attribute with the specified name if it exists. The return 1082 /// value indicates whether the attribute was present or not. 1083 auto NamedAttributeList::remove(Identifier name) -> RemoveResult { 1084 auto origAttrs = getAttrs(); 1085 for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { 1086 if (origAttrs[i].first == name) { 1087 // Handle the simple case of removing the only attribute in the list. 1088 if (e == 1) { 1089 attrs = nullptr; 1090 return RemoveResult::Removed; 1091 } 1092 1093 SmallVector<NamedAttribute, 8> newAttrs; 1094 newAttrs.reserve(origAttrs.size() - 1); 1095 newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); 1096 newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); 1097 attrs = DictionaryAttr::get(newAttrs, newAttrs[0].second.getContext()); 1098 return RemoveResult::Removed; 1099 } 1100 } 1101 return RemoveResult::NotFound; 1102 } 1103