1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Attributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Diagnostics.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/Function.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/IR/Types.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 23 //===----------------------------------------------------------------------===// 24 // AttributeStorage 25 //===----------------------------------------------------------------------===// 26 27 AttributeStorage::AttributeStorage(Type type) 28 : type(type.getAsOpaquePointer()) {} 29 AttributeStorage::AttributeStorage() : type(nullptr) {} 30 31 Type AttributeStorage::getType() const { 32 return Type::getFromOpaquePointer(type); 33 } 34 void AttributeStorage::setType(Type newType) { 35 type = newType.getAsOpaquePointer(); 36 } 37 38 //===----------------------------------------------------------------------===// 39 // Attribute 40 //===----------------------------------------------------------------------===// 41 42 /// Return the type of this attribute. 43 Type Attribute::getType() const { return impl->getType(); } 44 45 /// Return the context this attribute belongs to. 46 MLIRContext *Attribute::getContext() const { return getType().getContext(); } 47 48 /// Get the dialect this attribute is registered to. 49 Dialect &Attribute::getDialect() const { return impl->getDialect(); } 50 51 //===----------------------------------------------------------------------===// 52 // AffineMapAttr 53 //===----------------------------------------------------------------------===// 54 55 AffineMapAttr AffineMapAttr::get(AffineMap value) { 56 return Base::get(value.getContext(), StandardAttributes::AffineMap, value); 57 } 58 59 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } 60 61 //===----------------------------------------------------------------------===// 62 // ArrayAttr 63 //===----------------------------------------------------------------------===// 64 65 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) { 66 return Base::get(context, StandardAttributes::Array, value); 67 } 68 69 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } 70 71 //===----------------------------------------------------------------------===// 72 // BoolAttr 73 //===----------------------------------------------------------------------===// 74 75 bool BoolAttr::getValue() const { return getImpl()->value; } 76 77 //===----------------------------------------------------------------------===// 78 // DictionaryAttr 79 //===----------------------------------------------------------------------===// 80 81 /// Perform a three-way comparison between the names of the specified 82 /// NamedAttributes. 83 static int compareNamedAttributes(const NamedAttribute *lhs, 84 const NamedAttribute *rhs) { 85 return lhs->first.strref().compare(rhs->first.strref()); 86 } 87 88 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value, 89 MLIRContext *context) { 90 assert(llvm::all_of(value, 91 [](const NamedAttribute &attr) { return attr.second; }) && 92 "value cannot have null entries"); 93 94 // We need to sort the element list to canonicalize it, but we also don't want 95 // to do a ton of work in the super common case where the element list is 96 // already sorted. 97 SmallVector<NamedAttribute, 8> storage; 98 switch (value.size()) { 99 case 0: 100 break; 101 case 1: 102 // A single element is already sorted. 103 break; 104 case 2: 105 assert(value[0].first != value[1].first && 106 "DictionaryAttr element names must be unique"); 107 108 // Don't invoke a general sort for two element case. 109 if (value[0].first.strref() > value[1].first.strref()) { 110 storage.push_back(value[1]); 111 storage.push_back(value[0]); 112 value = storage; 113 } 114 break; 115 default: 116 // Check to see they are sorted already. 117 bool isSorted = true; 118 for (unsigned i = 0, e = value.size() - 1; i != e; ++i) { 119 if (value[i].first.strref() > value[i + 1].first.strref()) { 120 isSorted = false; 121 break; 122 } 123 } 124 // If not, do a general sort. 125 if (!isSorted) { 126 storage.append(value.begin(), value.end()); 127 llvm::array_pod_sort(storage.begin(), storage.end(), 128 compareNamedAttributes); 129 value = storage; 130 } 131 132 // Ensure that the attribute elements are unique. 133 assert(std::adjacent_find(value.begin(), value.end(), 134 [](NamedAttribute l, NamedAttribute r) { 135 return l.first == r.first; 136 }) == value.end() && 137 "DictionaryAttr element names must be unique"); 138 } 139 140 return Base::get(context, StandardAttributes::Dictionary, value); 141 } 142 143 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const { 144 return getImpl()->getElements(); 145 } 146 147 /// Return the specified attribute if present, null otherwise. 148 Attribute DictionaryAttr::get(StringRef name) const { 149 ArrayRef<NamedAttribute> values = getValue(); 150 auto compare = [](NamedAttribute attr, StringRef name) { 151 return attr.first.strref() < name; 152 }; 153 auto it = llvm::lower_bound(values, name, compare); 154 return it != values.end() && it->first.is(name) ? it->second : Attribute(); 155 } 156 Attribute DictionaryAttr::get(Identifier name) const { 157 for (auto elt : getValue()) 158 if (elt.first == name) 159 return elt.second; 160 return nullptr; 161 } 162 163 DictionaryAttr::iterator DictionaryAttr::begin() const { 164 return getValue().begin(); 165 } 166 DictionaryAttr::iterator DictionaryAttr::end() const { 167 return getValue().end(); 168 } 169 size_t DictionaryAttr::size() const { return getValue().size(); } 170 171 //===----------------------------------------------------------------------===// 172 // FloatAttr 173 //===----------------------------------------------------------------------===// 174 175 FloatAttr FloatAttr::get(Type type, double value) { 176 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 177 } 178 179 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { 180 return Base::getChecked(loc, type.getContext(), StandardAttributes::Float, 181 type, value); 182 } 183 184 FloatAttr FloatAttr::get(Type type, const APFloat &value) { 185 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 186 } 187 188 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { 189 return Base::getChecked(loc, type.getContext(), StandardAttributes::Float, 190 type, value); 191 } 192 193 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } 194 195 double FloatAttr::getValueAsDouble() const { 196 return getValueAsDouble(getValue()); 197 } 198 double FloatAttr::getValueAsDouble(APFloat value) { 199 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 200 bool losesInfo = false; 201 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 202 &losesInfo); 203 } 204 return value.convertToDouble(); 205 } 206 207 /// Verify construction invariants. 208 static LogicalResult verifyFloatTypeInvariants(Optional<Location> loc, 209 Type type) { 210 if (!type.isa<FloatType>()) 211 return emitOptionalError(loc, "expected floating point type"); 212 return success(); 213 } 214 215 LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc, 216 MLIRContext *ctx, 217 Type type, double value) { 218 return verifyFloatTypeInvariants(loc, type); 219 } 220 221 LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc, 222 MLIRContext *ctx, 223 Type type, 224 const APFloat &value) { 225 // Verify that the type is correct. 226 if (failed(verifyFloatTypeInvariants(loc, type))) 227 return failure(); 228 229 // Verify that the type semantics match that of the value. 230 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 231 return emitOptionalError( 232 loc, "FloatAttr type doesn't match the type implied by its value"); 233 } 234 return success(); 235 } 236 237 //===----------------------------------------------------------------------===// 238 // SymbolRefAttr 239 //===----------------------------------------------------------------------===// 240 241 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { 242 return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None) 243 .cast<FlatSymbolRefAttr>(); 244 } 245 246 SymbolRefAttr SymbolRefAttr::get(StringRef value, 247 ArrayRef<FlatSymbolRefAttr> nestedReferences, 248 MLIRContext *ctx) { 249 return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences); 250 } 251 252 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } 253 254 StringRef SymbolRefAttr::getLeafReference() const { 255 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 256 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); 257 } 258 259 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const { 260 return getImpl()->getNestedRefs(); 261 } 262 263 //===----------------------------------------------------------------------===// 264 // IntegerAttr 265 //===----------------------------------------------------------------------===// 266 267 IntegerAttr IntegerAttr::get(Type type, const APInt &value) { 268 return Base::get(type.getContext(), StandardAttributes::Integer, type, value); 269 } 270 271 IntegerAttr IntegerAttr::get(Type type, int64_t value) { 272 // This uses 64 bit APInts by default for index type. 273 if (type.isIndex()) 274 return get(type, APInt(64, value)); 275 276 auto intType = type.cast<IntegerType>(); 277 return get(type, APInt(intType.getWidth(), value)); 278 } 279 280 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } 281 282 int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); } 283 284 static LogicalResult verifyIntegerTypeInvariants(Optional<Location> loc, 285 Type type) { 286 if (type.isa<IntegerType>() || type.isa<IndexType>()) 287 return success(); 288 return emitOptionalError(loc, "expected integer or index type"); 289 } 290 291 LogicalResult verifyConstructionInvariants(Optional<Location> loc, 292 MLIRContext *ctx, Type type, 293 int64_t value) { 294 return verifyIntegerTypeInvariants(loc, type); 295 } 296 297 LogicalResult IntegerAttr::verifyConstructionInvariants(Optional<Location> loc, 298 MLIRContext *ctx, 299 Type type, 300 const APInt &value) { 301 if (failed(verifyIntegerTypeInvariants(loc, type))) 302 return failure(); 303 if (auto integerType = type.dyn_cast<IntegerType>()) 304 if (integerType.getWidth() != value.getBitWidth()) 305 return emitOptionalError( 306 loc, "integer type bit width (", integerType.getWidth(), 307 ") doesn't match value bit width (", value.getBitWidth(), ")"); 308 return success(); 309 } 310 311 //===----------------------------------------------------------------------===// 312 // IntegerSetAttr 313 //===----------------------------------------------------------------------===// 314 315 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { 316 return Base::get(value.getConstraint(0).getContext(), 317 StandardAttributes::IntegerSet, value); 318 } 319 320 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } 321 322 //===----------------------------------------------------------------------===// 323 // OpaqueAttr 324 //===----------------------------------------------------------------------===// 325 326 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, 327 MLIRContext *context) { 328 return Base::get(context, StandardAttributes::Opaque, dialect, attrData, 329 type); 330 } 331 332 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, 333 Type type, Location location) { 334 return Base::getChecked(location, type.getContext(), 335 StandardAttributes::Opaque, dialect, attrData, type); 336 } 337 338 /// Returns the dialect namespace of the opaque attribute. 339 Identifier OpaqueAttr::getDialectNamespace() const { 340 return getImpl()->dialectNamespace; 341 } 342 343 /// Returns the raw attribute data of the opaque attribute. 344 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } 345 346 /// Verify the construction of an opaque attribute. 347 LogicalResult OpaqueAttr::verifyConstructionInvariants(Optional<Location> loc, 348 MLIRContext *context, 349 Identifier dialect, 350 StringRef attrData, 351 Type type) { 352 if (!Dialect::isValidNamespace(dialect.strref())) 353 return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'"); 354 return success(); 355 } 356 357 //===----------------------------------------------------------------------===// 358 // StringAttr 359 //===----------------------------------------------------------------------===// 360 361 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { 362 return get(bytes, NoneType::get(context)); 363 } 364 365 /// Get an instance of a StringAttr with the given string and Type. 366 StringAttr StringAttr::get(StringRef bytes, Type type) { 367 return Base::get(type.getContext(), StandardAttributes::String, bytes, type); 368 } 369 370 StringRef StringAttr::getValue() const { return getImpl()->value; } 371 372 //===----------------------------------------------------------------------===// 373 // TypeAttr 374 //===----------------------------------------------------------------------===// 375 376 TypeAttr TypeAttr::get(Type value) { 377 return Base::get(value.getContext(), StandardAttributes::Type, value); 378 } 379 380 Type TypeAttr::getValue() const { return getImpl()->value; } 381 382 //===----------------------------------------------------------------------===// 383 // ElementsAttr 384 //===----------------------------------------------------------------------===// 385 386 ShapedType ElementsAttr::getType() const { 387 return Attribute::getType().cast<ShapedType>(); 388 } 389 390 /// Returns the number of elements held by this attribute. 391 int64_t ElementsAttr::getNumElements() const { 392 return getType().getNumElements(); 393 } 394 395 /// Return the value at the given index. If index does not refer to a valid 396 /// element, then a null attribute is returned. 397 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 398 switch (getKind()) { 399 case StandardAttributes::DenseElements: 400 return cast<DenseElementsAttr>().getValue(index); 401 case StandardAttributes::OpaqueElements: 402 return cast<OpaqueElementsAttr>().getValue(index); 403 case StandardAttributes::SparseElements: 404 return cast<SparseElementsAttr>().getValue(index); 405 default: 406 llvm_unreachable("unknown ElementsAttr kind"); 407 } 408 } 409 410 /// Return if the given 'index' refers to a valid element in this attribute. 411 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 412 auto type = getType(); 413 414 // Verify that the rank of the indices matches the held type. 415 auto rank = type.getRank(); 416 if (rank != static_cast<int64_t>(index.size())) 417 return false; 418 419 // Verify that all of the indices are within the shape dimensions. 420 auto shape = type.getShape(); 421 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 422 return static_cast<int64_t>(index[i]) < shape[i]; 423 }); 424 } 425 426 ElementsAttr 427 ElementsAttr::mapValues(Type newElementType, 428 function_ref<APInt(const APInt &)> mapping) const { 429 switch (getKind()) { 430 case StandardAttributes::DenseElements: 431 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 432 default: 433 llvm_unreachable("unsupported ElementsAttr subtype"); 434 } 435 } 436 437 ElementsAttr 438 ElementsAttr::mapValues(Type newElementType, 439 function_ref<APInt(const APFloat &)> mapping) const { 440 switch (getKind()) { 441 case StandardAttributes::DenseElements: 442 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 443 default: 444 llvm_unreachable("unsupported ElementsAttr subtype"); 445 } 446 } 447 448 /// Returns the 1 dimensional flattened row-major index from the given 449 /// multi-dimensional index. 450 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 451 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 452 auto type = getType(); 453 454 // Reduce the provided multidimensional index into a flattended 1D row-major 455 // index. 456 auto rank = type.getRank(); 457 auto shape = type.getShape(); 458 uint64_t valueIndex = 0; 459 uint64_t dimMultiplier = 1; 460 for (int i = rank - 1; i >= 0; --i) { 461 valueIndex += index[i] * dimMultiplier; 462 dimMultiplier *= shape[i]; 463 } 464 return valueIndex; 465 } 466 467 //===----------------------------------------------------------------------===// 468 // DenseElementAttr Utilities 469 //===----------------------------------------------------------------------===// 470 471 static size_t getDenseElementBitwidth(Type eltType) { 472 // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored 473 // with double semantics. 474 return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth(); 475 } 476 477 /// Get the bitwidth of a dense element type within the buffer. 478 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 479 static size_t getDenseElementStorageWidth(size_t origWidth) { 480 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 481 } 482 483 /// Set a bit to a specific value. 484 static void setBit(char *rawData, size_t bitPos, bool value) { 485 if (value) 486 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 487 else 488 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 489 } 490 491 /// Return the value of the specified bit. 492 static bool getBit(const char *rawData, size_t bitPos) { 493 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 494 } 495 496 /// Writes value to the bit position `bitPos` in array `rawData`. 497 static void writeBits(char *rawData, size_t bitPos, APInt value) { 498 size_t bitWidth = value.getBitWidth(); 499 500 // If the bitwidth is 1 we just toggle the specific bit. 501 if (bitWidth == 1) 502 return setBit(rawData, bitPos, value.isOneValue()); 503 504 // Otherwise, the bit position is guaranteed to be byte aligned. 505 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 506 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 507 llvm::divideCeil(bitWidth, CHAR_BIT), 508 rawData + (bitPos / CHAR_BIT)); 509 } 510 511 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 512 /// `rawData`. 513 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 514 // Handle a boolean bit position. 515 if (bitWidth == 1) 516 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 517 518 // Otherwise, the bit position must be 8-bit aligned. 519 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 520 APInt result(bitWidth, 0); 521 std::copy_n( 522 rawData + (bitPos / CHAR_BIT), llvm::divideCeil(bitWidth, CHAR_BIT), 523 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 524 return result; 525 } 526 527 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the 528 /// same element count as 'type'. 529 template <typename Values> 530 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 531 return (values.size() == 1) || 532 (type.getNumElements() == static_cast<int64_t>(values.size())); 533 } 534 535 //===----------------------------------------------------------------------===// 536 // DenseElementAttr Iterators 537 //===----------------------------------------------------------------------===// 538 539 /// Constructs a new iterator. 540 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 541 DenseElementsAttr attr, size_t index) 542 : indexed_accessor_iterator<AttributeElementIterator, const void *, 543 Attribute, Attribute, Attribute>( 544 attr.getAsOpaquePointer(), index) {} 545 546 /// Accesses the Attribute value at this iterator position. 547 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 548 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); 549 Type eltTy = owner.getType().getElementType(); 550 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) { 551 if (intEltTy.getWidth() == 1) 552 return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(), 553 owner.getContext()); 554 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 555 } 556 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 557 IntElementIterator intIt(owner, index); 558 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 559 return FloatAttr::get(eltTy, *floatIt); 560 } 561 llvm_unreachable("unexpected element type"); 562 } 563 564 /// Constructs a new iterator. 565 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 566 DenseElementsAttr attr, size_t dataIndex) 567 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 568 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 569 570 /// Accesses the bool value at this iterator position. 571 bool DenseElementsAttr::BoolElementIterator::operator*() const { 572 return getBit(getData(), getDataIndex()); 573 } 574 575 /// Constructs a new iterator. 576 DenseElementsAttr::IntElementIterator::IntElementIterator( 577 DenseElementsAttr attr, size_t dataIndex) 578 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 579 attr.getRawData().data(), attr.isSplat(), dataIndex), 580 bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {} 581 582 /// Accesses the raw APInt value at this iterator position. 583 APInt DenseElementsAttr::IntElementIterator::operator*() const { 584 return readBits(getData(), 585 getDataIndex() * getDenseElementStorageWidth(bitWidth), 586 bitWidth); 587 } 588 589 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 590 const llvm::fltSemantics &smt, IntElementIterator it) 591 : llvm::mapped_iterator<IntElementIterator, 592 std::function<APFloat(const APInt &)>>( 593 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 594 595 //===----------------------------------------------------------------------===// 596 // DenseElementsAttr 597 //===----------------------------------------------------------------------===// 598 599 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 600 ArrayRef<Attribute> values) { 601 assert(type.getElementType().isIntOrFloat() && 602 "expected int or float element type"); 603 assert(hasSameElementsOrSplat(type, values)); 604 605 auto eltType = type.getElementType(); 606 size_t bitWidth = getDenseElementBitwidth(eltType); 607 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 608 609 // Compress the attribute values into a character buffer. 610 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 611 values.size()); 612 APInt intVal; 613 for (unsigned i = 0, e = values.size(); i < e; ++i) { 614 assert(eltType == values[i].getType() && 615 "expected attribute value to have element type"); 616 617 switch (eltType.getKind()) { 618 case StandardTypes::BF16: 619 case StandardTypes::F16: 620 case StandardTypes::F32: 621 case StandardTypes::F64: 622 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 623 break; 624 case StandardTypes::Integer: 625 intVal = values[i].isa<BoolAttr>() 626 ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0) 627 : values[i].cast<IntegerAttr>().getValue(); 628 break; 629 default: 630 llvm_unreachable("unexpected element type"); 631 } 632 assert(intVal.getBitWidth() == bitWidth && 633 "expected value to have same bitwidth as element type"); 634 writeBits(data.data(), i * storageBitWidth, intVal); 635 } 636 return getRaw(type, data, /*isSplat=*/(values.size() == 1)); 637 } 638 639 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 640 ArrayRef<bool> values) { 641 assert(hasSameElementsOrSplat(type, values)); 642 assert(type.getElementType().isInteger(1)); 643 644 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 645 for (int i = 0, e = values.size(); i != e; ++i) 646 setBit(buff.data(), i, values[i]); 647 return getRaw(type, buff, /*isSplat=*/(values.size() == 1)); 648 } 649 650 /// Constructs a dense integer elements attribute from an array of APInt 651 /// values. Each APInt value is expected to have the same bitwidth as the 652 /// element type of 'type'. 653 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 654 ArrayRef<APInt> values) { 655 assert(type.getElementType().isa<IntegerType>()); 656 return getRaw(type, values); 657 } 658 659 // Constructs a dense float elements attribute from an array of APFloat 660 // values. Each APFloat value is expected to have the same bitwidth as the 661 // element type of 'type'. 662 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 663 ArrayRef<APFloat> values) { 664 assert(type.getElementType().isa<FloatType>()); 665 666 // Convert the APFloat values to APInt and create a dense elements attribute. 667 std::vector<APInt> intValues(values.size()); 668 for (unsigned i = 0, e = values.size(); i != e; ++i) 669 intValues[i] = values[i].bitcastToAPInt(); 670 return getRaw(type, intValues); 671 } 672 673 // Constructs a dense elements attribute from an array of raw APInt values. 674 // Each APInt value is expected to have the same bitwidth as the element type 675 // of 'type'. 676 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, 677 ArrayRef<APInt> values) { 678 assert(hasSameElementsOrSplat(type, values)); 679 680 size_t bitWidth = getDenseElementBitwidth(type.getElementType()); 681 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 682 std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 683 values.size()); 684 for (unsigned i = 0, e = values.size(); i != e; ++i) { 685 assert(values[i].getBitWidth() == bitWidth); 686 writeBits(elementData.data(), i * storageBitWidth, values[i]); 687 } 688 return getRaw(type, elementData, /*isSplat=*/(values.size() == 1)); 689 } 690 691 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, 692 ArrayRef<char> data, bool isSplat) { 693 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 694 "type must be ranked tensor or vector"); 695 assert(type.hasStaticShape() && "type must have static shape"); 696 return Base::get(type.getContext(), StandardAttributes::DenseElements, type, 697 data, isSplat); 698 } 699 700 /// Check the information for a c++ data type, check if this type is valid for 701 /// the current attribute. This method is used to verify specific type 702 /// invariants that the templatized 'getValues' method cannot. 703 static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, 704 bool isInt) { 705 // Make sure that the data element size is the same as the type element width. 706 if (getDenseElementBitwidth(type.getElementType()) != 707 static_cast<size_t>(dataEltSize * CHAR_BIT)) 708 return false; 709 710 // Check that the element type is valid. 711 return isInt ? type.getElementType().isa<IntegerType>() 712 : type.getElementType().isa<FloatType>(); 713 } 714 715 /// Overload of the 'getRaw' method that asserts that the given type is of 716 /// integer type. This method is used to verify type invariants that the 717 /// templatized 'get' method cannot. 718 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 719 ArrayRef<char> data, 720 int64_t dataEltSize, 721 bool isInt) { 722 assert(::isValidIntOrFloat(type, dataEltSize, isInt)); 723 724 int64_t numElements = data.size() / dataEltSize; 725 assert(numElements == 1 || numElements == type.getNumElements()); 726 return getRaw(type, data, /*isSplat=*/numElements == 1); 727 } 728 729 /// A method used to verify specific type invariants that the templatized 'get' 730 /// method cannot. 731 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, 732 bool isInt) const { 733 return ::isValidIntOrFloat(getType(), dataEltSize, isInt); 734 } 735 736 /// Return the raw storage data held by this attribute. 737 ArrayRef<char> DenseElementsAttr::getRawData() const { 738 return static_cast<ImplType *>(impl)->data; 739 } 740 741 /// Returns if this attribute corresponds to a splat, i.e. if all element 742 /// values are the same. 743 bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; } 744 745 /// Return the held element values as a range of Attributes. 746 auto DenseElementsAttr::getAttributeValues() const 747 -> llvm::iterator_range<AttributeElementIterator> { 748 return {attr_value_begin(), attr_value_end()}; 749 } 750 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 751 return AttributeElementIterator(*this, 0); 752 } 753 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 754 return AttributeElementIterator(*this, getNumElements()); 755 } 756 757 /// Return the held element values as a range of bool. The element type of 758 /// this attribute must be of integer type of bitwidth 1. 759 auto DenseElementsAttr::getBoolValues() const 760 -> llvm::iterator_range<BoolElementIterator> { 761 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 762 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 763 (void)eltType; 764 return {BoolElementIterator(*this, 0), 765 BoolElementIterator(*this, getNumElements())}; 766 } 767 768 /// Return the held element values as a range of APInts. The element type of 769 /// this attribute must be of integer type. 770 auto DenseElementsAttr::getIntValues() const 771 -> llvm::iterator_range<IntElementIterator> { 772 assert(getType().getElementType().isa<IntegerType>() && 773 "expected integer type"); 774 return {raw_int_begin(), raw_int_end()}; 775 } 776 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 777 assert(getType().getElementType().isa<IntegerType>() && 778 "expected integer type"); 779 return raw_int_begin(); 780 } 781 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 782 assert(getType().getElementType().isa<IntegerType>() && 783 "expected integer type"); 784 return raw_int_end(); 785 } 786 787 /// Return the held element values as a range of APFloat. The element type of 788 /// this attribute must be of float type. 789 auto DenseElementsAttr::getFloatValues() const 790 -> llvm::iterator_range<FloatElementIterator> { 791 auto elementType = getType().getElementType().cast<FloatType>(); 792 assert(elementType.isa<FloatType>() && "expected float type"); 793 const auto &elementSemantics = elementType.getFloatSemantics(); 794 return {FloatElementIterator(elementSemantics, raw_int_begin()), 795 FloatElementIterator(elementSemantics, raw_int_end())}; 796 } 797 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 798 return getFloatValues().begin(); 799 } 800 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 801 return getFloatValues().end(); 802 } 803 804 /// Return a new DenseElementsAttr that has the same data as the current 805 /// attribute, but has been reshaped to 'newType'. The new type must have the 806 /// same total number of elements as well as element type. 807 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 808 ShapedType curType = getType(); 809 if (curType == newType) 810 return *this; 811 812 (void)curType; 813 assert(newType.getElementType() == curType.getElementType() && 814 "expected the same element type"); 815 assert(newType.getNumElements() == curType.getNumElements() && 816 "expected the same number of elements"); 817 return getRaw(newType, getRawData(), isSplat()); 818 } 819 820 DenseElementsAttr 821 DenseElementsAttr::mapValues(Type newElementType, 822 function_ref<APInt(const APInt &)> mapping) const { 823 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 824 } 825 826 DenseElementsAttr DenseElementsAttr::mapValues( 827 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 828 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 829 } 830 831 //===----------------------------------------------------------------------===// 832 // DenseFPElementsAttr 833 //===----------------------------------------------------------------------===// 834 835 template <typename Fn, typename Attr> 836 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 837 Type newElementType, 838 llvm::SmallVectorImpl<char> &data) { 839 size_t bitWidth = getDenseElementBitwidth(newElementType); 840 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 841 842 ShapedType newArrayType; 843 if (inType.isa<RankedTensorType>()) 844 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 845 else if (inType.isa<UnrankedTensorType>()) 846 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 847 else if (inType.isa<VectorType>()) 848 newArrayType = VectorType::get(inType.getShape(), newElementType); 849 else 850 assert(newArrayType && "Unhandled tensor type"); 851 852 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 853 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 854 855 // Functor used to process a single element value of the attribute. 856 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 857 auto newInt = mapping(value); 858 assert(newInt.getBitWidth() == bitWidth); 859 writeBits(data.data(), index * storageBitWidth, newInt); 860 }; 861 862 // Check for the splat case. 863 if (attr.isSplat()) { 864 processElt(*attr.begin(), /*index=*/0); 865 return newArrayType; 866 } 867 868 // Otherwise, process all of the element values. 869 uint64_t elementIdx = 0; 870 for (auto value : attr) 871 processElt(value, elementIdx++); 872 return newArrayType; 873 } 874 875 DenseElementsAttr DenseFPElementsAttr::mapValues( 876 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 877 llvm::SmallVector<char, 8> elementData; 878 auto newArrayType = 879 mappingHelper(mapping, *this, getType(), newElementType, elementData); 880 881 return getRaw(newArrayType, elementData, isSplat()); 882 } 883 884 /// Method for supporting type inquiry through isa, cast and dyn_cast. 885 bool DenseFPElementsAttr::classof(Attribute attr) { 886 return attr.isa<DenseElementsAttr>() && 887 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 888 } 889 890 //===----------------------------------------------------------------------===// 891 // DenseIntElementsAttr 892 //===----------------------------------------------------------------------===// 893 894 DenseElementsAttr DenseIntElementsAttr::mapValues( 895 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 896 llvm::SmallVector<char, 8> elementData; 897 auto newArrayType = 898 mappingHelper(mapping, *this, getType(), newElementType, elementData); 899 900 return getRaw(newArrayType, elementData, isSplat()); 901 } 902 903 /// Method for supporting type inquiry through isa, cast and dyn_cast. 904 bool DenseIntElementsAttr::classof(Attribute attr) { 905 return attr.isa<DenseElementsAttr>() && 906 attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>(); 907 } 908 909 //===----------------------------------------------------------------------===// 910 // OpaqueElementsAttr 911 //===----------------------------------------------------------------------===// 912 913 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, 914 StringRef bytes) { 915 assert(TensorType::isValidElementType(type.getElementType()) && 916 "Input element type should be a valid tensor element type"); 917 return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type, 918 dialect, bytes); 919 } 920 921 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } 922 923 /// Return the value at the given index. If index does not refer to a valid 924 /// element, then a null attribute is returned. 925 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 926 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 927 if (Dialect *dialect = getDialect()) 928 return dialect->extractElementHook(*this, index); 929 return Attribute(); 930 } 931 932 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } 933 934 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 935 if (auto *d = getDialect()) 936 return d->decodeHook(*this, result); 937 return true; 938 } 939 940 //===----------------------------------------------------------------------===// 941 // SparseElementsAttr 942 //===----------------------------------------------------------------------===// 943 944 SparseElementsAttr SparseElementsAttr::get(ShapedType type, 945 DenseElementsAttr indices, 946 DenseElementsAttr values) { 947 assert(indices.getType().getElementType().isInteger(64) && 948 "expected sparse indices to be 64-bit integer values"); 949 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 950 "type must be ranked tensor or vector"); 951 assert(type.hasStaticShape() && "type must have static shape"); 952 return Base::get(type.getContext(), StandardAttributes::SparseElements, type, 953 indices.cast<DenseIntElementsAttr>(), values); 954 } 955 956 DenseIntElementsAttr SparseElementsAttr::getIndices() const { 957 return getImpl()->indices; 958 } 959 960 DenseElementsAttr SparseElementsAttr::getValues() const { 961 return getImpl()->values; 962 } 963 964 /// Return the value of the element at the given index. 965 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 966 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 967 auto type = getType(); 968 969 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 970 // as a 1-D index array. 971 auto sparseIndices = getIndices(); 972 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 973 974 // Check to see if the indices are a splat. 975 if (sparseIndices.isSplat()) { 976 // If the index is also not a splat of the index value, we know that the 977 // value is zero. 978 auto splatIndex = *sparseIndexValues.begin(); 979 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 980 return getZeroAttr(); 981 982 // If the indices are a splat, we also expect the values to be a splat. 983 assert(getValues().isSplat() && "expected splat values"); 984 return getValues().getSplatValue(); 985 } 986 987 // Build a mapping between known indices and the offset of the stored element. 988 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 989 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 990 size_t rank = type.getRank(); 991 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 992 mappedIndices.try_emplace( 993 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 994 995 // Look for the provided index key within the mapped indices. If the provided 996 // index is not found, then return a zero attribute. 997 auto it = mappedIndices.find(index); 998 if (it == mappedIndices.end()) 999 return getZeroAttr(); 1000 1001 // Otherwise, return the held sparse value element. 1002 return getValues().getValue(it->second); 1003 } 1004 1005 /// Get a zero APFloat for the given sparse attribute. 1006 APFloat SparseElementsAttr::getZeroAPFloat() const { 1007 auto eltType = getType().getElementType().cast<FloatType>(); 1008 return APFloat(eltType.getFloatSemantics()); 1009 } 1010 1011 /// Get a zero APInt for the given sparse attribute. 1012 APInt SparseElementsAttr::getZeroAPInt() const { 1013 auto eltType = getType().getElementType().cast<IntegerType>(); 1014 return APInt::getNullValue(eltType.getWidth()); 1015 } 1016 1017 /// Get a zero attribute for the given attribute type. 1018 Attribute SparseElementsAttr::getZeroAttr() const { 1019 auto eltType = getType().getElementType(); 1020 1021 // Handle floating point elements. 1022 if (eltType.isa<FloatType>()) 1023 return FloatAttr::get(eltType, 0); 1024 1025 // Otherwise, this is an integer. 1026 auto intEltTy = eltType.cast<IntegerType>(); 1027 if (intEltTy.getWidth() == 1) 1028 return BoolAttr::get(false, eltType.getContext()); 1029 return IntegerAttr::get(eltType, 0); 1030 } 1031 1032 /// Flatten, and return, all of the sparse indices in this attribute in 1033 /// row-major order. 1034 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1035 std::vector<ptrdiff_t> flatSparseIndices; 1036 1037 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1038 // as a 1-D index array. 1039 auto sparseIndices = getIndices(); 1040 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1041 if (sparseIndices.isSplat()) { 1042 SmallVector<uint64_t, 8> indices(getType().getRank(), 1043 *sparseIndexValues.begin()); 1044 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1045 return flatSparseIndices; 1046 } 1047 1048 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1049 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1050 size_t rank = getType().getRank(); 1051 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1052 flatSparseIndices.push_back(getFlattenedIndex( 1053 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1054 return flatSparseIndices; 1055 } 1056 1057 //===----------------------------------------------------------------------===// 1058 // NamedAttributeList 1059 //===----------------------------------------------------------------------===// 1060 1061 NamedAttributeList::NamedAttributeList(ArrayRef<NamedAttribute> attributes) { 1062 setAttrs(attributes); 1063 } 1064 1065 ArrayRef<NamedAttribute> NamedAttributeList::getAttrs() const { 1066 return attrs ? attrs.getValue() : llvm::None; 1067 } 1068 1069 /// Replace the held attributes with ones provided in 'newAttrs'. 1070 void NamedAttributeList::setAttrs(ArrayRef<NamedAttribute> attributes) { 1071 // Don't create an attribute list if there are no attributes. 1072 if (attributes.empty()) 1073 attrs = nullptr; 1074 else 1075 attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext()); 1076 } 1077 1078 /// Return the specified attribute if present, null otherwise. 1079 Attribute NamedAttributeList::get(StringRef name) const { 1080 return attrs ? attrs.get(name) : nullptr; 1081 } 1082 1083 /// Return the specified attribute if present, null otherwise. 1084 Attribute NamedAttributeList::get(Identifier name) const { 1085 return attrs ? attrs.get(name) : nullptr; 1086 } 1087 1088 /// If the an attribute exists with the specified name, change it to the new 1089 /// value. Otherwise, add a new attribute with the specified name/value. 1090 void NamedAttributeList::set(Identifier name, Attribute value) { 1091 assert(value && "attributes may never be null"); 1092 1093 // If we already have this attribute, replace it. 1094 auto origAttrs = getAttrs(); 1095 SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end()); 1096 for (auto &elt : newAttrs) 1097 if (elt.first == name) { 1098 elt.second = value; 1099 attrs = DictionaryAttr::get(newAttrs, value.getContext()); 1100 return; 1101 } 1102 1103 // Otherwise, add it. 1104 newAttrs.push_back({name, value}); 1105 attrs = DictionaryAttr::get(newAttrs, value.getContext()); 1106 } 1107 1108 /// Remove the attribute with the specified name if it exists. The return 1109 /// value indicates whether the attribute was present or not. 1110 auto NamedAttributeList::remove(Identifier name) -> RemoveResult { 1111 auto origAttrs = getAttrs(); 1112 for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { 1113 if (origAttrs[i].first == name) { 1114 // Handle the simple case of removing the only attribute in the list. 1115 if (e == 1) { 1116 attrs = nullptr; 1117 return RemoveResult::Removed; 1118 } 1119 1120 SmallVector<NamedAttribute, 8> newAttrs; 1121 newAttrs.reserve(origAttrs.size() - 1); 1122 newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); 1123 newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); 1124 attrs = DictionaryAttr::get(newAttrs, newAttrs[0].second.getContext()); 1125 return RemoveResult::Removed; 1126 } 1127 } 1128 return RemoveResult::NotFound; 1129 } 1130