1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Attributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Diagnostics.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/Function.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Types.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
23 //===----------------------------------------------------------------------===//
24 // AttributeStorage
25 //===----------------------------------------------------------------------===//
26 
27 AttributeStorage::AttributeStorage(Type type)
28     : type(type.getAsOpaquePointer()) {}
29 AttributeStorage::AttributeStorage() : type(nullptr) {}
30 
31 Type AttributeStorage::getType() const {
32   return Type::getFromOpaquePointer(type);
33 }
34 void AttributeStorage::setType(Type newType) {
35   type = newType.getAsOpaquePointer();
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Attribute
40 //===----------------------------------------------------------------------===//
41 
42 /// Return the type of this attribute.
43 Type Attribute::getType() const { return impl->getType(); }
44 
45 /// Return the context this attribute belongs to.
46 MLIRContext *Attribute::getContext() const { return getType().getContext(); }
47 
48 /// Get the dialect this attribute is registered to.
49 Dialect &Attribute::getDialect() const { return impl->getDialect(); }
50 
51 //===----------------------------------------------------------------------===//
52 // AffineMapAttr
53 //===----------------------------------------------------------------------===//
54 
55 AffineMapAttr AffineMapAttr::get(AffineMap value) {
56   return Base::get(value.getContext(), StandardAttributes::AffineMap, value);
57 }
58 
59 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
60 
61 //===----------------------------------------------------------------------===//
62 // ArrayAttr
63 //===----------------------------------------------------------------------===//
64 
65 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
66   return Base::get(context, StandardAttributes::Array, value);
67 }
68 
69 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
70 
71 //===----------------------------------------------------------------------===//
72 // BoolAttr
73 //===----------------------------------------------------------------------===//
74 
75 bool BoolAttr::getValue() const { return getImpl()->value; }
76 
77 //===----------------------------------------------------------------------===//
78 // DictionaryAttr
79 //===----------------------------------------------------------------------===//
80 
81 /// Perform a three-way comparison between the names of the specified
82 /// NamedAttributes.
83 static int compareNamedAttributes(const NamedAttribute *lhs,
84                                   const NamedAttribute *rhs) {
85   return lhs->first.strref().compare(rhs->first.strref());
86 }
87 
88 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
89                                    MLIRContext *context) {
90   assert(llvm::all_of(value,
91                       [](const NamedAttribute &attr) { return attr.second; }) &&
92          "value cannot have null entries");
93 
94   // We need to sort the element list to canonicalize it, but we also don't want
95   // to do a ton of work in the super common case where the element list is
96   // already sorted.
97   SmallVector<NamedAttribute, 8> storage;
98   switch (value.size()) {
99   case 0:
100     break;
101   case 1:
102     // A single element is already sorted.
103     break;
104   case 2:
105     assert(value[0].first != value[1].first &&
106            "DictionaryAttr element names must be unique");
107 
108     // Don't invoke a general sort for two element case.
109     if (value[0].first.strref() > value[1].first.strref()) {
110       storage.push_back(value[1]);
111       storage.push_back(value[0]);
112       value = storage;
113     }
114     break;
115   default:
116     // Check to see they are sorted already.
117     bool isSorted = true;
118     for (unsigned i = 0, e = value.size() - 1; i != e; ++i) {
119       if (value[i].first.strref() > value[i + 1].first.strref()) {
120         isSorted = false;
121         break;
122       }
123     }
124     // If not, do a general sort.
125     if (!isSorted) {
126       storage.append(value.begin(), value.end());
127       llvm::array_pod_sort(storage.begin(), storage.end(),
128                            compareNamedAttributes);
129       value = storage;
130     }
131 
132     // Ensure that the attribute elements are unique.
133     assert(std::adjacent_find(value.begin(), value.end(),
134                               [](NamedAttribute l, NamedAttribute r) {
135                                 return l.first == r.first;
136                               }) == value.end() &&
137            "DictionaryAttr element names must be unique");
138   }
139 
140   return Base::get(context, StandardAttributes::Dictionary, value);
141 }
142 
143 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
144   return getImpl()->getElements();
145 }
146 
147 /// Return the specified attribute if present, null otherwise.
148 Attribute DictionaryAttr::get(StringRef name) const {
149   ArrayRef<NamedAttribute> values = getValue();
150   auto compare = [](NamedAttribute attr, StringRef name) {
151     return attr.first.strref() < name;
152   };
153   auto it = llvm::lower_bound(values, name, compare);
154   return it != values.end() && it->first.is(name) ? it->second : Attribute();
155 }
156 Attribute DictionaryAttr::get(Identifier name) const {
157   for (auto elt : getValue())
158     if (elt.first == name)
159       return elt.second;
160   return nullptr;
161 }
162 
163 DictionaryAttr::iterator DictionaryAttr::begin() const {
164   return getValue().begin();
165 }
166 DictionaryAttr::iterator DictionaryAttr::end() const {
167   return getValue().end();
168 }
169 size_t DictionaryAttr::size() const { return getValue().size(); }
170 
171 //===----------------------------------------------------------------------===//
172 // FloatAttr
173 //===----------------------------------------------------------------------===//
174 
175 FloatAttr FloatAttr::get(Type type, double value) {
176   return Base::get(type.getContext(), StandardAttributes::Float, type, value);
177 }
178 
179 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
180   return Base::getChecked(loc, type.getContext(), StandardAttributes::Float,
181                           type, value);
182 }
183 
184 FloatAttr FloatAttr::get(Type type, const APFloat &value) {
185   return Base::get(type.getContext(), StandardAttributes::Float, type, value);
186 }
187 
188 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
189   return Base::getChecked(loc, type.getContext(), StandardAttributes::Float,
190                           type, value);
191 }
192 
193 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
194 
195 double FloatAttr::getValueAsDouble() const {
196   return getValueAsDouble(getValue());
197 }
198 double FloatAttr::getValueAsDouble(APFloat value) {
199   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
200     bool losesInfo = false;
201     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
202                   &losesInfo);
203   }
204   return value.convertToDouble();
205 }
206 
207 /// Verify construction invariants.
208 static LogicalResult verifyFloatTypeInvariants(Optional<Location> loc,
209                                                Type type) {
210   if (!type.isa<FloatType>())
211     return emitOptionalError(loc, "expected floating point type");
212   return success();
213 }
214 
215 LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
216                                                       MLIRContext *ctx,
217                                                       Type type, double value) {
218   return verifyFloatTypeInvariants(loc, type);
219 }
220 
221 LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
222                                                       MLIRContext *ctx,
223                                                       Type type,
224                                                       const APFloat &value) {
225   // Verify that the type is correct.
226   if (failed(verifyFloatTypeInvariants(loc, type)))
227     return failure();
228 
229   // Verify that the type semantics match that of the value.
230   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
231     return emitOptionalError(
232         loc, "FloatAttr type doesn't match the type implied by its value");
233   }
234   return success();
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // SymbolRefAttr
239 //===----------------------------------------------------------------------===//
240 
241 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
242   return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None)
243       .cast<FlatSymbolRefAttr>();
244 }
245 
246 SymbolRefAttr SymbolRefAttr::get(StringRef value,
247                                  ArrayRef<FlatSymbolRefAttr> nestedReferences,
248                                  MLIRContext *ctx) {
249   return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences);
250 }
251 
252 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
253 
254 StringRef SymbolRefAttr::getLeafReference() const {
255   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
256   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
257 }
258 
259 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
260   return getImpl()->getNestedRefs();
261 }
262 
263 //===----------------------------------------------------------------------===//
264 // IntegerAttr
265 //===----------------------------------------------------------------------===//
266 
267 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
268   return Base::get(type.getContext(), StandardAttributes::Integer, type, value);
269 }
270 
271 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
272   // This uses 64 bit APInts by default for index type.
273   if (type.isIndex())
274     return get(type, APInt(64, value));
275 
276   auto intType = type.cast<IntegerType>();
277   return get(type, APInt(intType.getWidth(), value));
278 }
279 
280 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
281 
282 int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
283 
284 //===----------------------------------------------------------------------===//
285 // IntegerSetAttr
286 //===----------------------------------------------------------------------===//
287 
288 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
289   return Base::get(value.getConstraint(0).getContext(),
290                    StandardAttributes::IntegerSet, value);
291 }
292 
293 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
294 
295 //===----------------------------------------------------------------------===//
296 // OpaqueAttr
297 //===----------------------------------------------------------------------===//
298 
299 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
300                            MLIRContext *context) {
301   return Base::get(context, StandardAttributes::Opaque, dialect, attrData,
302                    type);
303 }
304 
305 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
306                                   Type type, Location location) {
307   return Base::getChecked(location, type.getContext(),
308                           StandardAttributes::Opaque, dialect, attrData, type);
309 }
310 
311 /// Returns the dialect namespace of the opaque attribute.
312 Identifier OpaqueAttr::getDialectNamespace() const {
313   return getImpl()->dialectNamespace;
314 }
315 
316 /// Returns the raw attribute data of the opaque attribute.
317 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
318 
319 /// Verify the construction of an opaque attribute.
320 LogicalResult OpaqueAttr::verifyConstructionInvariants(Optional<Location> loc,
321                                                        MLIRContext *context,
322                                                        Identifier dialect,
323                                                        StringRef attrData,
324                                                        Type type) {
325   if (!Dialect::isValidNamespace(dialect.strref()))
326     return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'");
327   return success();
328 }
329 
330 //===----------------------------------------------------------------------===//
331 // StringAttr
332 //===----------------------------------------------------------------------===//
333 
334 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
335   return get(bytes, NoneType::get(context));
336 }
337 
338 /// Get an instance of a StringAttr with the given string and Type.
339 StringAttr StringAttr::get(StringRef bytes, Type type) {
340   return Base::get(type.getContext(), StandardAttributes::String, bytes, type);
341 }
342 
343 StringRef StringAttr::getValue() const { return getImpl()->value; }
344 
345 //===----------------------------------------------------------------------===//
346 // TypeAttr
347 //===----------------------------------------------------------------------===//
348 
349 TypeAttr TypeAttr::get(Type value) {
350   return Base::get(value.getContext(), StandardAttributes::Type, value);
351 }
352 
353 Type TypeAttr::getValue() const { return getImpl()->value; }
354 
355 //===----------------------------------------------------------------------===//
356 // ElementsAttr
357 //===----------------------------------------------------------------------===//
358 
359 ShapedType ElementsAttr::getType() const {
360   return Attribute::getType().cast<ShapedType>();
361 }
362 
363 /// Returns the number of elements held by this attribute.
364 int64_t ElementsAttr::getNumElements() const {
365   return getType().getNumElements();
366 }
367 
368 /// Return the value at the given index. If index does not refer to a valid
369 /// element, then a null attribute is returned.
370 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
371   switch (getKind()) {
372   case StandardAttributes::DenseElements:
373     return cast<DenseElementsAttr>().getValue(index);
374   case StandardAttributes::OpaqueElements:
375     return cast<OpaqueElementsAttr>().getValue(index);
376   case StandardAttributes::SparseElements:
377     return cast<SparseElementsAttr>().getValue(index);
378   default:
379     llvm_unreachable("unknown ElementsAttr kind");
380   }
381 }
382 
383 /// Return if the given 'index' refers to a valid element in this attribute.
384 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
385   auto type = getType();
386 
387   // Verify that the rank of the indices matches the held type.
388   auto rank = type.getRank();
389   if (rank != static_cast<int64_t>(index.size()))
390     return false;
391 
392   // Verify that all of the indices are within the shape dimensions.
393   auto shape = type.getShape();
394   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
395     return static_cast<int64_t>(index[i]) < shape[i];
396   });
397 }
398 
399 ElementsAttr
400 ElementsAttr::mapValues(Type newElementType,
401                         function_ref<APInt(const APInt &)> mapping) const {
402   switch (getKind()) {
403   case StandardAttributes::DenseElements:
404     return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
405   default:
406     llvm_unreachable("unsupported ElementsAttr subtype");
407   }
408 }
409 
410 ElementsAttr
411 ElementsAttr::mapValues(Type newElementType,
412                         function_ref<APInt(const APFloat &)> mapping) const {
413   switch (getKind()) {
414   case StandardAttributes::DenseElements:
415     return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
416   default:
417     llvm_unreachable("unsupported ElementsAttr subtype");
418   }
419 }
420 
421 /// Returns the 1 dimensional flattened row-major index from the given
422 /// multi-dimensional index.
423 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
424   assert(isValidIndex(index) && "expected valid multi-dimensional index");
425   auto type = getType();
426 
427   // Reduce the provided multidimensional index into a flattended 1D row-major
428   // index.
429   auto rank = type.getRank();
430   auto shape = type.getShape();
431   uint64_t valueIndex = 0;
432   uint64_t dimMultiplier = 1;
433   for (int i = rank - 1; i >= 0; --i) {
434     valueIndex += index[i] * dimMultiplier;
435     dimMultiplier *= shape[i];
436   }
437   return valueIndex;
438 }
439 
440 //===----------------------------------------------------------------------===//
441 // DenseElementAttr Utilities
442 //===----------------------------------------------------------------------===//
443 
444 static size_t getDenseElementBitwidth(Type eltType) {
445   // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
446   // with double semantics.
447   return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
448 }
449 
450 /// Get the bitwidth of a dense element type within the buffer.
451 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
452 static size_t getDenseElementStorageWidth(size_t origWidth) {
453   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
454 }
455 
456 /// Set a bit to a specific value.
457 static void setBit(char *rawData, size_t bitPos, bool value) {
458   if (value)
459     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
460   else
461     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
462 }
463 
464 /// Return the value of the specified bit.
465 static bool getBit(const char *rawData, size_t bitPos) {
466   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
467 }
468 
469 /// Writes value to the bit position `bitPos` in array `rawData`.
470 static void writeBits(char *rawData, size_t bitPos, APInt value) {
471   size_t bitWidth = value.getBitWidth();
472 
473   // If the bitwidth is 1 we just toggle the specific bit.
474   if (bitWidth == 1)
475     return setBit(rawData, bitPos, value.isOneValue());
476 
477   // Otherwise, the bit position is guaranteed to be byte aligned.
478   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
479   std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
480               llvm::divideCeil(bitWidth, CHAR_BIT),
481               rawData + (bitPos / CHAR_BIT));
482 }
483 
484 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
485 /// `rawData`.
486 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
487   // Handle a boolean bit position.
488   if (bitWidth == 1)
489     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
490 
491   // Otherwise, the bit position must be 8-bit aligned.
492   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
493   APInt result(bitWidth, 0);
494   std::copy_n(
495       rawData + (bitPos / CHAR_BIT), llvm::divideCeil(bitWidth, CHAR_BIT),
496       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
497   return result;
498 }
499 
500 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the
501 /// same element count as 'type'.
502 template <typename Values>
503 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
504   return (values.size() == 1) ||
505          (type.getNumElements() == static_cast<int64_t>(values.size()));
506 }
507 
508 //===----------------------------------------------------------------------===//
509 // DenseElementAttr Iterators
510 //===----------------------------------------------------------------------===//
511 
512 /// Constructs a new iterator.
513 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
514     DenseElementsAttr attr, size_t index)
515     : indexed_accessor_iterator<AttributeElementIterator, const void *,
516                                 Attribute, Attribute, Attribute>(
517           attr.getAsOpaquePointer(), index) {}
518 
519 /// Accesses the Attribute value at this iterator position.
520 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
521   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
522   Type eltTy = owner.getType().getElementType();
523   if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) {
524     if (intEltTy.getWidth() == 1)
525       return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(),
526                            owner.getContext());
527     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
528   }
529   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
530     IntElementIterator intIt(owner, index);
531     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
532     return FloatAttr::get(eltTy, *floatIt);
533   }
534   llvm_unreachable("unexpected element type");
535 }
536 
537 /// Constructs a new iterator.
538 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
539     DenseElementsAttr attr, size_t dataIndex)
540     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
541           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
542 
543 /// Accesses the bool value at this iterator position.
544 bool DenseElementsAttr::BoolElementIterator::operator*() const {
545   return getBit(getData(), getDataIndex());
546 }
547 
548 /// Constructs a new iterator.
549 DenseElementsAttr::IntElementIterator::IntElementIterator(
550     DenseElementsAttr attr, size_t dataIndex)
551     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
552           attr.getRawData().data(), attr.isSplat(), dataIndex),
553       bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {}
554 
555 /// Accesses the raw APInt value at this iterator position.
556 APInt DenseElementsAttr::IntElementIterator::operator*() const {
557   return readBits(getData(),
558                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
559                   bitWidth);
560 }
561 
562 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
563     const llvm::fltSemantics &smt, IntElementIterator it)
564     : llvm::mapped_iterator<IntElementIterator,
565                             std::function<APFloat(const APInt &)>>(
566           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
567 
568 //===----------------------------------------------------------------------===//
569 // DenseElementsAttr
570 //===----------------------------------------------------------------------===//
571 
572 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
573                                          ArrayRef<Attribute> values) {
574   assert(type.getElementType().isIntOrFloat() &&
575          "expected int or float element type");
576   assert(hasSameElementsOrSplat(type, values));
577 
578   auto eltType = type.getElementType();
579   size_t bitWidth = getDenseElementBitwidth(eltType);
580   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
581 
582   // Compress the attribute values into a character buffer.
583   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
584                             values.size());
585   APInt intVal;
586   for (unsigned i = 0, e = values.size(); i < e; ++i) {
587     assert(eltType == values[i].getType() &&
588            "expected attribute value to have element type");
589 
590     switch (eltType.getKind()) {
591     case StandardTypes::BF16:
592     case StandardTypes::F16:
593     case StandardTypes::F32:
594     case StandardTypes::F64:
595       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
596       break;
597     case StandardTypes::Integer:
598       intVal = values[i].isa<BoolAttr>()
599                    ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0)
600                    : values[i].cast<IntegerAttr>().getValue();
601       break;
602     default:
603       llvm_unreachable("unexpected element type");
604     }
605     assert(intVal.getBitWidth() == bitWidth &&
606            "expected value to have same bitwidth as element type");
607     writeBits(data.data(), i * storageBitWidth, intVal);
608   }
609   return getRaw(type, data, /*isSplat=*/(values.size() == 1));
610 }
611 
612 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
613                                          ArrayRef<bool> values) {
614   assert(hasSameElementsOrSplat(type, values));
615   assert(type.getElementType().isInteger(1));
616 
617   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
618   for (int i = 0, e = values.size(); i != e; ++i)
619     setBit(buff.data(), i, values[i]);
620   return getRaw(type, buff, /*isSplat=*/(values.size() == 1));
621 }
622 
623 /// Constructs a dense integer elements attribute from an array of APInt
624 /// values. Each APInt value is expected to have the same bitwidth as the
625 /// element type of 'type'.
626 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
627                                          ArrayRef<APInt> values) {
628   assert(type.getElementType().isa<IntegerType>());
629   return getRaw(type, values);
630 }
631 
632 // Constructs a dense float elements attribute from an array of APFloat
633 // values. Each APFloat value is expected to have the same bitwidth as the
634 // element type of 'type'.
635 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
636                                          ArrayRef<APFloat> values) {
637   assert(type.getElementType().isa<FloatType>());
638 
639   // Convert the APFloat values to APInt and create a dense elements attribute.
640   std::vector<APInt> intValues(values.size());
641   for (unsigned i = 0, e = values.size(); i != e; ++i)
642     intValues[i] = values[i].bitcastToAPInt();
643   return getRaw(type, intValues);
644 }
645 
646 // Constructs a dense elements attribute from an array of raw APInt values.
647 // Each APInt value is expected to have the same bitwidth as the element type
648 // of 'type'.
649 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
650                                             ArrayRef<APInt> values) {
651   assert(hasSameElementsOrSplat(type, values));
652 
653   size_t bitWidth = getDenseElementBitwidth(type.getElementType());
654   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
655   std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
656                                 values.size());
657   for (unsigned i = 0, e = values.size(); i != e; ++i) {
658     assert(values[i].getBitWidth() == bitWidth);
659     writeBits(elementData.data(), i * storageBitWidth, values[i]);
660   }
661   return getRaw(type, elementData, /*isSplat=*/(values.size() == 1));
662 }
663 
664 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
665                                             ArrayRef<char> data, bool isSplat) {
666   assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
667          "type must be ranked tensor or vector");
668   assert(type.hasStaticShape() && "type must have static shape");
669   return Base::get(type.getContext(), StandardAttributes::DenseElements, type,
670                    data, isSplat);
671 }
672 
673 /// Check the information for a c++ data type, check if this type is valid for
674 /// the current attribute. This method is used to verify specific type
675 /// invariants that the templatized 'getValues' method cannot.
676 static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize,
677                               bool isInt) {
678   // Make sure that the data element size is the same as the type element width.
679   if (getDenseElementBitwidth(type.getElementType()) !=
680       static_cast<size_t>(dataEltSize * CHAR_BIT))
681     return false;
682 
683   // Check that the element type is valid.
684   return isInt ? type.getElementType().isa<IntegerType>()
685                : type.getElementType().isa<FloatType>();
686 }
687 
688 /// Overload of the 'getRaw' method that asserts that the given type is of
689 /// integer type. This method is used to verify type invariants that the
690 /// templatized 'get' method cannot.
691 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
692                                                       ArrayRef<char> data,
693                                                       int64_t dataEltSize,
694                                                       bool isInt) {
695   assert(::isValidIntOrFloat(type, dataEltSize, isInt));
696 
697   int64_t numElements = data.size() / dataEltSize;
698   assert(numElements == 1 || numElements == type.getNumElements());
699   return getRaw(type, data, /*isSplat=*/numElements == 1);
700 }
701 
702 /// A method used to verify specific type invariants that the templatized 'get'
703 /// method cannot.
704 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize,
705                                           bool isInt) const {
706   return ::isValidIntOrFloat(getType(), dataEltSize, isInt);
707 }
708 
709 /// Return the raw storage data held by this attribute.
710 ArrayRef<char> DenseElementsAttr::getRawData() const {
711   return static_cast<ImplType *>(impl)->data;
712 }
713 
714 /// Returns if this attribute corresponds to a splat, i.e. if all element
715 /// values are the same.
716 bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; }
717 
718 /// Return the held element values as a range of Attributes.
719 auto DenseElementsAttr::getAttributeValues() const
720     -> llvm::iterator_range<AttributeElementIterator> {
721   return {attr_value_begin(), attr_value_end()};
722 }
723 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
724   return AttributeElementIterator(*this, 0);
725 }
726 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
727   return AttributeElementIterator(*this, getNumElements());
728 }
729 
730 /// Return the held element values as a range of bool. The element type of
731 /// this attribute must be of integer type of bitwidth 1.
732 auto DenseElementsAttr::getBoolValues() const
733     -> llvm::iterator_range<BoolElementIterator> {
734   auto eltType = getType().getElementType().dyn_cast<IntegerType>();
735   assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
736   (void)eltType;
737   return {BoolElementIterator(*this, 0),
738           BoolElementIterator(*this, getNumElements())};
739 }
740 
741 /// Return the held element values as a range of APInts. The element type of
742 /// this attribute must be of integer type.
743 auto DenseElementsAttr::getIntValues() const
744     -> llvm::iterator_range<IntElementIterator> {
745   assert(getType().getElementType().isa<IntegerType>() &&
746          "expected integer type");
747   return {raw_int_begin(), raw_int_end()};
748 }
749 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
750   assert(getType().getElementType().isa<IntegerType>() &&
751          "expected integer type");
752   return raw_int_begin();
753 }
754 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
755   assert(getType().getElementType().isa<IntegerType>() &&
756          "expected integer type");
757   return raw_int_end();
758 }
759 
760 /// Return the held element values as a range of APFloat. The element type of
761 /// this attribute must be of float type.
762 auto DenseElementsAttr::getFloatValues() const
763     -> llvm::iterator_range<FloatElementIterator> {
764   auto elementType = getType().getElementType().cast<FloatType>();
765   assert(elementType.isa<FloatType>() && "expected float type");
766   const auto &elementSemantics = elementType.getFloatSemantics();
767   return {FloatElementIterator(elementSemantics, raw_int_begin()),
768           FloatElementIterator(elementSemantics, raw_int_end())};
769 }
770 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
771   return getFloatValues().begin();
772 }
773 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
774   return getFloatValues().end();
775 }
776 
777 /// Return a new DenseElementsAttr that has the same data as the current
778 /// attribute, but has been reshaped to 'newType'. The new type must have the
779 /// same total number of elements as well as element type.
780 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
781   ShapedType curType = getType();
782   if (curType == newType)
783     return *this;
784 
785   (void)curType;
786   assert(newType.getElementType() == curType.getElementType() &&
787          "expected the same element type");
788   assert(newType.getNumElements() == curType.getNumElements() &&
789          "expected the same number of elements");
790   return getRaw(newType, getRawData(), isSplat());
791 }
792 
793 DenseElementsAttr
794 DenseElementsAttr::mapValues(Type newElementType,
795                              function_ref<APInt(const APInt &)> mapping) const {
796   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
797 }
798 
799 DenseElementsAttr DenseElementsAttr::mapValues(
800     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
801   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
802 }
803 
804 //===----------------------------------------------------------------------===//
805 // DenseFPElementsAttr
806 //===----------------------------------------------------------------------===//
807 
808 template <typename Fn, typename Attr>
809 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
810                                 Type newElementType,
811                                 llvm::SmallVectorImpl<char> &data) {
812   size_t bitWidth = getDenseElementBitwidth(newElementType);
813   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
814 
815   ShapedType newArrayType;
816   if (inType.isa<RankedTensorType>())
817     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
818   else if (inType.isa<UnrankedTensorType>())
819     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
820   else if (inType.isa<VectorType>())
821     newArrayType = VectorType::get(inType.getShape(), newElementType);
822   else
823     assert(newArrayType && "Unhandled tensor type");
824 
825   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
826   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
827 
828   // Functor used to process a single element value of the attribute.
829   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
830     auto newInt = mapping(value);
831     assert(newInt.getBitWidth() == bitWidth);
832     writeBits(data.data(), index * storageBitWidth, newInt);
833   };
834 
835   // Check for the splat case.
836   if (attr.isSplat()) {
837     processElt(*attr.begin(), /*index=*/0);
838     return newArrayType;
839   }
840 
841   // Otherwise, process all of the element values.
842   uint64_t elementIdx = 0;
843   for (auto value : attr)
844     processElt(value, elementIdx++);
845   return newArrayType;
846 }
847 
848 DenseElementsAttr DenseFPElementsAttr::mapValues(
849     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
850   llvm::SmallVector<char, 8> elementData;
851   auto newArrayType =
852       mappingHelper(mapping, *this, getType(), newElementType, elementData);
853 
854   return getRaw(newArrayType, elementData, isSplat());
855 }
856 
857 /// Method for supporting type inquiry through isa, cast and dyn_cast.
858 bool DenseFPElementsAttr::classof(Attribute attr) {
859   return attr.isa<DenseElementsAttr>() &&
860          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
861 }
862 
863 //===----------------------------------------------------------------------===//
864 // DenseIntElementsAttr
865 //===----------------------------------------------------------------------===//
866 
867 DenseElementsAttr DenseIntElementsAttr::mapValues(
868     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
869   llvm::SmallVector<char, 8> elementData;
870   auto newArrayType =
871       mappingHelper(mapping, *this, getType(), newElementType, elementData);
872 
873   return getRaw(newArrayType, elementData, isSplat());
874 }
875 
876 /// Method for supporting type inquiry through isa, cast and dyn_cast.
877 bool DenseIntElementsAttr::classof(Attribute attr) {
878   return attr.isa<DenseElementsAttr>() &&
879          attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>();
880 }
881 
882 //===----------------------------------------------------------------------===//
883 // OpaqueElementsAttr
884 //===----------------------------------------------------------------------===//
885 
886 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
887                                            StringRef bytes) {
888   assert(TensorType::isValidElementType(type.getElementType()) &&
889          "Input element type should be a valid tensor element type");
890   return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type,
891                    dialect, bytes);
892 }
893 
894 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
895 
896 /// Return the value at the given index. If index does not refer to a valid
897 /// element, then a null attribute is returned.
898 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
899   assert(isValidIndex(index) && "expected valid multi-dimensional index");
900   if (Dialect *dialect = getDialect())
901     return dialect->extractElementHook(*this, index);
902   return Attribute();
903 }
904 
905 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
906 
907 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
908   if (auto *d = getDialect())
909     return d->decodeHook(*this, result);
910   return true;
911 }
912 
913 //===----------------------------------------------------------------------===//
914 // SparseElementsAttr
915 //===----------------------------------------------------------------------===//
916 
917 SparseElementsAttr SparseElementsAttr::get(ShapedType type,
918                                            DenseElementsAttr indices,
919                                            DenseElementsAttr values) {
920   assert(indices.getType().getElementType().isInteger(64) &&
921          "expected sparse indices to be 64-bit integer values");
922   assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
923          "type must be ranked tensor or vector");
924   assert(type.hasStaticShape() && "type must have static shape");
925   return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
926                    indices.cast<DenseIntElementsAttr>(), values);
927 }
928 
929 DenseIntElementsAttr SparseElementsAttr::getIndices() const {
930   return getImpl()->indices;
931 }
932 
933 DenseElementsAttr SparseElementsAttr::getValues() const {
934   return getImpl()->values;
935 }
936 
937 /// Return the value of the element at the given index.
938 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
939   assert(isValidIndex(index) && "expected valid multi-dimensional index");
940   auto type = getType();
941 
942   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
943   // as a 1-D index array.
944   auto sparseIndices = getIndices();
945   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
946 
947   // Check to see if the indices are a splat.
948   if (sparseIndices.isSplat()) {
949     // If the index is also not a splat of the index value, we know that the
950     // value is zero.
951     auto splatIndex = *sparseIndexValues.begin();
952     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
953       return getZeroAttr();
954 
955     // If the indices are a splat, we also expect the values to be a splat.
956     assert(getValues().isSplat() && "expected splat values");
957     return getValues().getSplatValue();
958   }
959 
960   // Build a mapping between known indices and the offset of the stored element.
961   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
962   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
963   size_t rank = type.getRank();
964   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
965     mappedIndices.try_emplace(
966         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
967 
968   // Look for the provided index key within the mapped indices. If the provided
969   // index is not found, then return a zero attribute.
970   auto it = mappedIndices.find(index);
971   if (it == mappedIndices.end())
972     return getZeroAttr();
973 
974   // Otherwise, return the held sparse value element.
975   return getValues().getValue(it->second);
976 }
977 
978 /// Get a zero APFloat for the given sparse attribute.
979 APFloat SparseElementsAttr::getZeroAPFloat() const {
980   auto eltType = getType().getElementType().cast<FloatType>();
981   return APFloat(eltType.getFloatSemantics());
982 }
983 
984 /// Get a zero APInt for the given sparse attribute.
985 APInt SparseElementsAttr::getZeroAPInt() const {
986   auto eltType = getType().getElementType().cast<IntegerType>();
987   return APInt::getNullValue(eltType.getWidth());
988 }
989 
990 /// Get a zero attribute for the given attribute type.
991 Attribute SparseElementsAttr::getZeroAttr() const {
992   auto eltType = getType().getElementType();
993 
994   // Handle floating point elements.
995   if (eltType.isa<FloatType>())
996     return FloatAttr::get(eltType, 0);
997 
998   // Otherwise, this is an integer.
999   auto intEltTy = eltType.cast<IntegerType>();
1000   if (intEltTy.getWidth() == 1)
1001     return BoolAttr::get(false, eltType.getContext());
1002   return IntegerAttr::get(eltType, 0);
1003 }
1004 
1005 /// Flatten, and return, all of the sparse indices in this attribute in
1006 /// row-major order.
1007 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1008   std::vector<ptrdiff_t> flatSparseIndices;
1009 
1010   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1011   // as a 1-D index array.
1012   auto sparseIndices = getIndices();
1013   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1014   if (sparseIndices.isSplat()) {
1015     SmallVector<uint64_t, 8> indices(getType().getRank(),
1016                                      *sparseIndexValues.begin());
1017     flatSparseIndices.push_back(getFlattenedIndex(indices));
1018     return flatSparseIndices;
1019   }
1020 
1021   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1022   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1023   size_t rank = getType().getRank();
1024   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1025     flatSparseIndices.push_back(getFlattenedIndex(
1026         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1027   return flatSparseIndices;
1028 }
1029 
1030 //===----------------------------------------------------------------------===//
1031 // NamedAttributeList
1032 //===----------------------------------------------------------------------===//
1033 
1034 NamedAttributeList::NamedAttributeList(ArrayRef<NamedAttribute> attributes) {
1035   setAttrs(attributes);
1036 }
1037 
1038 ArrayRef<NamedAttribute> NamedAttributeList::getAttrs() const {
1039   return attrs ? attrs.getValue() : llvm::None;
1040 }
1041 
1042 /// Replace the held attributes with ones provided in 'newAttrs'.
1043 void NamedAttributeList::setAttrs(ArrayRef<NamedAttribute> attributes) {
1044   // Don't create an attribute list if there are no attributes.
1045   if (attributes.empty())
1046     attrs = nullptr;
1047   else
1048     attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
1049 }
1050 
1051 /// Return the specified attribute if present, null otherwise.
1052 Attribute NamedAttributeList::get(StringRef name) const {
1053   return attrs ? attrs.get(name) : nullptr;
1054 }
1055 
1056 /// Return the specified attribute if present, null otherwise.
1057 Attribute NamedAttributeList::get(Identifier name) const {
1058   return attrs ? attrs.get(name) : nullptr;
1059 }
1060 
1061 /// If the an attribute exists with the specified name, change it to the new
1062 /// value.  Otherwise, add a new attribute with the specified name/value.
1063 void NamedAttributeList::set(Identifier name, Attribute value) {
1064   assert(value && "attributes may never be null");
1065 
1066   // If we already have this attribute, replace it.
1067   auto origAttrs = getAttrs();
1068   SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end());
1069   for (auto &elt : newAttrs)
1070     if (elt.first == name) {
1071       elt.second = value;
1072       attrs = DictionaryAttr::get(newAttrs, value.getContext());
1073       return;
1074     }
1075 
1076   // Otherwise, add it.
1077   newAttrs.push_back({name, value});
1078   attrs = DictionaryAttr::get(newAttrs, value.getContext());
1079 }
1080 
1081 /// Remove the attribute with the specified name if it exists.  The return
1082 /// value indicates whether the attribute was present or not.
1083 auto NamedAttributeList::remove(Identifier name) -> RemoveResult {
1084   auto origAttrs = getAttrs();
1085   for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
1086     if (origAttrs[i].first == name) {
1087       // Handle the simple case of removing the only attribute in the list.
1088       if (e == 1) {
1089         attrs = nullptr;
1090         return RemoveResult::Removed;
1091       }
1092 
1093       SmallVector<NamedAttribute, 8> newAttrs;
1094       newAttrs.reserve(origAttrs.size() - 1);
1095       newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
1096       newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
1097       attrs = DictionaryAttr::get(newAttrs, newAttrs[0].second.getContext());
1098       return RemoveResult::Removed;
1099     }
1100   }
1101   return RemoveResult::NotFound;
1102 }
1103