1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Attributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Diagnostics.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/Function.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Types.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Endian.h"
20 
21 using namespace mlir;
22 using namespace mlir::detail;
23 
24 //===----------------------------------------------------------------------===//
25 // AttributeStorage
26 //===----------------------------------------------------------------------===//
27 
28 AttributeStorage::AttributeStorage(Type type)
29     : type(type.getAsOpaquePointer()) {}
30 AttributeStorage::AttributeStorage() : type(nullptr) {}
31 
32 Type AttributeStorage::getType() const {
33   return Type::getFromOpaquePointer(type);
34 }
35 void AttributeStorage::setType(Type newType) {
36   type = newType.getAsOpaquePointer();
37 }
38 
39 //===----------------------------------------------------------------------===//
40 // Attribute
41 //===----------------------------------------------------------------------===//
42 
43 /// Return the type of this attribute.
44 Type Attribute::getType() const { return impl->getType(); }
45 
46 /// Return the context this attribute belongs to.
47 MLIRContext *Attribute::getContext() const { return getType().getContext(); }
48 
49 /// Get the dialect this attribute is registered to.
50 Dialect &Attribute::getDialect() const {
51   return impl->getAbstractAttribute().getDialect();
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // AffineMapAttr
56 //===----------------------------------------------------------------------===//
57 
58 AffineMapAttr AffineMapAttr::get(AffineMap value) {
59   return Base::get(value.getContext(), StandardAttributes::AffineMap, value);
60 }
61 
62 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
63 
64 //===----------------------------------------------------------------------===//
65 // ArrayAttr
66 //===----------------------------------------------------------------------===//
67 
68 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
69   return Base::get(context, StandardAttributes::Array, value);
70 }
71 
72 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
73 
74 Attribute ArrayAttr::operator[](unsigned idx) const {
75   assert(idx < size() && "index out of bounds");
76   return getValue()[idx];
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // DictionaryAttr
81 //===----------------------------------------------------------------------===//
82 
83 /// Helper function that does either an in place sort or sorts from source array
84 /// into destination. If inPlace then storage is both the source and the
85 /// destination, else value is the source and storage destination. Returns
86 /// whether source was sorted.
87 template <bool inPlace>
88 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
89                                SmallVectorImpl<NamedAttribute> &storage) {
90   // Specialize for the common case.
91   switch (value.size()) {
92   case 0:
93     // Zero already sorted.
94     break;
95   case 1:
96     // One already sorted but may need to be copied.
97     if (!inPlace)
98       storage.assign({value[0]});
99     break;
100   case 2: {
101     assert(value[0].first != value[1].first &&
102            "DictionaryAttr element names must be unique");
103     bool isSorted = value[0] < value[1];
104     if (inPlace) {
105       if (!isSorted)
106         std::swap(storage[0], storage[1]);
107     } else if (isSorted) {
108       storage.assign({value[0], value[1]});
109     } else {
110       storage.assign({value[1], value[0]});
111     }
112     return !isSorted;
113   }
114   default:
115     if (!inPlace)
116       storage.assign(value.begin(), value.end());
117     // Check to see they are sorted already.
118     bool isSorted = llvm::is_sorted(value);
119     if (!isSorted) {
120       // If not, do a general sort.
121       llvm::array_pod_sort(storage.begin(), storage.end());
122       value = storage;
123     }
124 
125     // Ensure that the attribute elements are unique.
126     assert(std::adjacent_find(value.begin(), value.end(),
127                               [](NamedAttribute l, NamedAttribute r) {
128                                 return l.first == r.first;
129                               }) == value.end() &&
130            "DictionaryAttr element names must be unique");
131     return !isSorted;
132   }
133   return false;
134 }
135 
136 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
137                           SmallVectorImpl<NamedAttribute> &storage) {
138   return dictionaryAttrSort</*inPlace=*/false>(value, storage);
139 }
140 
141 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
142   return dictionaryAttrSort</*inPlace=*/true>(array, array);
143 }
144 
145 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
146                                    MLIRContext *context) {
147   if (value.empty())
148     return DictionaryAttr::getEmpty(context);
149   assert(llvm::all_of(value,
150                       [](const NamedAttribute &attr) { return attr.second; }) &&
151          "value cannot have null entries");
152 
153   // We need to sort the element list to canonicalize it.
154   SmallVector<NamedAttribute, 8> storage;
155   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
156     value = storage;
157 
158   return Base::get(context, StandardAttributes::Dictionary, value);
159 }
160 /// Construct a dictionary with an array of values that is known to already be
161 /// sorted by name and uniqued.
162 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
163                                              MLIRContext *context) {
164   if (value.empty())
165     return DictionaryAttr::getEmpty(context);
166   // Ensure that the attribute elements are unique and sorted.
167   assert(llvm::is_sorted(value,
168                          [](NamedAttribute l, NamedAttribute r) {
169                            return l.first.strref() < r.first.strref();
170                          }) &&
171          "expected attribute values to be sorted");
172   assert(std::adjacent_find(value.begin(), value.end(),
173                             [](NamedAttribute l, NamedAttribute r) {
174                               return l.first == r.first;
175                             }) == value.end() &&
176          "DictionaryAttr element names must be unique");
177   return Base::get(context, StandardAttributes::Dictionary, value);
178 }
179 
180 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
181   return getImpl()->getElements();
182 }
183 
184 /// Return the specified attribute if present, null otherwise.
185 Attribute DictionaryAttr::get(StringRef name) const {
186   Optional<NamedAttribute> attr = getNamed(name);
187   return attr ? attr->second : nullptr;
188 }
189 Attribute DictionaryAttr::get(Identifier name) const {
190   Optional<NamedAttribute> attr = getNamed(name);
191   return attr ? attr->second : nullptr;
192 }
193 
194 /// Return the specified named attribute if present, None otherwise.
195 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
196   ArrayRef<NamedAttribute> values = getValue();
197   const auto *it = llvm::lower_bound(values, name);
198   return it != values.end() && it->first == name ? *it
199                                                  : Optional<NamedAttribute>();
200 }
201 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
202   for (auto elt : getValue())
203     if (elt.first == name)
204       return elt;
205   return llvm::None;
206 }
207 
208 DictionaryAttr::iterator DictionaryAttr::begin() const {
209   return getValue().begin();
210 }
211 DictionaryAttr::iterator DictionaryAttr::end() const {
212   return getValue().end();
213 }
214 size_t DictionaryAttr::size() const { return getValue().size(); }
215 
216 //===----------------------------------------------------------------------===//
217 // FloatAttr
218 //===----------------------------------------------------------------------===//
219 
220 FloatAttr FloatAttr::get(Type type, double value) {
221   return Base::get(type.getContext(), StandardAttributes::Float, type, value);
222 }
223 
224 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
225   return Base::getChecked(loc, StandardAttributes::Float, type, value);
226 }
227 
228 FloatAttr FloatAttr::get(Type type, const APFloat &value) {
229   return Base::get(type.getContext(), StandardAttributes::Float, type, value);
230 }
231 
232 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
233   return Base::getChecked(loc, StandardAttributes::Float, type, value);
234 }
235 
236 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
237 
238 double FloatAttr::getValueAsDouble() const {
239   return getValueAsDouble(getValue());
240 }
241 double FloatAttr::getValueAsDouble(APFloat value) {
242   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
243     bool losesInfo = false;
244     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
245                   &losesInfo);
246   }
247   return value.convertToDouble();
248 }
249 
250 /// Verify construction invariants.
251 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
252   if (!type.isa<FloatType>())
253     return emitError(loc, "expected floating point type");
254   return success();
255 }
256 
257 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
258                                                       double value) {
259   return verifyFloatTypeInvariants(loc, type);
260 }
261 
262 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
263                                                       const APFloat &value) {
264   // Verify that the type is correct.
265   if (failed(verifyFloatTypeInvariants(loc, type)))
266     return failure();
267 
268   // Verify that the type semantics match that of the value.
269   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
270     return emitError(
271         loc, "FloatAttr type doesn't match the type implied by its value");
272   }
273   return success();
274 }
275 
276 //===----------------------------------------------------------------------===//
277 // SymbolRefAttr
278 //===----------------------------------------------------------------------===//
279 
280 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
281   return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None)
282       .cast<FlatSymbolRefAttr>();
283 }
284 
285 SymbolRefAttr SymbolRefAttr::get(StringRef value,
286                                  ArrayRef<FlatSymbolRefAttr> nestedReferences,
287                                  MLIRContext *ctx) {
288   return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences);
289 }
290 
291 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
292 
293 StringRef SymbolRefAttr::getLeafReference() const {
294   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
295   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
296 }
297 
298 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
299   return getImpl()->getNestedRefs();
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // IntegerAttr
304 //===----------------------------------------------------------------------===//
305 
306 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
307   if (type.isSignlessInteger(1))
308     return BoolAttr::get(value.getBoolValue(), type.getContext());
309   return Base::get(type.getContext(), StandardAttributes::Integer, type, value);
310 }
311 
312 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
313   // This uses 64 bit APInts by default for index type.
314   if (type.isIndex())
315     return get(type, APInt(IndexType::kInternalStorageBitWidth, value));
316 
317   auto intType = type.cast<IntegerType>();
318   return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
319 }
320 
321 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
322 
323 int64_t IntegerAttr::getInt() const {
324   assert((getImpl()->getType().isIndex() ||
325           getImpl()->getType().isSignlessInteger()) &&
326          "must be signless integer");
327   return getValue().getSExtValue();
328 }
329 
330 int64_t IntegerAttr::getSInt() const {
331   assert(getImpl()->getType().isSignedInteger() && "must be signed integer");
332   return getValue().getSExtValue();
333 }
334 
335 uint64_t IntegerAttr::getUInt() const {
336   assert(getImpl()->getType().isUnsignedInteger() &&
337          "must be unsigned integer");
338   return getValue().getZExtValue();
339 }
340 
341 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
342   if (type.isa<IntegerType, IndexType>())
343     return success();
344   return emitError(loc, "expected integer or index type");
345 }
346 
347 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
348                                                         int64_t value) {
349   return verifyIntegerTypeInvariants(loc, type);
350 }
351 
352 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
353                                                         const APInt &value) {
354   if (failed(verifyIntegerTypeInvariants(loc, type)))
355     return failure();
356   if (auto integerType = type.dyn_cast<IntegerType>())
357     if (integerType.getWidth() != value.getBitWidth())
358       return emitError(loc, "integer type bit width (")
359              << integerType.getWidth() << ") doesn't match value bit width ("
360              << value.getBitWidth() << ")";
361   return success();
362 }
363 
364 //===----------------------------------------------------------------------===//
365 // BoolAttr
366 
367 bool BoolAttr::getValue() const {
368   auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl);
369   return storage->getValue().getBoolValue();
370 }
371 
372 bool BoolAttr::classof(Attribute attr) {
373   IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
374   return intAttr && intAttr.getType().isSignlessInteger(1);
375 }
376 
377 //===----------------------------------------------------------------------===//
378 // IntegerSetAttr
379 //===----------------------------------------------------------------------===//
380 
381 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
382   return Base::get(value.getConstraint(0).getContext(),
383                    StandardAttributes::IntegerSet, value);
384 }
385 
386 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
387 
388 //===----------------------------------------------------------------------===//
389 // OpaqueAttr
390 //===----------------------------------------------------------------------===//
391 
392 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
393                            MLIRContext *context) {
394   return Base::get(context, StandardAttributes::Opaque, dialect, attrData,
395                    type);
396 }
397 
398 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
399                                   Type type, Location location) {
400   return Base::getChecked(location, StandardAttributes::Opaque, dialect,
401                           attrData, type);
402 }
403 
404 /// Returns the dialect namespace of the opaque attribute.
405 Identifier OpaqueAttr::getDialectNamespace() const {
406   return getImpl()->dialectNamespace;
407 }
408 
409 /// Returns the raw attribute data of the opaque attribute.
410 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
411 
412 /// Verify the construction of an opaque attribute.
413 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
414                                                        Identifier dialect,
415                                                        StringRef attrData,
416                                                        Type type) {
417   if (!Dialect::isValidNamespace(dialect.strref()))
418     return emitError(loc, "invalid dialect namespace '") << dialect << "'";
419   return success();
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // StringAttr
424 //===----------------------------------------------------------------------===//
425 
426 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
427   return get(bytes, NoneType::get(context));
428 }
429 
430 /// Get an instance of a StringAttr with the given string and Type.
431 StringAttr StringAttr::get(StringRef bytes, Type type) {
432   return Base::get(type.getContext(), StandardAttributes::String, bytes, type);
433 }
434 
435 StringRef StringAttr::getValue() const { return getImpl()->value; }
436 
437 //===----------------------------------------------------------------------===//
438 // TypeAttr
439 //===----------------------------------------------------------------------===//
440 
441 TypeAttr TypeAttr::get(Type value) {
442   return Base::get(value.getContext(), StandardAttributes::Type, value);
443 }
444 
445 Type TypeAttr::getValue() const { return getImpl()->value; }
446 
447 //===----------------------------------------------------------------------===//
448 // ElementsAttr
449 //===----------------------------------------------------------------------===//
450 
451 ShapedType ElementsAttr::getType() const {
452   return Attribute::getType().cast<ShapedType>();
453 }
454 
455 /// Returns the number of elements held by this attribute.
456 int64_t ElementsAttr::getNumElements() const {
457   return getType().getNumElements();
458 }
459 
460 /// Return the value at the given index. If index does not refer to a valid
461 /// element, then a null attribute is returned.
462 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
463   if (auto denseAttr = dyn_cast<DenseElementsAttr>())
464     return denseAttr.getValue(index);
465   if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>())
466     return opaqueAttr.getValue(index);
467   return cast<SparseElementsAttr>().getValue(index);
468 }
469 
470 /// Return if the given 'index' refers to a valid element in this attribute.
471 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
472   auto type = getType();
473 
474   // Verify that the rank of the indices matches the held type.
475   auto rank = type.getRank();
476   if (rank != static_cast<int64_t>(index.size()))
477     return false;
478 
479   // Verify that all of the indices are within the shape dimensions.
480   auto shape = type.getShape();
481   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
482     return static_cast<int64_t>(index[i]) < shape[i];
483   });
484 }
485 
486 ElementsAttr
487 ElementsAttr::mapValues(Type newElementType,
488                         function_ref<APInt(const APInt &)> mapping) const {
489   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
490     return intOrFpAttr.mapValues(newElementType, mapping);
491   llvm_unreachable("unsupported ElementsAttr subtype");
492 }
493 
494 ElementsAttr
495 ElementsAttr::mapValues(Type newElementType,
496                         function_ref<APInt(const APFloat &)> mapping) const {
497   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
498     return intOrFpAttr.mapValues(newElementType, mapping);
499   llvm_unreachable("unsupported ElementsAttr subtype");
500 }
501 
502 /// Method for support type inquiry through isa, cast and dyn_cast.
503 bool ElementsAttr::classof(Attribute attr) {
504   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr,
505                   OpaqueElementsAttr, SparseElementsAttr>();
506 }
507 
508 /// Returns the 1 dimensional flattened row-major index from the given
509 /// multi-dimensional index.
510 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
511   assert(isValidIndex(index) && "expected valid multi-dimensional index");
512   auto type = getType();
513 
514   // Reduce the provided multidimensional index into a flattended 1D row-major
515   // index.
516   auto rank = type.getRank();
517   auto shape = type.getShape();
518   uint64_t valueIndex = 0;
519   uint64_t dimMultiplier = 1;
520   for (int i = rank - 1; i >= 0; --i) {
521     valueIndex += index[i] * dimMultiplier;
522     dimMultiplier *= shape[i];
523   }
524   return valueIndex;
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // DenseElementAttr Utilities
529 //===----------------------------------------------------------------------===//
530 
531 /// Get the bitwidth of a dense element type within the buffer.
532 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
533 static size_t getDenseElementStorageWidth(size_t origWidth) {
534   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
535 }
536 static size_t getDenseElementStorageWidth(Type elementType) {
537   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
538 }
539 
540 /// Set a bit to a specific value.
541 static void setBit(char *rawData, size_t bitPos, bool value) {
542   if (value)
543     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
544   else
545     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
546 }
547 
548 /// Return the value of the specified bit.
549 static bool getBit(const char *rawData, size_t bitPos) {
550   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
551 }
552 
553 /// Get start position of actual data in `value`. Actual data is
554 /// stored in last `bitWidth`/CHAR_BIT bytes in big endian.
555 static char *getAPIntDataPos(APInt &value, size_t bitWidth) {
556   char *dataPos =
557       const_cast<char *>(reinterpret_cast<const char *>(value.getRawData()));
558   if (llvm::support::endian::system_endianness() ==
559       llvm::support::endianness::big)
560     dataPos = dataPos + 8 - llvm::divideCeil(bitWidth, CHAR_BIT);
561   return dataPos;
562 }
563 
564 /// Read APInt `value` from appropriate position.
565 static void readAPInt(APInt &value, size_t bitWidth, char *outData) {
566   char *dataPos = getAPIntDataPos(value, bitWidth);
567   std::copy_n(dataPos, llvm::divideCeil(bitWidth, CHAR_BIT), outData);
568 }
569 
570 /// Write `inData` to appropriate position of APInt `value`.
571 static void writeAPInt(const char *inData, size_t bitWidth, APInt &value) {
572   char *dataPos = getAPIntDataPos(value, bitWidth);
573   std::copy_n(inData, llvm::divideCeil(bitWidth, CHAR_BIT), dataPos);
574 }
575 
576 /// Writes value to the bit position `bitPos` in array `rawData`.
577 static void writeBits(char *rawData, size_t bitPos, APInt value) {
578   size_t bitWidth = value.getBitWidth();
579 
580   // If the bitwidth is 1 we just toggle the specific bit.
581   if (bitWidth == 1)
582     return setBit(rawData, bitPos, value.isOneValue());
583 
584   // Otherwise, the bit position is guaranteed to be byte aligned.
585   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
586   readAPInt(value, bitWidth, rawData + (bitPos / CHAR_BIT));
587 }
588 
589 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
590 /// `rawData`.
591 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
592   // Handle a boolean bit position.
593   if (bitWidth == 1)
594     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
595 
596   // Otherwise, the bit position must be 8-bit aligned.
597   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
598   APInt result(bitWidth, 0);
599   writeAPInt(rawData + (bitPos / CHAR_BIT), bitWidth, result);
600   return result;
601 }
602 
603 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the
604 /// same element count as 'type'.
605 template <typename Values>
606 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
607   return (values.size() == 1) ||
608          (type.getNumElements() == static_cast<int64_t>(values.size()));
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // DenseElementAttr Iterators
613 //===----------------------------------------------------------------------===//
614 
615 //===----------------------------------------------------------------------===//
616 // AttributeElementIterator
617 
618 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
619     DenseElementsAttr attr, size_t index)
620     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
621                                       Attribute, Attribute, Attribute>(
622           attr.getAsOpaquePointer(), index) {}
623 
624 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
625   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
626   Type eltTy = owner.getType().getElementType();
627   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
628     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
629   if (eltTy.isa<IndexType>())
630     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
631   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
632     IntElementIterator intIt(owner, index);
633     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
634     return FloatAttr::get(eltTy, *floatIt);
635   }
636   if (owner.isa<DenseStringElementsAttr>()) {
637     ArrayRef<StringRef> vals = owner.getRawStringData();
638     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
639   }
640   llvm_unreachable("unexpected element type");
641 }
642 
643 //===----------------------------------------------------------------------===//
644 // BoolElementIterator
645 
646 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
647     DenseElementsAttr attr, size_t dataIndex)
648     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
649           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
650 
651 bool DenseElementsAttr::BoolElementIterator::operator*() const {
652   return getBit(getData(), getDataIndex());
653 }
654 
655 //===----------------------------------------------------------------------===//
656 // IntElementIterator
657 
658 DenseElementsAttr::IntElementIterator::IntElementIterator(
659     DenseElementsAttr attr, size_t dataIndex)
660     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
661           attr.getRawData().data(), attr.isSplat(), dataIndex),
662       bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
663 
664 APInt DenseElementsAttr::IntElementIterator::operator*() const {
665   return readBits(getData(),
666                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
667                   bitWidth);
668 }
669 
670 //===----------------------------------------------------------------------===//
671 // ComplexIntElementIterator
672 
673 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
674     DenseElementsAttr attr, size_t dataIndex)
675     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
676                                       std::complex<APInt>, std::complex<APInt>,
677                                       std::complex<APInt>>(
678           attr.getRawData().data(), attr.isSplat(), dataIndex) {
679   auto complexType = attr.getType().getElementType().cast<ComplexType>();
680   bitWidth = getDenseElementBitWidth(complexType.getElementType());
681 }
682 
683 std::complex<APInt>
684 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
685   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
686   size_t offset = getDataIndex() * storageWidth * 2;
687   return {readBits(getData(), offset, bitWidth),
688           readBits(getData(), offset + storageWidth, bitWidth)};
689 }
690 
691 //===----------------------------------------------------------------------===//
692 // FloatElementIterator
693 
694 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
695     const llvm::fltSemantics &smt, IntElementIterator it)
696     : llvm::mapped_iterator<IntElementIterator,
697                             std::function<APFloat(const APInt &)>>(
698           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
699 
700 //===----------------------------------------------------------------------===//
701 // ComplexFloatElementIterator
702 
703 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
704     const llvm::fltSemantics &smt, ComplexIntElementIterator it)
705     : llvm::mapped_iterator<
706           ComplexIntElementIterator,
707           std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
708           it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
709             return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
710           }) {}
711 
712 //===----------------------------------------------------------------------===//
713 // DenseElementsAttr
714 //===----------------------------------------------------------------------===//
715 
716 /// Method for support type inquiry through isa, cast and dyn_cast.
717 bool DenseElementsAttr::classof(Attribute attr) {
718   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
719 }
720 
721 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
722                                          ArrayRef<Attribute> values) {
723   assert(hasSameElementsOrSplat(type, values));
724 
725   // If the element type is not based on int/float/index, assume it is a string
726   // type.
727   auto eltType = type.getElementType();
728   if (!type.getElementType().isIntOrIndexOrFloat()) {
729     SmallVector<StringRef, 8> stringValues;
730     stringValues.reserve(values.size());
731     for (Attribute attr : values) {
732       assert(attr.isa<StringAttr>() &&
733              "expected string value for non integer/index/float element");
734       stringValues.push_back(attr.cast<StringAttr>().getValue());
735     }
736     return get(type, stringValues);
737   }
738 
739   // Otherwise, get the raw storage width to use for the allocation.
740   size_t bitWidth = getDenseElementBitWidth(eltType);
741   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
742 
743   // Compress the attribute values into a character buffer.
744   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
745                             values.size());
746   APInt intVal;
747   for (unsigned i = 0, e = values.size(); i < e; ++i) {
748     assert(eltType == values[i].getType() &&
749            "expected attribute value to have element type");
750 
751     switch (eltType.getKind()) {
752     case StandardTypes::BF16:
753     case StandardTypes::F16:
754     case StandardTypes::F32:
755     case StandardTypes::F64:
756       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
757       break;
758     case StandardTypes::Integer:
759     case StandardTypes::Index:
760       intVal = values[i].cast<IntegerAttr>().getValue();
761       break;
762     default:
763       llvm_unreachable("unexpected element type");
764     }
765     assert(intVal.getBitWidth() == bitWidth &&
766            "expected value to have same bitwidth as element type");
767     writeBits(data.data(), i * storageBitWidth, intVal);
768   }
769   return DenseIntOrFPElementsAttr::getRaw(type, data,
770                                           /*isSplat=*/(values.size() == 1));
771 }
772 
773 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
774                                          ArrayRef<bool> values) {
775   assert(hasSameElementsOrSplat(type, values));
776   assert(type.getElementType().isInteger(1));
777 
778   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
779   for (int i = 0, e = values.size(); i != e; ++i)
780     setBit(buff.data(), i, values[i]);
781   return DenseIntOrFPElementsAttr::getRaw(type, buff,
782                                           /*isSplat=*/(values.size() == 1));
783 }
784 
785 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
786                                          ArrayRef<StringRef> values) {
787   assert(!type.getElementType().isIntOrFloat());
788   return DenseStringElementsAttr::get(type, values);
789 }
790 
791 /// Constructs a dense integer elements attribute from an array of APInt
792 /// values. Each APInt value is expected to have the same bitwidth as the
793 /// element type of 'type'.
794 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
795                                          ArrayRef<APInt> values) {
796   assert(type.getElementType().isIntOrIndex());
797   assert(hasSameElementsOrSplat(type, values));
798   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
799   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
800                                           /*isSplat=*/(values.size() == 1));
801 }
802 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
803                                          ArrayRef<std::complex<APInt>> values) {
804   ComplexType complex = type.getElementType().cast<ComplexType>();
805   assert(complex.getElementType().isa<IntegerType>());
806   assert(hasSameElementsOrSplat(type, values));
807   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
808   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
809                           values.size() * 2);
810   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
811                                           /*isSplat=*/(values.size() == 1));
812 }
813 
814 // Constructs a dense float elements attribute from an array of APFloat
815 // values. Each APFloat value is expected to have the same bitwidth as the
816 // element type of 'type'.
817 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
818                                          ArrayRef<APFloat> values) {
819   assert(type.getElementType().isa<FloatType>());
820   assert(hasSameElementsOrSplat(type, values));
821   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
822   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
823                                           /*isSplat=*/(values.size() == 1));
824 }
825 DenseElementsAttr
826 DenseElementsAttr::get(ShapedType type,
827                        ArrayRef<std::complex<APFloat>> values) {
828   ComplexType complex = type.getElementType().cast<ComplexType>();
829   assert(complex.getElementType().isa<FloatType>());
830   assert(hasSameElementsOrSplat(type, values));
831   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
832                            values.size() * 2);
833   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
834   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
835                                           /*isSplat=*/(values.size() == 1));
836 }
837 
838 /// Construct a dense elements attribute from a raw buffer representing the
839 /// data for this attribute. Users should generally not use this methods as
840 /// the expected buffer format may not be a form the user expects.
841 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
842                                                       ArrayRef<char> rawBuffer,
843                                                       bool isSplatBuffer) {
844   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
845 }
846 
847 /// Returns true if the given buffer is a valid raw buffer for the given type.
848 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
849                                          ArrayRef<char> rawBuffer,
850                                          bool &detectedSplat) {
851   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
852   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
853 
854   // Storage width of 1 is special as it is packed by the bit.
855   if (storageWidth == 1) {
856     // Check for a splat, or a buffer equal to the number of elements.
857     if ((detectedSplat = rawBuffer.size() == 1))
858       return true;
859     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
860   }
861   // All other types are 8-bit aligned.
862   if ((detectedSplat = rawBufferWidth == storageWidth))
863     return true;
864   return rawBufferWidth == (storageWidth * type.getNumElements());
865 }
866 
867 /// Check the information for a C++ data type, check if this type is valid for
868 /// the current attribute. This method is used to verify specific type
869 /// invariants that the templatized 'getValues' method cannot.
870 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
871                               bool isSigned) {
872   // Make sure that the data element size is the same as the type element width.
873   if (getDenseElementBitWidth(type) !=
874       static_cast<size_t>(dataEltSize * CHAR_BIT))
875     return false;
876 
877   // Check that the element type is either float or integer or index.
878   if (!isInt)
879     return type.isa<FloatType>();
880   if (type.isIndex())
881     return true;
882 
883   auto intType = type.dyn_cast<IntegerType>();
884   if (!intType)
885     return false;
886 
887   // Make sure signedness semantics is consistent.
888   if (intType.isSignless())
889     return true;
890   return intType.isSigned() ? isSigned : !isSigned;
891 }
892 
893 /// Defaults down the subclass implementation.
894 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
895                                                    ArrayRef<char> data,
896                                                    int64_t dataEltSize,
897                                                    bool isInt, bool isSigned) {
898   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
899                                                  isSigned);
900 }
901 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
902                                                       ArrayRef<char> data,
903                                                       int64_t dataEltSize,
904                                                       bool isInt,
905                                                       bool isSigned) {
906   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
907                                                     isInt, isSigned);
908 }
909 
910 /// A method used to verify specific type invariants that the templatized 'get'
911 /// method cannot.
912 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
913                                           bool isSigned) const {
914   return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
915                              isSigned);
916 }
917 
918 /// Check the information for a C++ data type, check if this type is valid for
919 /// the current attribute.
920 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
921                                        bool isSigned) const {
922   return ::isValidIntOrFloat(
923       getType().getElementType().cast<ComplexType>().getElementType(),
924       dataEltSize / 2, isInt, isSigned);
925 }
926 
927 /// Returns if this attribute corresponds to a splat, i.e. if all element
928 /// values are the same.
929 bool DenseElementsAttr::isSplat() const {
930   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
931 }
932 
933 /// Return the held element values as a range of Attributes.
934 auto DenseElementsAttr::getAttributeValues() const
935     -> llvm::iterator_range<AttributeElementIterator> {
936   return {attr_value_begin(), attr_value_end()};
937 }
938 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
939   return AttributeElementIterator(*this, 0);
940 }
941 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
942   return AttributeElementIterator(*this, getNumElements());
943 }
944 
945 /// Return the held element values as a range of bool. The element type of
946 /// this attribute must be of integer type of bitwidth 1.
947 auto DenseElementsAttr::getBoolValues() const
948     -> llvm::iterator_range<BoolElementIterator> {
949   auto eltType = getType().getElementType().dyn_cast<IntegerType>();
950   assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
951   (void)eltType;
952   return {BoolElementIterator(*this, 0),
953           BoolElementIterator(*this, getNumElements())};
954 }
955 
956 /// Return the held element values as a range of APInts. The element type of
957 /// this attribute must be of integer type.
958 auto DenseElementsAttr::getIntValues() const
959     -> llvm::iterator_range<IntElementIterator> {
960   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
961   return {raw_int_begin(), raw_int_end()};
962 }
963 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
964   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
965   return raw_int_begin();
966 }
967 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
968   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
969   return raw_int_end();
970 }
971 auto DenseElementsAttr::getComplexIntValues() const
972     -> llvm::iterator_range<ComplexIntElementIterator> {
973   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
974   (void)eltTy;
975   assert(eltTy.isa<IntegerType>() && "expected complex integral type");
976   return {ComplexIntElementIterator(*this, 0),
977           ComplexIntElementIterator(*this, getNumElements())};
978 }
979 
980 /// Return the held element values as a range of APFloat. The element type of
981 /// this attribute must be of float type.
982 auto DenseElementsAttr::getFloatValues() const
983     -> llvm::iterator_range<FloatElementIterator> {
984   auto elementType = getType().getElementType().cast<FloatType>();
985   const auto &elementSemantics = elementType.getFloatSemantics();
986   return {FloatElementIterator(elementSemantics, raw_int_begin()),
987           FloatElementIterator(elementSemantics, raw_int_end())};
988 }
989 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
990   return getFloatValues().begin();
991 }
992 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
993   return getFloatValues().end();
994 }
995 auto DenseElementsAttr::getComplexFloatValues() const
996     -> llvm::iterator_range<ComplexFloatElementIterator> {
997   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
998   assert(eltTy.isa<FloatType>() && "expected complex float type");
999   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
1000   return {{semantics, {*this, 0}},
1001           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
1002 }
1003 
1004 /// Return the raw storage data held by this attribute.
1005 ArrayRef<char> DenseElementsAttr::getRawData() const {
1006   return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data;
1007 }
1008 
1009 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1010   return static_cast<DenseStringElementsAttributeStorage *>(impl)->data;
1011 }
1012 
1013 /// Return a new DenseElementsAttr that has the same data as the current
1014 /// attribute, but has been reshaped to 'newType'. The new type must have the
1015 /// same total number of elements as well as element type.
1016 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1017   ShapedType curType = getType();
1018   if (curType == newType)
1019     return *this;
1020 
1021   (void)curType;
1022   assert(newType.getElementType() == curType.getElementType() &&
1023          "expected the same element type");
1024   assert(newType.getNumElements() == curType.getNumElements() &&
1025          "expected the same number of elements");
1026   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
1027 }
1028 
1029 DenseElementsAttr
1030 DenseElementsAttr::mapValues(Type newElementType,
1031                              function_ref<APInt(const APInt &)> mapping) const {
1032   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1033 }
1034 
1035 DenseElementsAttr DenseElementsAttr::mapValues(
1036     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1037   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1038 }
1039 
1040 //===----------------------------------------------------------------------===//
1041 // DenseStringElementsAttr
1042 //===----------------------------------------------------------------------===//
1043 
1044 DenseStringElementsAttr
1045 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
1046   return Base::get(type.getContext(), StandardAttributes::DenseStringElements,
1047                    type, values, (values.size() == 1));
1048 }
1049 
1050 //===----------------------------------------------------------------------===//
1051 // DenseIntOrFPElementsAttr
1052 //===----------------------------------------------------------------------===//
1053 
1054 /// Utility method to write a range of APInt values to a buffer.
1055 template <typename APRangeT>
1056 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1057                                 APRangeT &&values) {
1058   data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
1059   size_t offset = 0;
1060   for (auto it = values.begin(), e = values.end(); it != e;
1061        ++it, offset += storageWidth) {
1062     assert((*it).getBitWidth() <= storageWidth);
1063     writeBits(data.data(), offset, *it);
1064   }
1065 }
1066 
1067 /// Constructs a dense elements attribute from an array of raw APFloat values.
1068 /// Each APFloat value is expected to have the same bitwidth as the element
1069 /// type of 'type'. 'type' must be a vector or tensor with static shape.
1070 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1071                                                    size_t storageWidth,
1072                                                    ArrayRef<APFloat> values,
1073                                                    bool isSplat) {
1074   std::vector<char> data;
1075   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1076   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1077   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1078 }
1079 
1080 /// Constructs a dense elements attribute from an array of raw APInt values.
1081 /// Each APInt value is expected to have the same bitwidth as the element type
1082 /// of 'type'.
1083 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1084                                                    size_t storageWidth,
1085                                                    ArrayRef<APInt> values,
1086                                                    bool isSplat) {
1087   std::vector<char> data;
1088   writeAPIntsToBuffer(storageWidth, data, values);
1089   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1090 }
1091 
1092 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1093                                                    ArrayRef<char> data,
1094                                                    bool isSplat) {
1095   assert((type.isa<RankedTensorType, VectorType>()) &&
1096          "type must be ranked tensor or vector");
1097   assert(type.hasStaticShape() && "type must have static shape");
1098   return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements,
1099                    type, data, isSplat);
1100 }
1101 
1102 /// Overload of the raw 'get' method that asserts that the given type is of
1103 /// complex type. This method is used to verify type invariants that the
1104 /// templatized 'get' method cannot.
1105 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1106                                                           ArrayRef<char> data,
1107                                                           int64_t dataEltSize,
1108                                                           bool isInt,
1109                                                           bool isSigned) {
1110   assert(::isValidIntOrFloat(
1111       type.getElementType().cast<ComplexType>().getElementType(),
1112       dataEltSize / 2, isInt, isSigned));
1113 
1114   int64_t numElements = data.size() / dataEltSize;
1115   assert(numElements == 1 || numElements == type.getNumElements());
1116   return getRaw(type, data, /*isSplat=*/numElements == 1);
1117 }
1118 
1119 /// Overload of the 'getRaw' method that asserts that the given type is of
1120 /// integer type. This method is used to verify type invariants that the
1121 /// templatized 'get' method cannot.
1122 DenseElementsAttr
1123 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1124                                            int64_t dataEltSize, bool isInt,
1125                                            bool isSigned) {
1126   assert(
1127       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1128 
1129   int64_t numElements = data.size() / dataEltSize;
1130   assert(numElements == 1 || numElements == type.getNumElements());
1131   return getRaw(type, data, /*isSplat=*/numElements == 1);
1132 }
1133 
1134 //===----------------------------------------------------------------------===//
1135 // DenseFPElementsAttr
1136 //===----------------------------------------------------------------------===//
1137 
1138 template <typename Fn, typename Attr>
1139 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1140                                 Type newElementType,
1141                                 llvm::SmallVectorImpl<char> &data) {
1142   size_t bitWidth = getDenseElementBitWidth(newElementType);
1143   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1144 
1145   ShapedType newArrayType;
1146   if (inType.isa<RankedTensorType>())
1147     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1148   else if (inType.isa<UnrankedTensorType>())
1149     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1150   else if (inType.isa<VectorType>())
1151     newArrayType = VectorType::get(inType.getShape(), newElementType);
1152   else
1153     assert(newArrayType && "Unhandled tensor type");
1154 
1155   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1156   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1157 
1158   // Functor used to process a single element value of the attribute.
1159   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1160     auto newInt = mapping(value);
1161     assert(newInt.getBitWidth() == bitWidth);
1162     writeBits(data.data(), index * storageBitWidth, newInt);
1163   };
1164 
1165   // Check for the splat case.
1166   if (attr.isSplat()) {
1167     processElt(*attr.begin(), /*index=*/0);
1168     return newArrayType;
1169   }
1170 
1171   // Otherwise, process all of the element values.
1172   uint64_t elementIdx = 0;
1173   for (auto value : attr)
1174     processElt(value, elementIdx++);
1175   return newArrayType;
1176 }
1177 
1178 DenseElementsAttr DenseFPElementsAttr::mapValues(
1179     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1180   llvm::SmallVector<char, 8> elementData;
1181   auto newArrayType =
1182       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1183 
1184   return getRaw(newArrayType, elementData, isSplat());
1185 }
1186 
1187 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1188 bool DenseFPElementsAttr::classof(Attribute attr) {
1189   return attr.isa<DenseElementsAttr>() &&
1190          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1191 }
1192 
1193 //===----------------------------------------------------------------------===//
1194 // DenseIntElementsAttr
1195 //===----------------------------------------------------------------------===//
1196 
1197 DenseElementsAttr DenseIntElementsAttr::mapValues(
1198     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1199   llvm::SmallVector<char, 8> elementData;
1200   auto newArrayType =
1201       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1202 
1203   return getRaw(newArrayType, elementData, isSplat());
1204 }
1205 
1206 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1207 bool DenseIntElementsAttr::classof(Attribute attr) {
1208   return attr.isa<DenseElementsAttr>() &&
1209          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1210 }
1211 
1212 //===----------------------------------------------------------------------===//
1213 // OpaqueElementsAttr
1214 //===----------------------------------------------------------------------===//
1215 
1216 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
1217                                            StringRef bytes) {
1218   assert(TensorType::isValidElementType(type.getElementType()) &&
1219          "Input element type should be a valid tensor element type");
1220   return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type,
1221                    dialect, bytes);
1222 }
1223 
1224 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
1225 
1226 /// Return the value at the given index. If index does not refer to a valid
1227 /// element, then a null attribute is returned.
1228 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1229   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1230   if (Dialect *dialect = getDialect())
1231     return dialect->extractElementHook(*this, index);
1232   return Attribute();
1233 }
1234 
1235 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
1236 
1237 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1238   if (auto *d = getDialect())
1239     return d->decodeHook(*this, result);
1240   return true;
1241 }
1242 
1243 //===----------------------------------------------------------------------===//
1244 // SparseElementsAttr
1245 //===----------------------------------------------------------------------===//
1246 
1247 SparseElementsAttr SparseElementsAttr::get(ShapedType type,
1248                                            DenseElementsAttr indices,
1249                                            DenseElementsAttr values) {
1250   assert(indices.getType().getElementType().isInteger(64) &&
1251          "expected sparse indices to be 64-bit integer values");
1252   assert((type.isa<RankedTensorType, VectorType>()) &&
1253          "type must be ranked tensor or vector");
1254   assert(type.hasStaticShape() && "type must have static shape");
1255   return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
1256                    indices.cast<DenseIntElementsAttr>(), values);
1257 }
1258 
1259 DenseIntElementsAttr SparseElementsAttr::getIndices() const {
1260   return getImpl()->indices;
1261 }
1262 
1263 DenseElementsAttr SparseElementsAttr::getValues() const {
1264   return getImpl()->values;
1265 }
1266 
1267 /// Return the value of the element at the given index.
1268 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1269   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1270   auto type = getType();
1271 
1272   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1273   // as a 1-D index array.
1274   auto sparseIndices = getIndices();
1275   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1276 
1277   // Check to see if the indices are a splat.
1278   if (sparseIndices.isSplat()) {
1279     // If the index is also not a splat of the index value, we know that the
1280     // value is zero.
1281     auto splatIndex = *sparseIndexValues.begin();
1282     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
1283       return getZeroAttr();
1284 
1285     // If the indices are a splat, we also expect the values to be a splat.
1286     assert(getValues().isSplat() && "expected splat values");
1287     return getValues().getSplatValue();
1288   }
1289 
1290   // Build a mapping between known indices and the offset of the stored element.
1291   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
1292   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1293   size_t rank = type.getRank();
1294   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1295     mappedIndices.try_emplace(
1296         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1297 
1298   // Look for the provided index key within the mapped indices. If the provided
1299   // index is not found, then return a zero attribute.
1300   auto it = mappedIndices.find(index);
1301   if (it == mappedIndices.end())
1302     return getZeroAttr();
1303 
1304   // Otherwise, return the held sparse value element.
1305   return getValues().getValue(it->second);
1306 }
1307 
1308 /// Get a zero APFloat for the given sparse attribute.
1309 APFloat SparseElementsAttr::getZeroAPFloat() const {
1310   auto eltType = getType().getElementType().cast<FloatType>();
1311   return APFloat(eltType.getFloatSemantics());
1312 }
1313 
1314 /// Get a zero APInt for the given sparse attribute.
1315 APInt SparseElementsAttr::getZeroAPInt() const {
1316   auto eltType = getType().getElementType().cast<IntegerType>();
1317   return APInt::getNullValue(eltType.getWidth());
1318 }
1319 
1320 /// Get a zero attribute for the given attribute type.
1321 Attribute SparseElementsAttr::getZeroAttr() const {
1322   auto eltType = getType().getElementType();
1323 
1324   // Handle floating point elements.
1325   if (eltType.isa<FloatType>())
1326     return FloatAttr::get(eltType, 0);
1327 
1328   // Otherwise, this is an integer.
1329   // TODO: Handle StringAttr here.
1330   return IntegerAttr::get(eltType, 0);
1331 }
1332 
1333 /// Flatten, and return, all of the sparse indices in this attribute in
1334 /// row-major order.
1335 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1336   std::vector<ptrdiff_t> flatSparseIndices;
1337 
1338   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1339   // as a 1-D index array.
1340   auto sparseIndices = getIndices();
1341   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1342   if (sparseIndices.isSplat()) {
1343     SmallVector<uint64_t, 8> indices(getType().getRank(),
1344                                      *sparseIndexValues.begin());
1345     flatSparseIndices.push_back(getFlattenedIndex(indices));
1346     return flatSparseIndices;
1347   }
1348 
1349   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1350   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1351   size_t rank = getType().getRank();
1352   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1353     flatSparseIndices.push_back(getFlattenedIndex(
1354         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1355   return flatSparseIndices;
1356 }
1357 
1358 //===----------------------------------------------------------------------===//
1359 // MutableDictionaryAttr
1360 //===----------------------------------------------------------------------===//
1361 
1362 MutableDictionaryAttr::MutableDictionaryAttr(
1363     ArrayRef<NamedAttribute> attributes) {
1364   setAttrs(attributes);
1365 }
1366 
1367 /// Return the underlying dictionary attribute.
1368 DictionaryAttr
1369 MutableDictionaryAttr::getDictionary(MLIRContext *context) const {
1370   // Construct empty DictionaryAttr if needed.
1371   if (!attrs)
1372     return DictionaryAttr::get({}, context);
1373   return attrs;
1374 }
1375 
1376 ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const {
1377   return attrs ? attrs.getValue() : llvm::None;
1378 }
1379 
1380 /// Replace the held attributes with ones provided in 'newAttrs'.
1381 void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) {
1382   // Don't create an attribute list if there are no attributes.
1383   if (attributes.empty())
1384     attrs = nullptr;
1385   else
1386     attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
1387 }
1388 
1389 /// Return the specified attribute if present, null otherwise.
1390 Attribute MutableDictionaryAttr::get(StringRef name) const {
1391   return attrs ? attrs.get(name) : nullptr;
1392 }
1393 
1394 /// Return the specified attribute if present, null otherwise.
1395 Attribute MutableDictionaryAttr::get(Identifier name) const {
1396   return attrs ? attrs.get(name) : nullptr;
1397 }
1398 
1399 /// Return the specified named attribute if present, None otherwise.
1400 Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const {
1401   return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1402 }
1403 Optional<NamedAttribute>
1404 MutableDictionaryAttr::getNamed(Identifier name) const {
1405   return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1406 }
1407 
1408 /// If the an attribute exists with the specified name, change it to the new
1409 /// value.  Otherwise, add a new attribute with the specified name/value.
1410 void MutableDictionaryAttr::set(Identifier name, Attribute value) {
1411   assert(value && "attributes may never be null");
1412 
1413   // Look for an existing value for the given name, and set it in-place.
1414   ArrayRef<NamedAttribute> values = getAttrs();
1415   const auto *it = llvm::find_if(
1416       values, [name](NamedAttribute attr) { return attr.first == name; });
1417   if (it != values.end()) {
1418     // Bail out early if the value is the same as what we already have.
1419     if (it->second == value)
1420       return;
1421 
1422     SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end());
1423     newAttrs[it - values.begin()].second = value;
1424     attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1425     return;
1426   }
1427 
1428   // Otherwise, insert the new attribute into its sorted position.
1429   it = llvm::lower_bound(values, name);
1430   SmallVector<NamedAttribute, 8> newAttrs;
1431   newAttrs.reserve(values.size() + 1);
1432   newAttrs.append(values.begin(), it);
1433   newAttrs.push_back({name, value});
1434   newAttrs.append(it, values.end());
1435   attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1436 }
1437 
1438 /// Remove the attribute with the specified name if it exists.  The return
1439 /// value indicates whether the attribute was present or not.
1440 auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult {
1441   auto origAttrs = getAttrs();
1442   for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
1443     if (origAttrs[i].first == name) {
1444       // Handle the simple case of removing the only attribute in the list.
1445       if (e == 1) {
1446         attrs = nullptr;
1447         return RemoveResult::Removed;
1448       }
1449 
1450       SmallVector<NamedAttribute, 8> newAttrs;
1451       newAttrs.reserve(origAttrs.size() - 1);
1452       newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
1453       newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
1454       attrs = DictionaryAttr::getWithSorted(newAttrs,
1455                                             newAttrs[0].second.getContext());
1456       return RemoveResult::Removed;
1457     }
1458   }
1459   return RemoveResult::NotFound;
1460 }
1461 
1462 bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) {
1463   return strcmp(lhs.first.data(), rhs.first.data()) < 0;
1464 }
1465 bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) {
1466   // This is correct even when attr.first.data()[name.size()] is not a zero
1467   // string terminator, because we only care about a less than comparison.
1468   // This can't use memcmp, because it doesn't guarantee that it will stop
1469   // reading both buffers if one is shorter than the other, even if there is
1470   // a difference.
1471   return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0;
1472 }
1473