1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Attributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Diagnostics.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/Function.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Types.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
23 //===----------------------------------------------------------------------===//
24 // AttributeStorage
25 //===----------------------------------------------------------------------===//
26 
27 AttributeStorage::AttributeStorage(Type type)
28     : type(type.getAsOpaquePointer()) {}
29 AttributeStorage::AttributeStorage() : type(nullptr) {}
30 
31 Type AttributeStorage::getType() const {
32   return Type::getFromOpaquePointer(type);
33 }
34 void AttributeStorage::setType(Type newType) {
35   type = newType.getAsOpaquePointer();
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Attribute
40 //===----------------------------------------------------------------------===//
41 
42 /// Return the type of this attribute.
43 Type Attribute::getType() const { return impl->getType(); }
44 
45 /// Return the context this attribute belongs to.
46 MLIRContext *Attribute::getContext() const { return getType().getContext(); }
47 
48 /// Get the dialect this attribute is registered to.
49 Dialect &Attribute::getDialect() const { return impl->getDialect(); }
50 
51 //===----------------------------------------------------------------------===//
52 // AffineMapAttr
53 //===----------------------------------------------------------------------===//
54 
55 AffineMapAttr AffineMapAttr::get(AffineMap value) {
56   return Base::get(value.getContext(), StandardAttributes::AffineMap, value);
57 }
58 
59 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
60 
61 //===----------------------------------------------------------------------===//
62 // ArrayAttr
63 //===----------------------------------------------------------------------===//
64 
65 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
66   return Base::get(context, StandardAttributes::Array, value);
67 }
68 
69 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
70 
71 Attribute ArrayAttr::operator[](unsigned idx) const {
72   assert(idx < size() && "index out of bounds");
73   return getValue()[idx];
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // BoolAttr
78 //===----------------------------------------------------------------------===//
79 
80 bool BoolAttr::getValue() const { return getImpl()->value; }
81 
82 //===----------------------------------------------------------------------===//
83 // DictionaryAttr
84 //===----------------------------------------------------------------------===//
85 
86 /// Perform a three-way comparison between the names of the specified
87 /// NamedAttributes.
88 static int compareNamedAttributes(const NamedAttribute *lhs,
89                                   const NamedAttribute *rhs) {
90   return lhs->first.strref().compare(rhs->first.strref());
91 }
92 
93 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
94                                    MLIRContext *context) {
95   assert(llvm::all_of(value,
96                       [](const NamedAttribute &attr) { return attr.second; }) &&
97          "value cannot have null entries");
98 
99   // We need to sort the element list to canonicalize it, but we also don't want
100   // to do a ton of work in the super common case where the element list is
101   // already sorted.
102   SmallVector<NamedAttribute, 8> storage;
103   switch (value.size()) {
104   case 0:
105     break;
106   case 1:
107     // A single element is already sorted.
108     break;
109   case 2:
110     assert(value[0].first != value[1].first &&
111            "DictionaryAttr element names must be unique");
112 
113     // Don't invoke a general sort for two element case.
114     if (value[0].first.strref() > value[1].first.strref()) {
115       storage.push_back(value[1]);
116       storage.push_back(value[0]);
117       value = storage;
118     }
119     break;
120   default:
121     // Check to see they are sorted already.
122     bool isSorted = true;
123     for (unsigned i = 0, e = value.size() - 1; i != e; ++i) {
124       if (value[i].first.strref() > value[i + 1].first.strref()) {
125         isSorted = false;
126         break;
127       }
128     }
129     // If not, do a general sort.
130     if (!isSorted) {
131       storage.append(value.begin(), value.end());
132       llvm::array_pod_sort(storage.begin(), storage.end(),
133                            compareNamedAttributes);
134       value = storage;
135     }
136 
137     // Ensure that the attribute elements are unique.
138     assert(std::adjacent_find(value.begin(), value.end(),
139                               [](NamedAttribute l, NamedAttribute r) {
140                                 return l.first == r.first;
141                               }) == value.end() &&
142            "DictionaryAttr element names must be unique");
143   }
144 
145   return Base::get(context, StandardAttributes::Dictionary, value);
146 }
147 
148 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
149   return getImpl()->getElements();
150 }
151 
152 /// Return the specified attribute if present, null otherwise.
153 Attribute DictionaryAttr::get(StringRef name) const {
154   ArrayRef<NamedAttribute> values = getValue();
155   auto compare = [](NamedAttribute attr, StringRef name) {
156     return attr.first.strref() < name;
157   };
158   auto it = llvm::lower_bound(values, name, compare);
159   return it != values.end() && it->first.is(name) ? it->second : Attribute();
160 }
161 Attribute DictionaryAttr::get(Identifier name) const {
162   for (auto elt : getValue())
163     if (elt.first == name)
164       return elt.second;
165   return nullptr;
166 }
167 
168 DictionaryAttr::iterator DictionaryAttr::begin() const {
169   return getValue().begin();
170 }
171 DictionaryAttr::iterator DictionaryAttr::end() const {
172   return getValue().end();
173 }
174 size_t DictionaryAttr::size() const { return getValue().size(); }
175 
176 //===----------------------------------------------------------------------===//
177 // FloatAttr
178 //===----------------------------------------------------------------------===//
179 
180 FloatAttr FloatAttr::get(Type type, double value) {
181   return Base::get(type.getContext(), StandardAttributes::Float, type, value);
182 }
183 
184 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
185   return Base::getChecked(loc, type.getContext(), StandardAttributes::Float,
186                           type, value);
187 }
188 
189 FloatAttr FloatAttr::get(Type type, const APFloat &value) {
190   return Base::get(type.getContext(), StandardAttributes::Float, type, value);
191 }
192 
193 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
194   return Base::getChecked(loc, type.getContext(), StandardAttributes::Float,
195                           type, value);
196 }
197 
198 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
199 
200 double FloatAttr::getValueAsDouble() const {
201   return getValueAsDouble(getValue());
202 }
203 double FloatAttr::getValueAsDouble(APFloat value) {
204   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
205     bool losesInfo = false;
206     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
207                   &losesInfo);
208   }
209   return value.convertToDouble();
210 }
211 
212 /// Verify construction invariants.
213 static LogicalResult verifyFloatTypeInvariants(Optional<Location> loc,
214                                                Type type) {
215   if (!type.isa<FloatType>())
216     return emitOptionalError(loc, "expected floating point type");
217   return success();
218 }
219 
220 LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
221                                                       MLIRContext *ctx,
222                                                       Type type, double value) {
223   return verifyFloatTypeInvariants(loc, type);
224 }
225 
226 LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
227                                                       MLIRContext *ctx,
228                                                       Type type,
229                                                       const APFloat &value) {
230   // Verify that the type is correct.
231   if (failed(verifyFloatTypeInvariants(loc, type)))
232     return failure();
233 
234   // Verify that the type semantics match that of the value.
235   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
236     return emitOptionalError(
237         loc, "FloatAttr type doesn't match the type implied by its value");
238   }
239   return success();
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // SymbolRefAttr
244 //===----------------------------------------------------------------------===//
245 
246 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
247   return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None)
248       .cast<FlatSymbolRefAttr>();
249 }
250 
251 SymbolRefAttr SymbolRefAttr::get(StringRef value,
252                                  ArrayRef<FlatSymbolRefAttr> nestedReferences,
253                                  MLIRContext *ctx) {
254   return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences);
255 }
256 
257 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
258 
259 StringRef SymbolRefAttr::getLeafReference() const {
260   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
261   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
262 }
263 
264 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
265   return getImpl()->getNestedRefs();
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // IntegerAttr
270 //===----------------------------------------------------------------------===//
271 
272 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
273   return Base::get(type.getContext(), StandardAttributes::Integer, type, value);
274 }
275 
276 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
277   // This uses 64 bit APInts by default for index type.
278   if (type.isIndex())
279     return get(type, APInt(64, value));
280 
281   auto intType = type.cast<IntegerType>();
282   return get(type, APInt(intType.getWidth(), value));
283 }
284 
285 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
286 
287 int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
288 
289 static LogicalResult verifyIntegerTypeInvariants(Optional<Location> loc,
290                                                  Type type) {
291   if (type.isa<IntegerType>() || type.isa<IndexType>())
292     return success();
293   return emitOptionalError(loc, "expected integer or index type");
294 }
295 
296 LogicalResult IntegerAttr::verifyConstructionInvariants(Optional<Location> loc,
297                                                         MLIRContext *ctx,
298                                                         Type type,
299                                                         int64_t value) {
300   return verifyIntegerTypeInvariants(loc, type);
301 }
302 
303 LogicalResult IntegerAttr::verifyConstructionInvariants(Optional<Location> loc,
304                                                         MLIRContext *ctx,
305                                                         Type type,
306                                                         const APInt &value) {
307   if (failed(verifyIntegerTypeInvariants(loc, type)))
308     return failure();
309   if (auto integerType = type.dyn_cast<IntegerType>())
310     if (integerType.getWidth() != value.getBitWidth())
311       return emitOptionalError(
312           loc, "integer type bit width (", integerType.getWidth(),
313           ") doesn't match value bit width (", value.getBitWidth(), ")");
314   return success();
315 }
316 
317 //===----------------------------------------------------------------------===//
318 // IntegerSetAttr
319 //===----------------------------------------------------------------------===//
320 
321 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
322   return Base::get(value.getConstraint(0).getContext(),
323                    StandardAttributes::IntegerSet, value);
324 }
325 
326 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
327 
328 //===----------------------------------------------------------------------===//
329 // OpaqueAttr
330 //===----------------------------------------------------------------------===//
331 
332 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
333                            MLIRContext *context) {
334   return Base::get(context, StandardAttributes::Opaque, dialect, attrData,
335                    type);
336 }
337 
338 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
339                                   Type type, Location location) {
340   return Base::getChecked(location, type.getContext(),
341                           StandardAttributes::Opaque, dialect, attrData, type);
342 }
343 
344 /// Returns the dialect namespace of the opaque attribute.
345 Identifier OpaqueAttr::getDialectNamespace() const {
346   return getImpl()->dialectNamespace;
347 }
348 
349 /// Returns the raw attribute data of the opaque attribute.
350 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
351 
352 /// Verify the construction of an opaque attribute.
353 LogicalResult OpaqueAttr::verifyConstructionInvariants(Optional<Location> loc,
354                                                        MLIRContext *context,
355                                                        Identifier dialect,
356                                                        StringRef attrData,
357                                                        Type type) {
358   if (!Dialect::isValidNamespace(dialect.strref()))
359     return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'");
360   return success();
361 }
362 
363 //===----------------------------------------------------------------------===//
364 // StringAttr
365 //===----------------------------------------------------------------------===//
366 
367 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
368   return get(bytes, NoneType::get(context));
369 }
370 
371 /// Get an instance of a StringAttr with the given string and Type.
372 StringAttr StringAttr::get(StringRef bytes, Type type) {
373   return Base::get(type.getContext(), StandardAttributes::String, bytes, type);
374 }
375 
376 StringRef StringAttr::getValue() const { return getImpl()->value; }
377 
378 //===----------------------------------------------------------------------===//
379 // TypeAttr
380 //===----------------------------------------------------------------------===//
381 
382 TypeAttr TypeAttr::get(Type value) {
383   return Base::get(value.getContext(), StandardAttributes::Type, value);
384 }
385 
386 Type TypeAttr::getValue() const { return getImpl()->value; }
387 
388 //===----------------------------------------------------------------------===//
389 // ElementsAttr
390 //===----------------------------------------------------------------------===//
391 
392 ShapedType ElementsAttr::getType() const {
393   return Attribute::getType().cast<ShapedType>();
394 }
395 
396 /// Returns the number of elements held by this attribute.
397 int64_t ElementsAttr::getNumElements() const {
398   return getType().getNumElements();
399 }
400 
401 /// Return the value at the given index. If index does not refer to a valid
402 /// element, then a null attribute is returned.
403 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
404   switch (getKind()) {
405   case StandardAttributes::DenseElements:
406     return cast<DenseElementsAttr>().getValue(index);
407   case StandardAttributes::OpaqueElements:
408     return cast<OpaqueElementsAttr>().getValue(index);
409   case StandardAttributes::SparseElements:
410     return cast<SparseElementsAttr>().getValue(index);
411   default:
412     llvm_unreachable("unknown ElementsAttr kind");
413   }
414 }
415 
416 /// Return if the given 'index' refers to a valid element in this attribute.
417 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
418   auto type = getType();
419 
420   // Verify that the rank of the indices matches the held type.
421   auto rank = type.getRank();
422   if (rank != static_cast<int64_t>(index.size()))
423     return false;
424 
425   // Verify that all of the indices are within the shape dimensions.
426   auto shape = type.getShape();
427   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
428     return static_cast<int64_t>(index[i]) < shape[i];
429   });
430 }
431 
432 ElementsAttr
433 ElementsAttr::mapValues(Type newElementType,
434                         function_ref<APInt(const APInt &)> mapping) const {
435   switch (getKind()) {
436   case StandardAttributes::DenseElements:
437     return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
438   default:
439     llvm_unreachable("unsupported ElementsAttr subtype");
440   }
441 }
442 
443 ElementsAttr
444 ElementsAttr::mapValues(Type newElementType,
445                         function_ref<APInt(const APFloat &)> mapping) const {
446   switch (getKind()) {
447   case StandardAttributes::DenseElements:
448     return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
449   default:
450     llvm_unreachable("unsupported ElementsAttr subtype");
451   }
452 }
453 
454 /// Returns the 1 dimensional flattened row-major index from the given
455 /// multi-dimensional index.
456 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
457   assert(isValidIndex(index) && "expected valid multi-dimensional index");
458   auto type = getType();
459 
460   // Reduce the provided multidimensional index into a flattended 1D row-major
461   // index.
462   auto rank = type.getRank();
463   auto shape = type.getShape();
464   uint64_t valueIndex = 0;
465   uint64_t dimMultiplier = 1;
466   for (int i = rank - 1; i >= 0; --i) {
467     valueIndex += index[i] * dimMultiplier;
468     dimMultiplier *= shape[i];
469   }
470   return valueIndex;
471 }
472 
473 //===----------------------------------------------------------------------===//
474 // DenseElementAttr Utilities
475 //===----------------------------------------------------------------------===//
476 
477 static size_t getDenseElementBitwidth(Type eltType) {
478   // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
479   // with double semantics.
480   return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
481 }
482 
483 /// Get the bitwidth of a dense element type within the buffer.
484 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
485 static size_t getDenseElementStorageWidth(size_t origWidth) {
486   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
487 }
488 
489 /// Set a bit to a specific value.
490 static void setBit(char *rawData, size_t bitPos, bool value) {
491   if (value)
492     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
493   else
494     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
495 }
496 
497 /// Return the value of the specified bit.
498 static bool getBit(const char *rawData, size_t bitPos) {
499   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
500 }
501 
502 /// Writes value to the bit position `bitPos` in array `rawData`.
503 static void writeBits(char *rawData, size_t bitPos, APInt value) {
504   size_t bitWidth = value.getBitWidth();
505 
506   // If the bitwidth is 1 we just toggle the specific bit.
507   if (bitWidth == 1)
508     return setBit(rawData, bitPos, value.isOneValue());
509 
510   // Otherwise, the bit position is guaranteed to be byte aligned.
511   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
512   std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
513               llvm::divideCeil(bitWidth, CHAR_BIT),
514               rawData + (bitPos / CHAR_BIT));
515 }
516 
517 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
518 /// `rawData`.
519 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
520   // Handle a boolean bit position.
521   if (bitWidth == 1)
522     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
523 
524   // Otherwise, the bit position must be 8-bit aligned.
525   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
526   APInt result(bitWidth, 0);
527   std::copy_n(
528       rawData + (bitPos / CHAR_BIT), llvm::divideCeil(bitWidth, CHAR_BIT),
529       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
530   return result;
531 }
532 
533 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the
534 /// same element count as 'type'.
535 template <typename Values>
536 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
537   return (values.size() == 1) ||
538          (type.getNumElements() == static_cast<int64_t>(values.size()));
539 }
540 
541 //===----------------------------------------------------------------------===//
542 // DenseElementAttr Iterators
543 //===----------------------------------------------------------------------===//
544 
545 /// Constructs a new iterator.
546 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
547     DenseElementsAttr attr, size_t index)
548     : indexed_accessor_iterator<AttributeElementIterator, const void *,
549                                 Attribute, Attribute, Attribute>(
550           attr.getAsOpaquePointer(), index) {}
551 
552 /// Accesses the Attribute value at this iterator position.
553 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
554   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
555   Type eltTy = owner.getType().getElementType();
556   if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) {
557     if (intEltTy.getWidth() == 1)
558       return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(),
559                            owner.getContext());
560     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
561   }
562   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
563     IntElementIterator intIt(owner, index);
564     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
565     return FloatAttr::get(eltTy, *floatIt);
566   }
567   llvm_unreachable("unexpected element type");
568 }
569 
570 /// Constructs a new iterator.
571 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
572     DenseElementsAttr attr, size_t dataIndex)
573     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
574           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
575 
576 /// Accesses the bool value at this iterator position.
577 bool DenseElementsAttr::BoolElementIterator::operator*() const {
578   return getBit(getData(), getDataIndex());
579 }
580 
581 /// Constructs a new iterator.
582 DenseElementsAttr::IntElementIterator::IntElementIterator(
583     DenseElementsAttr attr, size_t dataIndex)
584     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
585           attr.getRawData().data(), attr.isSplat(), dataIndex),
586       bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {}
587 
588 /// Accesses the raw APInt value at this iterator position.
589 APInt DenseElementsAttr::IntElementIterator::operator*() const {
590   return readBits(getData(),
591                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
592                   bitWidth);
593 }
594 
595 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
596     const llvm::fltSemantics &smt, IntElementIterator it)
597     : llvm::mapped_iterator<IntElementIterator,
598                             std::function<APFloat(const APInt &)>>(
599           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
600 
601 //===----------------------------------------------------------------------===//
602 // DenseElementsAttr
603 //===----------------------------------------------------------------------===//
604 
605 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
606                                          ArrayRef<Attribute> values) {
607   assert(type.getElementType().isIntOrFloat() &&
608          "expected int or float element type");
609   assert(hasSameElementsOrSplat(type, values));
610 
611   auto eltType = type.getElementType();
612   size_t bitWidth = getDenseElementBitwidth(eltType);
613   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
614 
615   // Compress the attribute values into a character buffer.
616   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
617                             values.size());
618   APInt intVal;
619   for (unsigned i = 0, e = values.size(); i < e; ++i) {
620     assert(eltType == values[i].getType() &&
621            "expected attribute value to have element type");
622 
623     switch (eltType.getKind()) {
624     case StandardTypes::BF16:
625     case StandardTypes::F16:
626     case StandardTypes::F32:
627     case StandardTypes::F64:
628       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
629       break;
630     case StandardTypes::Integer:
631       intVal = values[i].isa<BoolAttr>()
632                    ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0)
633                    : values[i].cast<IntegerAttr>().getValue();
634       break;
635     default:
636       llvm_unreachable("unexpected element type");
637     }
638     assert(intVal.getBitWidth() == bitWidth &&
639            "expected value to have same bitwidth as element type");
640     writeBits(data.data(), i * storageBitWidth, intVal);
641   }
642   return getRaw(type, data, /*isSplat=*/(values.size() == 1));
643 }
644 
645 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
646                                          ArrayRef<bool> values) {
647   assert(hasSameElementsOrSplat(type, values));
648   assert(type.getElementType().isInteger(1));
649 
650   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
651   for (int i = 0, e = values.size(); i != e; ++i)
652     setBit(buff.data(), i, values[i]);
653   return getRaw(type, buff, /*isSplat=*/(values.size() == 1));
654 }
655 
656 /// Constructs a dense integer elements attribute from an array of APInt
657 /// values. Each APInt value is expected to have the same bitwidth as the
658 /// element type of 'type'.
659 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
660                                          ArrayRef<APInt> values) {
661   assert(type.getElementType().isa<IntegerType>());
662   return getRaw(type, values);
663 }
664 
665 // Constructs a dense float elements attribute from an array of APFloat
666 // values. Each APFloat value is expected to have the same bitwidth as the
667 // element type of 'type'.
668 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
669                                          ArrayRef<APFloat> values) {
670   assert(type.getElementType().isa<FloatType>());
671 
672   // Convert the APFloat values to APInt and create a dense elements attribute.
673   std::vector<APInt> intValues(values.size());
674   for (unsigned i = 0, e = values.size(); i != e; ++i)
675     intValues[i] = values[i].bitcastToAPInt();
676   return getRaw(type, intValues);
677 }
678 
679 // Constructs a dense elements attribute from an array of raw APInt values.
680 // Each APInt value is expected to have the same bitwidth as the element type
681 // of 'type'.
682 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
683                                             ArrayRef<APInt> values) {
684   assert(hasSameElementsOrSplat(type, values));
685 
686   size_t bitWidth = getDenseElementBitwidth(type.getElementType());
687   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
688   std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
689                                 values.size());
690   for (unsigned i = 0, e = values.size(); i != e; ++i) {
691     assert(values[i].getBitWidth() == bitWidth);
692     writeBits(elementData.data(), i * storageBitWidth, values[i]);
693   }
694   return getRaw(type, elementData, /*isSplat=*/(values.size() == 1));
695 }
696 
697 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
698                                             ArrayRef<char> data, bool isSplat) {
699   assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
700          "type must be ranked tensor or vector");
701   assert(type.hasStaticShape() && "type must have static shape");
702   return Base::get(type.getContext(), StandardAttributes::DenseElements, type,
703                    data, isSplat);
704 }
705 
706 /// Check the information for a c++ data type, check if this type is valid for
707 /// the current attribute. This method is used to verify specific type
708 /// invariants that the templatized 'getValues' method cannot.
709 static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize,
710                               bool isInt) {
711   // Make sure that the data element size is the same as the type element width.
712   if (getDenseElementBitwidth(type.getElementType()) !=
713       static_cast<size_t>(dataEltSize * CHAR_BIT))
714     return false;
715 
716   // Check that the element type is valid.
717   return isInt ? type.getElementType().isa<IntegerType>()
718                : type.getElementType().isa<FloatType>();
719 }
720 
721 /// Overload of the 'getRaw' method that asserts that the given type is of
722 /// integer type. This method is used to verify type invariants that the
723 /// templatized 'get' method cannot.
724 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
725                                                       ArrayRef<char> data,
726                                                       int64_t dataEltSize,
727                                                       bool isInt) {
728   assert(::isValidIntOrFloat(type, dataEltSize, isInt));
729 
730   int64_t numElements = data.size() / dataEltSize;
731   assert(numElements == 1 || numElements == type.getNumElements());
732   return getRaw(type, data, /*isSplat=*/numElements == 1);
733 }
734 
735 /// A method used to verify specific type invariants that the templatized 'get'
736 /// method cannot.
737 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize,
738                                           bool isInt) const {
739   return ::isValidIntOrFloat(getType(), dataEltSize, isInt);
740 }
741 
742 /// Return the raw storage data held by this attribute.
743 ArrayRef<char> DenseElementsAttr::getRawData() const {
744   return static_cast<ImplType *>(impl)->data;
745 }
746 
747 /// Returns if this attribute corresponds to a splat, i.e. if all element
748 /// values are the same.
749 bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; }
750 
751 /// Return the held element values as a range of Attributes.
752 auto DenseElementsAttr::getAttributeValues() const
753     -> llvm::iterator_range<AttributeElementIterator> {
754   return {attr_value_begin(), attr_value_end()};
755 }
756 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
757   return AttributeElementIterator(*this, 0);
758 }
759 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
760   return AttributeElementIterator(*this, getNumElements());
761 }
762 
763 /// Return the held element values as a range of bool. The element type of
764 /// this attribute must be of integer type of bitwidth 1.
765 auto DenseElementsAttr::getBoolValues() const
766     -> llvm::iterator_range<BoolElementIterator> {
767   auto eltType = getType().getElementType().dyn_cast<IntegerType>();
768   assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
769   (void)eltType;
770   return {BoolElementIterator(*this, 0),
771           BoolElementIterator(*this, getNumElements())};
772 }
773 
774 /// Return the held element values as a range of APInts. The element type of
775 /// this attribute must be of integer type.
776 auto DenseElementsAttr::getIntValues() const
777     -> llvm::iterator_range<IntElementIterator> {
778   assert(getType().getElementType().isa<IntegerType>() &&
779          "expected integer type");
780   return {raw_int_begin(), raw_int_end()};
781 }
782 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
783   assert(getType().getElementType().isa<IntegerType>() &&
784          "expected integer type");
785   return raw_int_begin();
786 }
787 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
788   assert(getType().getElementType().isa<IntegerType>() &&
789          "expected integer type");
790   return raw_int_end();
791 }
792 
793 /// Return the held element values as a range of APFloat. The element type of
794 /// this attribute must be of float type.
795 auto DenseElementsAttr::getFloatValues() const
796     -> llvm::iterator_range<FloatElementIterator> {
797   auto elementType = getType().getElementType().cast<FloatType>();
798   assert(elementType.isa<FloatType>() && "expected float type");
799   const auto &elementSemantics = elementType.getFloatSemantics();
800   return {FloatElementIterator(elementSemantics, raw_int_begin()),
801           FloatElementIterator(elementSemantics, raw_int_end())};
802 }
803 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
804   return getFloatValues().begin();
805 }
806 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
807   return getFloatValues().end();
808 }
809 
810 /// Return a new DenseElementsAttr that has the same data as the current
811 /// attribute, but has been reshaped to 'newType'. The new type must have the
812 /// same total number of elements as well as element type.
813 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
814   ShapedType curType = getType();
815   if (curType == newType)
816     return *this;
817 
818   (void)curType;
819   assert(newType.getElementType() == curType.getElementType() &&
820          "expected the same element type");
821   assert(newType.getNumElements() == curType.getNumElements() &&
822          "expected the same number of elements");
823   return getRaw(newType, getRawData(), isSplat());
824 }
825 
826 DenseElementsAttr
827 DenseElementsAttr::mapValues(Type newElementType,
828                              function_ref<APInt(const APInt &)> mapping) const {
829   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
830 }
831 
832 DenseElementsAttr DenseElementsAttr::mapValues(
833     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
834   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
835 }
836 
837 //===----------------------------------------------------------------------===//
838 // DenseFPElementsAttr
839 //===----------------------------------------------------------------------===//
840 
841 template <typename Fn, typename Attr>
842 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
843                                 Type newElementType,
844                                 llvm::SmallVectorImpl<char> &data) {
845   size_t bitWidth = getDenseElementBitwidth(newElementType);
846   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
847 
848   ShapedType newArrayType;
849   if (inType.isa<RankedTensorType>())
850     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
851   else if (inType.isa<UnrankedTensorType>())
852     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
853   else if (inType.isa<VectorType>())
854     newArrayType = VectorType::get(inType.getShape(), newElementType);
855   else
856     assert(newArrayType && "Unhandled tensor type");
857 
858   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
859   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
860 
861   // Functor used to process a single element value of the attribute.
862   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
863     auto newInt = mapping(value);
864     assert(newInt.getBitWidth() == bitWidth);
865     writeBits(data.data(), index * storageBitWidth, newInt);
866   };
867 
868   // Check for the splat case.
869   if (attr.isSplat()) {
870     processElt(*attr.begin(), /*index=*/0);
871     return newArrayType;
872   }
873 
874   // Otherwise, process all of the element values.
875   uint64_t elementIdx = 0;
876   for (auto value : attr)
877     processElt(value, elementIdx++);
878   return newArrayType;
879 }
880 
881 DenseElementsAttr DenseFPElementsAttr::mapValues(
882     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
883   llvm::SmallVector<char, 8> elementData;
884   auto newArrayType =
885       mappingHelper(mapping, *this, getType(), newElementType, elementData);
886 
887   return getRaw(newArrayType, elementData, isSplat());
888 }
889 
890 /// Method for supporting type inquiry through isa, cast and dyn_cast.
891 bool DenseFPElementsAttr::classof(Attribute attr) {
892   return attr.isa<DenseElementsAttr>() &&
893          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
894 }
895 
896 //===----------------------------------------------------------------------===//
897 // DenseIntElementsAttr
898 //===----------------------------------------------------------------------===//
899 
900 DenseElementsAttr DenseIntElementsAttr::mapValues(
901     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
902   llvm::SmallVector<char, 8> elementData;
903   auto newArrayType =
904       mappingHelper(mapping, *this, getType(), newElementType, elementData);
905 
906   return getRaw(newArrayType, elementData, isSplat());
907 }
908 
909 /// Method for supporting type inquiry through isa, cast and dyn_cast.
910 bool DenseIntElementsAttr::classof(Attribute attr) {
911   return attr.isa<DenseElementsAttr>() &&
912          attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>();
913 }
914 
915 //===----------------------------------------------------------------------===//
916 // OpaqueElementsAttr
917 //===----------------------------------------------------------------------===//
918 
919 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
920                                            StringRef bytes) {
921   assert(TensorType::isValidElementType(type.getElementType()) &&
922          "Input element type should be a valid tensor element type");
923   return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type,
924                    dialect, bytes);
925 }
926 
927 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
928 
929 /// Return the value at the given index. If index does not refer to a valid
930 /// element, then a null attribute is returned.
931 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
932   assert(isValidIndex(index) && "expected valid multi-dimensional index");
933   if (Dialect *dialect = getDialect())
934     return dialect->extractElementHook(*this, index);
935   return Attribute();
936 }
937 
938 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
939 
940 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
941   if (auto *d = getDialect())
942     return d->decodeHook(*this, result);
943   return true;
944 }
945 
946 //===----------------------------------------------------------------------===//
947 // SparseElementsAttr
948 //===----------------------------------------------------------------------===//
949 
950 SparseElementsAttr SparseElementsAttr::get(ShapedType type,
951                                            DenseElementsAttr indices,
952                                            DenseElementsAttr values) {
953   assert(indices.getType().getElementType().isInteger(64) &&
954          "expected sparse indices to be 64-bit integer values");
955   assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
956          "type must be ranked tensor or vector");
957   assert(type.hasStaticShape() && "type must have static shape");
958   return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
959                    indices.cast<DenseIntElementsAttr>(), values);
960 }
961 
962 DenseIntElementsAttr SparseElementsAttr::getIndices() const {
963   return getImpl()->indices;
964 }
965 
966 DenseElementsAttr SparseElementsAttr::getValues() const {
967   return getImpl()->values;
968 }
969 
970 /// Return the value of the element at the given index.
971 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
972   assert(isValidIndex(index) && "expected valid multi-dimensional index");
973   auto type = getType();
974 
975   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
976   // as a 1-D index array.
977   auto sparseIndices = getIndices();
978   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
979 
980   // Check to see if the indices are a splat.
981   if (sparseIndices.isSplat()) {
982     // If the index is also not a splat of the index value, we know that the
983     // value is zero.
984     auto splatIndex = *sparseIndexValues.begin();
985     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
986       return getZeroAttr();
987 
988     // If the indices are a splat, we also expect the values to be a splat.
989     assert(getValues().isSplat() && "expected splat values");
990     return getValues().getSplatValue();
991   }
992 
993   // Build a mapping between known indices and the offset of the stored element.
994   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
995   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
996   size_t rank = type.getRank();
997   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
998     mappedIndices.try_emplace(
999         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1000 
1001   // Look for the provided index key within the mapped indices. If the provided
1002   // index is not found, then return a zero attribute.
1003   auto it = mappedIndices.find(index);
1004   if (it == mappedIndices.end())
1005     return getZeroAttr();
1006 
1007   // Otherwise, return the held sparse value element.
1008   return getValues().getValue(it->second);
1009 }
1010 
1011 /// Get a zero APFloat for the given sparse attribute.
1012 APFloat SparseElementsAttr::getZeroAPFloat() const {
1013   auto eltType = getType().getElementType().cast<FloatType>();
1014   return APFloat(eltType.getFloatSemantics());
1015 }
1016 
1017 /// Get a zero APInt for the given sparse attribute.
1018 APInt SparseElementsAttr::getZeroAPInt() const {
1019   auto eltType = getType().getElementType().cast<IntegerType>();
1020   return APInt::getNullValue(eltType.getWidth());
1021 }
1022 
1023 /// Get a zero attribute for the given attribute type.
1024 Attribute SparseElementsAttr::getZeroAttr() const {
1025   auto eltType = getType().getElementType();
1026 
1027   // Handle floating point elements.
1028   if (eltType.isa<FloatType>())
1029     return FloatAttr::get(eltType, 0);
1030 
1031   // Otherwise, this is an integer.
1032   auto intEltTy = eltType.cast<IntegerType>();
1033   if (intEltTy.getWidth() == 1)
1034     return BoolAttr::get(false, eltType.getContext());
1035   return IntegerAttr::get(eltType, 0);
1036 }
1037 
1038 /// Flatten, and return, all of the sparse indices in this attribute in
1039 /// row-major order.
1040 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1041   std::vector<ptrdiff_t> flatSparseIndices;
1042 
1043   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1044   // as a 1-D index array.
1045   auto sparseIndices = getIndices();
1046   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1047   if (sparseIndices.isSplat()) {
1048     SmallVector<uint64_t, 8> indices(getType().getRank(),
1049                                      *sparseIndexValues.begin());
1050     flatSparseIndices.push_back(getFlattenedIndex(indices));
1051     return flatSparseIndices;
1052   }
1053 
1054   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1055   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1056   size_t rank = getType().getRank();
1057   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1058     flatSparseIndices.push_back(getFlattenedIndex(
1059         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1060   return flatSparseIndices;
1061 }
1062 
1063 //===----------------------------------------------------------------------===//
1064 // NamedAttributeList
1065 //===----------------------------------------------------------------------===//
1066 
1067 NamedAttributeList::NamedAttributeList(ArrayRef<NamedAttribute> attributes) {
1068   setAttrs(attributes);
1069 }
1070 
1071 ArrayRef<NamedAttribute> NamedAttributeList::getAttrs() const {
1072   return attrs ? attrs.getValue() : llvm::None;
1073 }
1074 
1075 /// Replace the held attributes with ones provided in 'newAttrs'.
1076 void NamedAttributeList::setAttrs(ArrayRef<NamedAttribute> attributes) {
1077   // Don't create an attribute list if there are no attributes.
1078   if (attributes.empty())
1079     attrs = nullptr;
1080   else
1081     attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
1082 }
1083 
1084 /// Return the specified attribute if present, null otherwise.
1085 Attribute NamedAttributeList::get(StringRef name) const {
1086   return attrs ? attrs.get(name) : nullptr;
1087 }
1088 
1089 /// Return the specified attribute if present, null otherwise.
1090 Attribute NamedAttributeList::get(Identifier name) const {
1091   return attrs ? attrs.get(name) : nullptr;
1092 }
1093 
1094 /// If the an attribute exists with the specified name, change it to the new
1095 /// value.  Otherwise, add a new attribute with the specified name/value.
1096 void NamedAttributeList::set(Identifier name, Attribute value) {
1097   assert(value && "attributes may never be null");
1098 
1099   // If we already have this attribute, replace it.
1100   auto origAttrs = getAttrs();
1101   SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end());
1102   for (auto &elt : newAttrs)
1103     if (elt.first == name) {
1104       elt.second = value;
1105       attrs = DictionaryAttr::get(newAttrs, value.getContext());
1106       return;
1107     }
1108 
1109   // Otherwise, add it.
1110   newAttrs.push_back({name, value});
1111   attrs = DictionaryAttr::get(newAttrs, value.getContext());
1112 }
1113 
1114 /// Remove the attribute with the specified name if it exists.  The return
1115 /// value indicates whether the attribute was present or not.
1116 auto NamedAttributeList::remove(Identifier name) -> RemoveResult {
1117   auto origAttrs = getAttrs();
1118   for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
1119     if (origAttrs[i].first == name) {
1120       // Handle the simple case of removing the only attribute in the list.
1121       if (e == 1) {
1122         attrs = nullptr;
1123         return RemoveResult::Removed;
1124       }
1125 
1126       SmallVector<NamedAttribute, 8> newAttrs;
1127       newAttrs.reserve(origAttrs.size() - 1);
1128       newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
1129       newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
1130       attrs = DictionaryAttr::get(newAttrs, newAttrs[0].second.getContext());
1131       return RemoveResult::Removed;
1132     }
1133   }
1134   return RemoveResult::NotFound;
1135 }
1136