1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Attributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Diagnostics.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/Function.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Types.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
23 //===----------------------------------------------------------------------===//
24 // AttributeStorage
25 //===----------------------------------------------------------------------===//
26 
27 AttributeStorage::AttributeStorage(Type type)
28     : type(type.getAsOpaquePointer()) {}
29 AttributeStorage::AttributeStorage() : type(nullptr) {}
30 
31 Type AttributeStorage::getType() const {
32   return Type::getFromOpaquePointer(type);
33 }
34 void AttributeStorage::setType(Type newType) {
35   type = newType.getAsOpaquePointer();
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Attribute
40 //===----------------------------------------------------------------------===//
41 
42 /// Return the type of this attribute.
43 Type Attribute::getType() const { return impl->getType(); }
44 
45 /// Return the context this attribute belongs to.
46 MLIRContext *Attribute::getContext() const { return getType().getContext(); }
47 
48 /// Get the dialect this attribute is registered to.
49 Dialect &Attribute::getDialect() const { return impl->getDialect(); }
50 
51 //===----------------------------------------------------------------------===//
52 // AffineMapAttr
53 //===----------------------------------------------------------------------===//
54 
55 AffineMapAttr AffineMapAttr::get(AffineMap value) {
56   return Base::get(value.getContext(), StandardAttributes::AffineMap, value);
57 }
58 
59 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
60 
61 //===----------------------------------------------------------------------===//
62 // ArrayAttr
63 //===----------------------------------------------------------------------===//
64 
65 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
66   return Base::get(context, StandardAttributes::Array, value);
67 }
68 
69 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
70 
71 Attribute ArrayAttr::operator[](unsigned idx) const {
72   assert(idx < size() && "index out of bounds");
73   return getValue()[idx];
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // BoolAttr
78 //===----------------------------------------------------------------------===//
79 
80 bool BoolAttr::getValue() const { return getImpl()->value; }
81 
82 //===----------------------------------------------------------------------===//
83 // DictionaryAttr
84 //===----------------------------------------------------------------------===//
85 
86 /// Perform a three-way comparison between the names of the specified
87 /// NamedAttributes.
88 static int compareNamedAttributes(const NamedAttribute *lhs,
89                                   const NamedAttribute *rhs) {
90   return lhs->first.strref().compare(rhs->first.strref());
91 }
92 
93 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
94                                    MLIRContext *context) {
95   assert(llvm::all_of(value,
96                       [](const NamedAttribute &attr) { return attr.second; }) &&
97          "value cannot have null entries");
98 
99   // We need to sort the element list to canonicalize it, but we also don't want
100   // to do a ton of work in the super common case where the element list is
101   // already sorted.
102   SmallVector<NamedAttribute, 8> storage;
103   switch (value.size()) {
104   case 0:
105     break;
106   case 1:
107     // A single element is already sorted.
108     break;
109   case 2:
110     assert(value[0].first != value[1].first &&
111            "DictionaryAttr element names must be unique");
112 
113     // Don't invoke a general sort for two element case.
114     if (value[0].first.strref() > value[1].first.strref()) {
115       storage.push_back(value[1]);
116       storage.push_back(value[0]);
117       value = storage;
118     }
119     break;
120   default:
121     // Check to see they are sorted already.
122     bool isSorted = true;
123     for (unsigned i = 0, e = value.size() - 1; i != e; ++i) {
124       if (value[i].first.strref() > value[i + 1].first.strref()) {
125         isSorted = false;
126         break;
127       }
128     }
129     // If not, do a general sort.
130     if (!isSorted) {
131       storage.append(value.begin(), value.end());
132       llvm::array_pod_sort(storage.begin(), storage.end(),
133                            compareNamedAttributes);
134       value = storage;
135     }
136 
137     // Ensure that the attribute elements are unique.
138     assert(std::adjacent_find(value.begin(), value.end(),
139                               [](NamedAttribute l, NamedAttribute r) {
140                                 return l.first == r.first;
141                               }) == value.end() &&
142            "DictionaryAttr element names must be unique");
143   }
144 
145   return Base::get(context, StandardAttributes::Dictionary, value);
146 }
147 
148 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
149   return getImpl()->getElements();
150 }
151 
152 /// Return the specified attribute if present, null otherwise.
153 Attribute DictionaryAttr::get(StringRef name) const {
154   ArrayRef<NamedAttribute> values = getValue();
155   auto compare = [](NamedAttribute attr, StringRef name) {
156     return attr.first.strref() < name;
157   };
158   auto it = llvm::lower_bound(values, name, compare);
159   return it != values.end() && it->first.is(name) ? it->second : Attribute();
160 }
161 Attribute DictionaryAttr::get(Identifier name) const {
162   for (auto elt : getValue())
163     if (elt.first == name)
164       return elt.second;
165   return nullptr;
166 }
167 
168 DictionaryAttr::iterator DictionaryAttr::begin() const {
169   return getValue().begin();
170 }
171 DictionaryAttr::iterator DictionaryAttr::end() const {
172   return getValue().end();
173 }
174 size_t DictionaryAttr::size() const { return getValue().size(); }
175 
176 //===----------------------------------------------------------------------===//
177 // FloatAttr
178 //===----------------------------------------------------------------------===//
179 
180 FloatAttr FloatAttr::get(Type type, double value) {
181   return Base::get(type.getContext(), StandardAttributes::Float, type, value);
182 }
183 
184 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
185   return Base::getChecked(loc, StandardAttributes::Float, type, value);
186 }
187 
188 FloatAttr FloatAttr::get(Type type, const APFloat &value) {
189   return Base::get(type.getContext(), StandardAttributes::Float, type, value);
190 }
191 
192 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
193   return Base::getChecked(loc, StandardAttributes::Float, type, value);
194 }
195 
196 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
197 
198 double FloatAttr::getValueAsDouble() const {
199   return getValueAsDouble(getValue());
200 }
201 double FloatAttr::getValueAsDouble(APFloat value) {
202   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
203     bool losesInfo = false;
204     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
205                   &losesInfo);
206   }
207   return value.convertToDouble();
208 }
209 
210 /// Verify construction invariants.
211 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
212   if (!type.isa<FloatType>())
213     return emitError(loc, "expected floating point type");
214   return success();
215 }
216 
217 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
218                                                       double value) {
219   return verifyFloatTypeInvariants(loc, type);
220 }
221 
222 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
223                                                       const APFloat &value) {
224   // Verify that the type is correct.
225   if (failed(verifyFloatTypeInvariants(loc, type)))
226     return failure();
227 
228   // Verify that the type semantics match that of the value.
229   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
230     return emitError(
231         loc, "FloatAttr type doesn't match the type implied by its value");
232   }
233   return success();
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // SymbolRefAttr
238 //===----------------------------------------------------------------------===//
239 
240 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
241   return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None)
242       .cast<FlatSymbolRefAttr>();
243 }
244 
245 SymbolRefAttr SymbolRefAttr::get(StringRef value,
246                                  ArrayRef<FlatSymbolRefAttr> nestedReferences,
247                                  MLIRContext *ctx) {
248   return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences);
249 }
250 
251 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
252 
253 StringRef SymbolRefAttr::getLeafReference() const {
254   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
255   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
256 }
257 
258 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
259   return getImpl()->getNestedRefs();
260 }
261 
262 //===----------------------------------------------------------------------===//
263 // IntegerAttr
264 //===----------------------------------------------------------------------===//
265 
266 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
267   return Base::get(type.getContext(), StandardAttributes::Integer, type, value);
268 }
269 
270 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
271   // This uses 64 bit APInts by default for index type.
272   if (type.isIndex())
273     return get(type, APInt(64, value));
274 
275   auto intType = type.cast<IntegerType>();
276   return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
277 }
278 
279 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
280 
281 int64_t IntegerAttr::getInt() const {
282   assert((getImpl()->getType().isIndex() ||
283           getImpl()->getType().isSignlessInteger()) &&
284          "must be signless integer");
285   return getValue().getSExtValue();
286 }
287 
288 int64_t IntegerAttr::getSInt() const {
289   assert(getImpl()->getType().isSignedInteger() && "must be signed integer");
290   return getValue().getSExtValue();
291 }
292 
293 uint64_t IntegerAttr::getUInt() const {
294   assert(getImpl()->getType().isUnsignedInteger() &&
295          "must be unsigned integer");
296   return getValue().getZExtValue();
297 }
298 
299 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
300   if (type.isa<IntegerType>() || type.isa<IndexType>())
301     return success();
302   return emitError(loc, "expected integer or index type");
303 }
304 
305 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
306                                                         int64_t value) {
307   return verifyIntegerTypeInvariants(loc, type);
308 }
309 
310 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
311                                                         const APInt &value) {
312   if (failed(verifyIntegerTypeInvariants(loc, type)))
313     return failure();
314   if (auto integerType = type.dyn_cast<IntegerType>())
315     if (integerType.getWidth() != value.getBitWidth())
316       return emitError(loc, "integer type bit width (")
317              << integerType.getWidth() << ") doesn't match value bit width ("
318              << value.getBitWidth() << ")";
319   return success();
320 }
321 
322 //===----------------------------------------------------------------------===//
323 // IntegerSetAttr
324 //===----------------------------------------------------------------------===//
325 
326 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
327   return Base::get(value.getConstraint(0).getContext(),
328                    StandardAttributes::IntegerSet, value);
329 }
330 
331 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
332 
333 //===----------------------------------------------------------------------===//
334 // OpaqueAttr
335 //===----------------------------------------------------------------------===//
336 
337 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
338                            MLIRContext *context) {
339   return Base::get(context, StandardAttributes::Opaque, dialect, attrData,
340                    type);
341 }
342 
343 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
344                                   Type type, Location location) {
345   return Base::getChecked(location, StandardAttributes::Opaque, dialect,
346                           attrData, type);
347 }
348 
349 /// Returns the dialect namespace of the opaque attribute.
350 Identifier OpaqueAttr::getDialectNamespace() const {
351   return getImpl()->dialectNamespace;
352 }
353 
354 /// Returns the raw attribute data of the opaque attribute.
355 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
356 
357 /// Verify the construction of an opaque attribute.
358 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
359                                                        Identifier dialect,
360                                                        StringRef attrData,
361                                                        Type type) {
362   if (!Dialect::isValidNamespace(dialect.strref()))
363     return emitError(loc, "invalid dialect namespace '") << dialect << "'";
364   return success();
365 }
366 
367 //===----------------------------------------------------------------------===//
368 // StringAttr
369 //===----------------------------------------------------------------------===//
370 
371 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
372   return get(bytes, NoneType::get(context));
373 }
374 
375 /// Get an instance of a StringAttr with the given string and Type.
376 StringAttr StringAttr::get(StringRef bytes, Type type) {
377   return Base::get(type.getContext(), StandardAttributes::String, bytes, type);
378 }
379 
380 StringRef StringAttr::getValue() const { return getImpl()->value; }
381 
382 //===----------------------------------------------------------------------===//
383 // TypeAttr
384 //===----------------------------------------------------------------------===//
385 
386 TypeAttr TypeAttr::get(Type value) {
387   return Base::get(value.getContext(), StandardAttributes::Type, value);
388 }
389 
390 Type TypeAttr::getValue() const { return getImpl()->value; }
391 
392 //===----------------------------------------------------------------------===//
393 // ElementsAttr
394 //===----------------------------------------------------------------------===//
395 
396 ShapedType ElementsAttr::getType() const {
397   return Attribute::getType().cast<ShapedType>();
398 }
399 
400 /// Returns the number of elements held by this attribute.
401 int64_t ElementsAttr::getNumElements() const {
402   return getType().getNumElements();
403 }
404 
405 /// Return the value at the given index. If index does not refer to a valid
406 /// element, then a null attribute is returned.
407 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
408   switch (getKind()) {
409   case StandardAttributes::DenseElements:
410     return cast<DenseElementsAttr>().getValue(index);
411   case StandardAttributes::OpaqueElements:
412     return cast<OpaqueElementsAttr>().getValue(index);
413   case StandardAttributes::SparseElements:
414     return cast<SparseElementsAttr>().getValue(index);
415   default:
416     llvm_unreachable("unknown ElementsAttr kind");
417   }
418 }
419 
420 /// Return if the given 'index' refers to a valid element in this attribute.
421 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
422   auto type = getType();
423 
424   // Verify that the rank of the indices matches the held type.
425   auto rank = type.getRank();
426   if (rank != static_cast<int64_t>(index.size()))
427     return false;
428 
429   // Verify that all of the indices are within the shape dimensions.
430   auto shape = type.getShape();
431   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
432     return static_cast<int64_t>(index[i]) < shape[i];
433   });
434 }
435 
436 ElementsAttr
437 ElementsAttr::mapValues(Type newElementType,
438                         function_ref<APInt(const APInt &)> mapping) const {
439   switch (getKind()) {
440   case StandardAttributes::DenseElements:
441     return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
442   default:
443     llvm_unreachable("unsupported ElementsAttr subtype");
444   }
445 }
446 
447 ElementsAttr
448 ElementsAttr::mapValues(Type newElementType,
449                         function_ref<APInt(const APFloat &)> mapping) const {
450   switch (getKind()) {
451   case StandardAttributes::DenseElements:
452     return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
453   default:
454     llvm_unreachable("unsupported ElementsAttr subtype");
455   }
456 }
457 
458 /// Returns the 1 dimensional flattened row-major index from the given
459 /// multi-dimensional index.
460 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
461   assert(isValidIndex(index) && "expected valid multi-dimensional index");
462   auto type = getType();
463 
464   // Reduce the provided multidimensional index into a flattended 1D row-major
465   // index.
466   auto rank = type.getRank();
467   auto shape = type.getShape();
468   uint64_t valueIndex = 0;
469   uint64_t dimMultiplier = 1;
470   for (int i = rank - 1; i >= 0; --i) {
471     valueIndex += index[i] * dimMultiplier;
472     dimMultiplier *= shape[i];
473   }
474   return valueIndex;
475 }
476 
477 //===----------------------------------------------------------------------===//
478 // DenseElementAttr Utilities
479 //===----------------------------------------------------------------------===//
480 
481 static size_t getDenseElementBitwidth(Type eltType) {
482   // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
483   // with double semantics.
484   return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
485 }
486 
487 /// Get the bitwidth of a dense element type within the buffer.
488 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
489 static size_t getDenseElementStorageWidth(size_t origWidth) {
490   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
491 }
492 
493 /// Set a bit to a specific value.
494 static void setBit(char *rawData, size_t bitPos, bool value) {
495   if (value)
496     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
497   else
498     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
499 }
500 
501 /// Return the value of the specified bit.
502 static bool getBit(const char *rawData, size_t bitPos) {
503   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
504 }
505 
506 /// Writes value to the bit position `bitPos` in array `rawData`.
507 static void writeBits(char *rawData, size_t bitPos, APInt value) {
508   size_t bitWidth = value.getBitWidth();
509 
510   // If the bitwidth is 1 we just toggle the specific bit.
511   if (bitWidth == 1)
512     return setBit(rawData, bitPos, value.isOneValue());
513 
514   // Otherwise, the bit position is guaranteed to be byte aligned.
515   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
516   std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
517               llvm::divideCeil(bitWidth, CHAR_BIT),
518               rawData + (bitPos / CHAR_BIT));
519 }
520 
521 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
522 /// `rawData`.
523 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
524   // Handle a boolean bit position.
525   if (bitWidth == 1)
526     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
527 
528   // Otherwise, the bit position must be 8-bit aligned.
529   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
530   APInt result(bitWidth, 0);
531   std::copy_n(
532       rawData + (bitPos / CHAR_BIT), llvm::divideCeil(bitWidth, CHAR_BIT),
533       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
534   return result;
535 }
536 
537 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the
538 /// same element count as 'type'.
539 template <typename Values>
540 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
541   return (values.size() == 1) ||
542          (type.getNumElements() == static_cast<int64_t>(values.size()));
543 }
544 
545 //===----------------------------------------------------------------------===//
546 // DenseElementAttr Iterators
547 //===----------------------------------------------------------------------===//
548 
549 /// Constructs a new iterator.
550 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
551     DenseElementsAttr attr, size_t index)
552     : indexed_accessor_iterator<AttributeElementIterator, const void *,
553                                 Attribute, Attribute, Attribute>(
554           attr.getAsOpaquePointer(), index) {}
555 
556 /// Accesses the Attribute value at this iterator position.
557 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
558   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
559   Type eltTy = owner.getType().getElementType();
560   if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) {
561     if (intEltTy.getWidth() == 1)
562       return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(),
563                            owner.getContext());
564     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
565   }
566   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
567     IntElementIterator intIt(owner, index);
568     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
569     return FloatAttr::get(eltTy, *floatIt);
570   }
571   llvm_unreachable("unexpected element type");
572 }
573 
574 /// Constructs a new iterator.
575 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
576     DenseElementsAttr attr, size_t dataIndex)
577     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
578           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
579 
580 /// Accesses the bool value at this iterator position.
581 bool DenseElementsAttr::BoolElementIterator::operator*() const {
582   return getBit(getData(), getDataIndex());
583 }
584 
585 /// Constructs a new iterator.
586 DenseElementsAttr::IntElementIterator::IntElementIterator(
587     DenseElementsAttr attr, size_t dataIndex)
588     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
589           attr.getRawData().data(), attr.isSplat(), dataIndex),
590       bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {}
591 
592 /// Accesses the raw APInt value at this iterator position.
593 APInt DenseElementsAttr::IntElementIterator::operator*() const {
594   return readBits(getData(),
595                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
596                   bitWidth);
597 }
598 
599 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
600     const llvm::fltSemantics &smt, IntElementIterator it)
601     : llvm::mapped_iterator<IntElementIterator,
602                             std::function<APFloat(const APInt &)>>(
603           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
604 
605 //===----------------------------------------------------------------------===//
606 // DenseElementsAttr
607 //===----------------------------------------------------------------------===//
608 
609 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
610                                          ArrayRef<Attribute> values) {
611   assert(type.getElementType().isSignlessIntOrFloat() &&
612          "expected int or float element type");
613   assert(hasSameElementsOrSplat(type, values));
614 
615   auto eltType = type.getElementType();
616   size_t bitWidth = getDenseElementBitwidth(eltType);
617   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
618 
619   // Compress the attribute values into a character buffer.
620   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
621                             values.size());
622   APInt intVal;
623   for (unsigned i = 0, e = values.size(); i < e; ++i) {
624     assert(eltType == values[i].getType() &&
625            "expected attribute value to have element type");
626 
627     switch (eltType.getKind()) {
628     case StandardTypes::BF16:
629     case StandardTypes::F16:
630     case StandardTypes::F32:
631     case StandardTypes::F64:
632       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
633       break;
634     case StandardTypes::Integer:
635       intVal = values[i].isa<BoolAttr>()
636                    ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0)
637                    : values[i].cast<IntegerAttr>().getValue();
638       break;
639     default:
640       llvm_unreachable("unexpected element type");
641     }
642     assert(intVal.getBitWidth() == bitWidth &&
643            "expected value to have same bitwidth as element type");
644     writeBits(data.data(), i * storageBitWidth, intVal);
645   }
646   return getRaw(type, data, /*isSplat=*/(values.size() == 1));
647 }
648 
649 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
650                                          ArrayRef<bool> values) {
651   assert(hasSameElementsOrSplat(type, values));
652   assert(type.getElementType().isInteger(1));
653 
654   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
655   for (int i = 0, e = values.size(); i != e; ++i)
656     setBit(buff.data(), i, values[i]);
657   return getRaw(type, buff, /*isSplat=*/(values.size() == 1));
658 }
659 
660 /// Constructs a dense integer elements attribute from an array of APInt
661 /// values. Each APInt value is expected to have the same bitwidth as the
662 /// element type of 'type'.
663 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
664                                          ArrayRef<APInt> values) {
665   assert(type.getElementType().isa<IntegerType>());
666   return getRaw(type, values);
667 }
668 
669 // Constructs a dense float elements attribute from an array of APFloat
670 // values. Each APFloat value is expected to have the same bitwidth as the
671 // element type of 'type'.
672 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
673                                          ArrayRef<APFloat> values) {
674   assert(type.getElementType().isa<FloatType>());
675 
676   // Convert the APFloat values to APInt and create a dense elements attribute.
677   std::vector<APInt> intValues(values.size());
678   for (unsigned i = 0, e = values.size(); i != e; ++i)
679     intValues[i] = values[i].bitcastToAPInt();
680   return getRaw(type, intValues);
681 }
682 
683 /// Construct a dense elements attribute from a raw buffer representing the
684 /// data for this attribute. Users should generally not use this methods as
685 /// the expected buffer format may not be a form the user expects.
686 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
687                                                       ArrayRef<char> rawBuffer,
688                                                       bool isSplatBuffer) {
689   return getRaw(type, rawBuffer, isSplatBuffer);
690 }
691 
692 /// Constructs a dense elements attribute from an array of raw APInt values.
693 /// Each APInt value is expected to have the same bitwidth as the element type
694 /// of 'type'.
695 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
696                                             ArrayRef<APInt> values) {
697   assert(hasSameElementsOrSplat(type, values));
698 
699   size_t bitWidth = getDenseElementBitwidth(type.getElementType());
700   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
701   std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
702                                 values.size());
703   for (unsigned i = 0, e = values.size(); i != e; ++i) {
704     assert(values[i].getBitWidth() == bitWidth);
705     writeBits(elementData.data(), i * storageBitWidth, values[i]);
706   }
707   return getRaw(type, elementData, /*isSplat=*/(values.size() == 1));
708 }
709 
710 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
711                                             ArrayRef<char> data, bool isSplat) {
712   assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
713          "type must be ranked tensor or vector");
714   assert(type.hasStaticShape() && "type must have static shape");
715   return Base::get(type.getContext(), StandardAttributes::DenseElements, type,
716                    data, isSplat);
717 }
718 
719 /// Check the information for a C++ data type, check if this type is valid for
720 /// the current attribute. This method is used to verify specific type
721 /// invariants that the templatized 'getValues' method cannot.
722 static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, bool isInt,
723                               bool isSigned) {
724   // Make sure that the data element size is the same as the type element width.
725   if (getDenseElementBitwidth(type.getElementType()) !=
726       static_cast<size_t>(dataEltSize * CHAR_BIT))
727     return false;
728 
729   // Check that the element type is either float or integer.
730   if (!isInt)
731     return type.getElementType().isa<FloatType>();
732 
733   auto intType = type.getElementType().dyn_cast<IntegerType>();
734   if (!intType)
735     return false;
736 
737   // Make sure signedness semantics is consistent.
738   if (intType.isSignless())
739     return true;
740   return intType.isSigned() ? isSigned : !isSigned;
741 }
742 
743 /// Overload of the 'getRaw' method that asserts that the given type is of
744 /// integer type. This method is used to verify type invariants that the
745 /// templatized 'get' method cannot.
746 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
747                                                       ArrayRef<char> data,
748                                                       int64_t dataEltSize,
749                                                       bool isInt,
750                                                       bool isSigned) {
751   assert(::isValidIntOrFloat(type, dataEltSize, isInt, isSigned));
752 
753   int64_t numElements = data.size() / dataEltSize;
754   assert(numElements == 1 || numElements == type.getNumElements());
755   return getRaw(type, data, /*isSplat=*/numElements == 1);
756 }
757 
758 /// A method used to verify specific type invariants that the templatized 'get'
759 /// method cannot.
760 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
761                                           bool isSigned) const {
762   return ::isValidIntOrFloat(getType(), dataEltSize, isInt, isSigned);
763 }
764 
765 /// Returns if this attribute corresponds to a splat, i.e. if all element
766 /// values are the same.
767 bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; }
768 
769 /// Return the held element values as a range of Attributes.
770 auto DenseElementsAttr::getAttributeValues() const
771     -> llvm::iterator_range<AttributeElementIterator> {
772   return {attr_value_begin(), attr_value_end()};
773 }
774 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
775   return AttributeElementIterator(*this, 0);
776 }
777 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
778   return AttributeElementIterator(*this, getNumElements());
779 }
780 
781 /// Return the held element values as a range of bool. The element type of
782 /// this attribute must be of integer type of bitwidth 1.
783 auto DenseElementsAttr::getBoolValues() const
784     -> llvm::iterator_range<BoolElementIterator> {
785   auto eltType = getType().getElementType().dyn_cast<IntegerType>();
786   assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
787   (void)eltType;
788   return {BoolElementIterator(*this, 0),
789           BoolElementIterator(*this, getNumElements())};
790 }
791 
792 /// Return the held element values as a range of APInts. The element type of
793 /// this attribute must be of integer type.
794 auto DenseElementsAttr::getIntValues() const
795     -> llvm::iterator_range<IntElementIterator> {
796   assert(getType().getElementType().isa<IntegerType>() &&
797          "expected integer type");
798   return {raw_int_begin(), raw_int_end()};
799 }
800 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
801   assert(getType().getElementType().isa<IntegerType>() &&
802          "expected integer type");
803   return raw_int_begin();
804 }
805 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
806   assert(getType().getElementType().isa<IntegerType>() &&
807          "expected integer type");
808   return raw_int_end();
809 }
810 
811 /// Return the held element values as a range of APFloat. The element type of
812 /// this attribute must be of float type.
813 auto DenseElementsAttr::getFloatValues() const
814     -> llvm::iterator_range<FloatElementIterator> {
815   auto elementType = getType().getElementType().cast<FloatType>();
816   assert(elementType.isa<FloatType>() && "expected float type");
817   const auto &elementSemantics = elementType.getFloatSemantics();
818   return {FloatElementIterator(elementSemantics, raw_int_begin()),
819           FloatElementIterator(elementSemantics, raw_int_end())};
820 }
821 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
822   return getFloatValues().begin();
823 }
824 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
825   return getFloatValues().end();
826 }
827 
828 /// Return the raw storage data held by this attribute.
829 ArrayRef<char> DenseElementsAttr::getRawData() const {
830   return static_cast<ImplType *>(impl)->data;
831 }
832 
833 /// Return a new DenseElementsAttr that has the same data as the current
834 /// attribute, but has been reshaped to 'newType'. The new type must have the
835 /// same total number of elements as well as element type.
836 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
837   ShapedType curType = getType();
838   if (curType == newType)
839     return *this;
840 
841   (void)curType;
842   assert(newType.getElementType() == curType.getElementType() &&
843          "expected the same element type");
844   assert(newType.getNumElements() == curType.getNumElements() &&
845          "expected the same number of elements");
846   return getRaw(newType, getRawData(), isSplat());
847 }
848 
849 DenseElementsAttr
850 DenseElementsAttr::mapValues(Type newElementType,
851                              function_ref<APInt(const APInt &)> mapping) const {
852   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
853 }
854 
855 DenseElementsAttr DenseElementsAttr::mapValues(
856     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
857   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
858 }
859 
860 //===----------------------------------------------------------------------===//
861 // DenseFPElementsAttr
862 //===----------------------------------------------------------------------===//
863 
864 template <typename Fn, typename Attr>
865 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
866                                 Type newElementType,
867                                 llvm::SmallVectorImpl<char> &data) {
868   size_t bitWidth = getDenseElementBitwidth(newElementType);
869   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
870 
871   ShapedType newArrayType;
872   if (inType.isa<RankedTensorType>())
873     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
874   else if (inType.isa<UnrankedTensorType>())
875     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
876   else if (inType.isa<VectorType>())
877     newArrayType = VectorType::get(inType.getShape(), newElementType);
878   else
879     assert(newArrayType && "Unhandled tensor type");
880 
881   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
882   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
883 
884   // Functor used to process a single element value of the attribute.
885   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
886     auto newInt = mapping(value);
887     assert(newInt.getBitWidth() == bitWidth);
888     writeBits(data.data(), index * storageBitWidth, newInt);
889   };
890 
891   // Check for the splat case.
892   if (attr.isSplat()) {
893     processElt(*attr.begin(), /*index=*/0);
894     return newArrayType;
895   }
896 
897   // Otherwise, process all of the element values.
898   uint64_t elementIdx = 0;
899   for (auto value : attr)
900     processElt(value, elementIdx++);
901   return newArrayType;
902 }
903 
904 DenseElementsAttr DenseFPElementsAttr::mapValues(
905     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
906   llvm::SmallVector<char, 8> elementData;
907   auto newArrayType =
908       mappingHelper(mapping, *this, getType(), newElementType, elementData);
909 
910   return getRaw(newArrayType, elementData, isSplat());
911 }
912 
913 /// Method for supporting type inquiry through isa, cast and dyn_cast.
914 bool DenseFPElementsAttr::classof(Attribute attr) {
915   return attr.isa<DenseElementsAttr>() &&
916          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
917 }
918 
919 //===----------------------------------------------------------------------===//
920 // DenseIntElementsAttr
921 //===----------------------------------------------------------------------===//
922 
923 DenseElementsAttr DenseIntElementsAttr::mapValues(
924     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
925   llvm::SmallVector<char, 8> elementData;
926   auto newArrayType =
927       mappingHelper(mapping, *this, getType(), newElementType, elementData);
928 
929   return getRaw(newArrayType, elementData, isSplat());
930 }
931 
932 /// Method for supporting type inquiry through isa, cast and dyn_cast.
933 bool DenseIntElementsAttr::classof(Attribute attr) {
934   return attr.isa<DenseElementsAttr>() &&
935          attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>();
936 }
937 
938 //===----------------------------------------------------------------------===//
939 // OpaqueElementsAttr
940 //===----------------------------------------------------------------------===//
941 
942 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
943                                            StringRef bytes) {
944   assert(TensorType::isValidElementType(type.getElementType()) &&
945          "Input element type should be a valid tensor element type");
946   return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type,
947                    dialect, bytes);
948 }
949 
950 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
951 
952 /// Return the value at the given index. If index does not refer to a valid
953 /// element, then a null attribute is returned.
954 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
955   assert(isValidIndex(index) && "expected valid multi-dimensional index");
956   if (Dialect *dialect = getDialect())
957     return dialect->extractElementHook(*this, index);
958   return Attribute();
959 }
960 
961 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
962 
963 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
964   if (auto *d = getDialect())
965     return d->decodeHook(*this, result);
966   return true;
967 }
968 
969 //===----------------------------------------------------------------------===//
970 // SparseElementsAttr
971 //===----------------------------------------------------------------------===//
972 
973 SparseElementsAttr SparseElementsAttr::get(ShapedType type,
974                                            DenseElementsAttr indices,
975                                            DenseElementsAttr values) {
976   assert(indices.getType().getElementType().isInteger(64) &&
977          "expected sparse indices to be 64-bit integer values");
978   assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
979          "type must be ranked tensor or vector");
980   assert(type.hasStaticShape() && "type must have static shape");
981   return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
982                    indices.cast<DenseIntElementsAttr>(), values);
983 }
984 
985 DenseIntElementsAttr SparseElementsAttr::getIndices() const {
986   return getImpl()->indices;
987 }
988 
989 DenseElementsAttr SparseElementsAttr::getValues() const {
990   return getImpl()->values;
991 }
992 
993 /// Return the value of the element at the given index.
994 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
995   assert(isValidIndex(index) && "expected valid multi-dimensional index");
996   auto type = getType();
997 
998   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
999   // as a 1-D index array.
1000   auto sparseIndices = getIndices();
1001   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1002 
1003   // Check to see if the indices are a splat.
1004   if (sparseIndices.isSplat()) {
1005     // If the index is also not a splat of the index value, we know that the
1006     // value is zero.
1007     auto splatIndex = *sparseIndexValues.begin();
1008     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
1009       return getZeroAttr();
1010 
1011     // If the indices are a splat, we also expect the values to be a splat.
1012     assert(getValues().isSplat() && "expected splat values");
1013     return getValues().getSplatValue();
1014   }
1015 
1016   // Build a mapping between known indices and the offset of the stored element.
1017   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
1018   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1019   size_t rank = type.getRank();
1020   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1021     mappedIndices.try_emplace(
1022         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1023 
1024   // Look for the provided index key within the mapped indices. If the provided
1025   // index is not found, then return a zero attribute.
1026   auto it = mappedIndices.find(index);
1027   if (it == mappedIndices.end())
1028     return getZeroAttr();
1029 
1030   // Otherwise, return the held sparse value element.
1031   return getValues().getValue(it->second);
1032 }
1033 
1034 /// Get a zero APFloat for the given sparse attribute.
1035 APFloat SparseElementsAttr::getZeroAPFloat() const {
1036   auto eltType = getType().getElementType().cast<FloatType>();
1037   return APFloat(eltType.getFloatSemantics());
1038 }
1039 
1040 /// Get a zero APInt for the given sparse attribute.
1041 APInt SparseElementsAttr::getZeroAPInt() const {
1042   auto eltType = getType().getElementType().cast<IntegerType>();
1043   return APInt::getNullValue(eltType.getWidth());
1044 }
1045 
1046 /// Get a zero attribute for the given attribute type.
1047 Attribute SparseElementsAttr::getZeroAttr() const {
1048   auto eltType = getType().getElementType();
1049 
1050   // Handle floating point elements.
1051   if (eltType.isa<FloatType>())
1052     return FloatAttr::get(eltType, 0);
1053 
1054   // Otherwise, this is an integer.
1055   auto intEltTy = eltType.cast<IntegerType>();
1056   if (intEltTy.getWidth() == 1)
1057     return BoolAttr::get(false, eltType.getContext());
1058   return IntegerAttr::get(eltType, 0);
1059 }
1060 
1061 /// Flatten, and return, all of the sparse indices in this attribute in
1062 /// row-major order.
1063 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1064   std::vector<ptrdiff_t> flatSparseIndices;
1065 
1066   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1067   // as a 1-D index array.
1068   auto sparseIndices = getIndices();
1069   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1070   if (sparseIndices.isSplat()) {
1071     SmallVector<uint64_t, 8> indices(getType().getRank(),
1072                                      *sparseIndexValues.begin());
1073     flatSparseIndices.push_back(getFlattenedIndex(indices));
1074     return flatSparseIndices;
1075   }
1076 
1077   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1078   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1079   size_t rank = getType().getRank();
1080   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1081     flatSparseIndices.push_back(getFlattenedIndex(
1082         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1083   return flatSparseIndices;
1084 }
1085 
1086 //===----------------------------------------------------------------------===//
1087 // NamedAttributeList
1088 //===----------------------------------------------------------------------===//
1089 
1090 NamedAttributeList::NamedAttributeList(ArrayRef<NamedAttribute> attributes) {
1091   setAttrs(attributes);
1092 }
1093 
1094 ArrayRef<NamedAttribute> NamedAttributeList::getAttrs() const {
1095   return attrs ? attrs.getValue() : llvm::None;
1096 }
1097 
1098 /// Replace the held attributes with ones provided in 'newAttrs'.
1099 void NamedAttributeList::setAttrs(ArrayRef<NamedAttribute> attributes) {
1100   // Don't create an attribute list if there are no attributes.
1101   if (attributes.empty())
1102     attrs = nullptr;
1103   else
1104     attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
1105 }
1106 
1107 /// Return the specified attribute if present, null otherwise.
1108 Attribute NamedAttributeList::get(StringRef name) const {
1109   return attrs ? attrs.get(name) : nullptr;
1110 }
1111 
1112 /// Return the specified attribute if present, null otherwise.
1113 Attribute NamedAttributeList::get(Identifier name) const {
1114   return attrs ? attrs.get(name) : nullptr;
1115 }
1116 
1117 /// If the an attribute exists with the specified name, change it to the new
1118 /// value.  Otherwise, add a new attribute with the specified name/value.
1119 void NamedAttributeList::set(Identifier name, Attribute value) {
1120   assert(value && "attributes may never be null");
1121 
1122   // If we already have this attribute, replace it.
1123   auto origAttrs = getAttrs();
1124   SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end());
1125   for (auto &elt : newAttrs)
1126     if (elt.first == name) {
1127       elt.second = value;
1128       attrs = DictionaryAttr::get(newAttrs, value.getContext());
1129       return;
1130     }
1131 
1132   // Otherwise, add it.
1133   newAttrs.push_back({name, value});
1134   attrs = DictionaryAttr::get(newAttrs, value.getContext());
1135 }
1136 
1137 /// Remove the attribute with the specified name if it exists.  The return
1138 /// value indicates whether the attribute was present or not.
1139 auto NamedAttributeList::remove(Identifier name) -> RemoveResult {
1140   auto origAttrs = getAttrs();
1141   for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
1142     if (origAttrs[i].first == name) {
1143       // Handle the simple case of removing the only attribute in the list.
1144       if (e == 1) {
1145         attrs = nullptr;
1146         return RemoveResult::Removed;
1147       }
1148 
1149       SmallVector<NamedAttribute, 8> newAttrs;
1150       newAttrs.reserve(origAttrs.size() - 1);
1151       newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
1152       newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
1153       attrs = DictionaryAttr::get(newAttrs, newAttrs[0].second.getContext());
1154       return RemoveResult::Removed;
1155     }
1156   }
1157   return RemoveResult::NotFound;
1158 }
1159