1 //===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/BuiltinAttributes.h"
10 #include "mlir/CAPI/AffineMap.h"
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Support.h"
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 
16 using namespace mlir;
17 
18 MlirAttribute mlirAttributeGetNull() { return {nullptr}; }
19 
20 //===----------------------------------------------------------------------===//
21 // Affine map attribute.
22 //===----------------------------------------------------------------------===//
23 
24 bool mlirAttributeIsAAffineMap(MlirAttribute attr) {
25   return unwrap(attr).isa<AffineMapAttr>();
26 }
27 
28 MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) {
29   return wrap(AffineMapAttr::get(unwrap(map)));
30 }
31 
32 MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) {
33   return wrap(unwrap(attr).cast<AffineMapAttr>().getValue());
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // Array attribute.
38 //===----------------------------------------------------------------------===//
39 
40 bool mlirAttributeIsAArray(MlirAttribute attr) {
41   return unwrap(attr).isa<ArrayAttr>();
42 }
43 
44 MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
45                                MlirAttribute const *elements) {
46   SmallVector<Attribute, 8> attrs;
47   return wrap(
48       ArrayAttr::get(unwrap(ctx), unwrapList(static_cast<size_t>(numElements),
49                                              elements, attrs)));
50 }
51 
52 intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) {
53   return static_cast<intptr_t>(unwrap(attr).cast<ArrayAttr>().size());
54 }
55 
56 MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) {
57   return wrap(unwrap(attr).cast<ArrayAttr>().getValue()[pos]);
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // Dictionary attribute.
62 //===----------------------------------------------------------------------===//
63 
64 bool mlirAttributeIsADictionary(MlirAttribute attr) {
65   return unwrap(attr).isa<DictionaryAttr>();
66 }
67 
68 MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
69                                     MlirNamedAttribute const *elements) {
70   SmallVector<NamedAttribute, 8> attributes;
71   attributes.reserve(numElements);
72   for (intptr_t i = 0; i < numElements; ++i)
73     attributes.emplace_back(unwrap(elements[i].name),
74                             unwrap(elements[i].attribute));
75   return wrap(DictionaryAttr::get(unwrap(ctx), attributes));
76 }
77 
78 intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) {
79   return static_cast<intptr_t>(unwrap(attr).cast<DictionaryAttr>().size());
80 }
81 
82 MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr,
83                                                 intptr_t pos) {
84   NamedAttribute attribute =
85       unwrap(attr).cast<DictionaryAttr>().getValue()[pos];
86   return {wrap(attribute.getName()), wrap(attribute.getValue())};
87 }
88 
89 MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
90                                                  MlirStringRef name) {
91   return wrap(unwrap(attr).cast<DictionaryAttr>().get(unwrap(name)));
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // Floating point attribute.
96 //===----------------------------------------------------------------------===//
97 
98 bool mlirAttributeIsAFloat(MlirAttribute attr) {
99   return unwrap(attr).isa<FloatAttr>();
100 }
101 
102 MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
103                                      double value) {
104   return wrap(FloatAttr::get(unwrap(type), value));
105 }
106 
107 MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type,
108                                             double value) {
109   return wrap(FloatAttr::getChecked(unwrap(loc), unwrap(type), value));
110 }
111 
112 double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
113   return unwrap(attr).cast<FloatAttr>().getValueAsDouble();
114 }
115 
116 //===----------------------------------------------------------------------===//
117 // Integer attribute.
118 //===----------------------------------------------------------------------===//
119 
120 bool mlirAttributeIsAInteger(MlirAttribute attr) {
121   return unwrap(attr).isa<IntegerAttr>();
122 }
123 
124 MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) {
125   return wrap(IntegerAttr::get(unwrap(type), value));
126 }
127 
128 int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) {
129   return unwrap(attr).cast<IntegerAttr>().getInt();
130 }
131 
132 int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) {
133   return unwrap(attr).cast<IntegerAttr>().getSInt();
134 }
135 
136 uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
137   return unwrap(attr).cast<IntegerAttr>().getUInt();
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // Bool attribute.
142 //===----------------------------------------------------------------------===//
143 
144 bool mlirAttributeIsABool(MlirAttribute attr) {
145   return unwrap(attr).isa<BoolAttr>();
146 }
147 
148 MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
149   return wrap(BoolAttr::get(unwrap(ctx), value));
150 }
151 
152 bool mlirBoolAttrGetValue(MlirAttribute attr) {
153   return unwrap(attr).cast<BoolAttr>().getValue();
154 }
155 
156 //===----------------------------------------------------------------------===//
157 // Integer set attribute.
158 //===----------------------------------------------------------------------===//
159 
160 bool mlirAttributeIsAIntegerSet(MlirAttribute attr) {
161   return unwrap(attr).isa<IntegerSetAttr>();
162 }
163 
164 //===----------------------------------------------------------------------===//
165 // Opaque attribute.
166 //===----------------------------------------------------------------------===//
167 
168 bool mlirAttributeIsAOpaque(MlirAttribute attr) {
169   return unwrap(attr).isa<OpaqueAttr>();
170 }
171 
172 MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
173                                 intptr_t dataLength, const char *data,
174                                 MlirType type) {
175   return wrap(
176       OpaqueAttr::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)),
177                       StringRef(data, dataLength), unwrap(type)));
178 }
179 
180 MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
181   return wrap(unwrap(attr).cast<OpaqueAttr>().getDialectNamespace().strref());
182 }
183 
184 MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) {
185   return wrap(unwrap(attr).cast<OpaqueAttr>().getAttrData());
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // String attribute.
190 //===----------------------------------------------------------------------===//
191 
192 bool mlirAttributeIsAString(MlirAttribute attr) {
193   return unwrap(attr).isa<StringAttr>();
194 }
195 
196 MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
197   return wrap((Attribute)StringAttr::get(unwrap(ctx), unwrap(str)));
198 }
199 
200 MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
201   return wrap((Attribute)StringAttr::get(unwrap(str), unwrap(type)));
202 }
203 
204 MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {
205   return wrap(unwrap(attr).cast<StringAttr>().getValue());
206 }
207 
208 //===----------------------------------------------------------------------===//
209 // SymbolRef attribute.
210 //===----------------------------------------------------------------------===//
211 
212 bool mlirAttributeIsASymbolRef(MlirAttribute attr) {
213   return unwrap(attr).isa<SymbolRefAttr>();
214 }
215 
216 MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
217                                    intptr_t numReferences,
218                                    MlirAttribute const *references) {
219   SmallVector<FlatSymbolRefAttr, 4> refs;
220   refs.reserve(numReferences);
221   for (intptr_t i = 0; i < numReferences; ++i)
222     refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>());
223   auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol));
224   return wrap(SymbolRefAttr::get(symbolAttr, refs));
225 }
226 
227 MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
228   return wrap(unwrap(attr).cast<SymbolRefAttr>().getRootReference().getValue());
229 }
230 
231 MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) {
232   return wrap(unwrap(attr).cast<SymbolRefAttr>().getLeafReference().getValue());
233 }
234 
235 intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) {
236   return static_cast<intptr_t>(
237       unwrap(attr).cast<SymbolRefAttr>().getNestedReferences().size());
238 }
239 
240 MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
241                                                   intptr_t pos) {
242   return wrap(unwrap(attr).cast<SymbolRefAttr>().getNestedReferences()[pos]);
243 }
244 
245 //===----------------------------------------------------------------------===//
246 // Flat SymbolRef attribute.
247 //===----------------------------------------------------------------------===//
248 
249 bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
250   return unwrap(attr).isa<FlatSymbolRefAttr>();
251 }
252 
253 MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) {
254   return wrap(FlatSymbolRefAttr::get(unwrap(ctx), unwrap(symbol)));
255 }
256 
257 MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
258   return wrap(unwrap(attr).cast<FlatSymbolRefAttr>().getValue());
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // Type attribute.
263 //===----------------------------------------------------------------------===//
264 
265 bool mlirAttributeIsAType(MlirAttribute attr) {
266   return unwrap(attr).isa<TypeAttr>();
267 }
268 
269 MlirAttribute mlirTypeAttrGet(MlirType type) {
270   return wrap(TypeAttr::get(unwrap(type)));
271 }
272 
273 MlirType mlirTypeAttrGetValue(MlirAttribute attr) {
274   return wrap(unwrap(attr).cast<TypeAttr>().getValue());
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // Unit attribute.
279 //===----------------------------------------------------------------------===//
280 
281 bool mlirAttributeIsAUnit(MlirAttribute attr) {
282   return unwrap(attr).isa<UnitAttr>();
283 }
284 
285 MlirAttribute mlirUnitAttrGet(MlirContext ctx) {
286   return wrap(UnitAttr::get(unwrap(ctx)));
287 }
288 
289 //===----------------------------------------------------------------------===//
290 // Elements attributes.
291 //===----------------------------------------------------------------------===//
292 
293 bool mlirAttributeIsAElements(MlirAttribute attr) {
294   return unwrap(attr).isa<ElementsAttr>();
295 }
296 
297 MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank,
298                                        uint64_t *idxs) {
299   return wrap(unwrap(attr)
300                   .cast<ElementsAttr>()
301                   .getValues<Attribute>()[llvm::makeArrayRef(idxs, rank)]);
302 }
303 
304 bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank,
305                                   uint64_t *idxs) {
306   return unwrap(attr).cast<ElementsAttr>().isValidIndex(
307       llvm::makeArrayRef(idxs, rank));
308 }
309 
310 int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
311   return unwrap(attr).cast<ElementsAttr>().getNumElements();
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // Dense elements attribute.
316 //===----------------------------------------------------------------------===//
317 
318 //===----------------------------------------------------------------------===//
319 // IsA support.
320 
321 bool mlirAttributeIsADenseElements(MlirAttribute attr) {
322   return unwrap(attr).isa<DenseElementsAttr>();
323 }
324 bool mlirAttributeIsADenseIntElements(MlirAttribute attr) {
325   return unwrap(attr).isa<DenseIntElementsAttr>();
326 }
327 bool mlirAttributeIsADenseFPElements(MlirAttribute attr) {
328   return unwrap(attr).isa<DenseFPElementsAttr>();
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // Constructors.
333 
334 MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
335                                        intptr_t numElements,
336                                        MlirAttribute const *elements) {
337   SmallVector<Attribute, 8> attributes;
338   return wrap(
339       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
340                              unwrapList(numElements, elements, attributes)));
341 }
342 
343 MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType,
344                                                 size_t rawBufferSize,
345                                                 const void *rawBuffer) {
346   auto shapedTypeCpp = unwrap(shapedType).cast<ShapedType>();
347   ArrayRef<char> rawBufferCpp(static_cast<const char *>(rawBuffer),
348                               rawBufferSize);
349   bool isSplat = false;
350   if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp,
351                                            isSplat))
352     return mlirAttributeGetNull();
353   return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp));
354 }
355 
356 MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,
357                                             MlirAttribute element) {
358   return wrap(DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
359                                      unwrap(element)));
360 }
361 MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,
362                                                 bool element) {
363   return wrap(
364       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
365 }
366 MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType,
367                                                  uint8_t element) {
368   return wrap(
369       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
370 }
371 MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType,
372                                                 int8_t element) {
373   return wrap(
374       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
375 }
376 MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,
377                                                   uint32_t element) {
378   return wrap(
379       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
380 }
381 MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType,
382                                                  int32_t element) {
383   return wrap(
384       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
385 }
386 MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType,
387                                                   uint64_t element) {
388   return wrap(
389       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
390 }
391 MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType,
392                                                  int64_t element) {
393   return wrap(
394       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
395 }
396 MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType,
397                                                  float element) {
398   return wrap(
399       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
400 }
401 MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,
402                                                   double element) {
403   return wrap(
404       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
405 }
406 
407 MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
408                                            intptr_t numElements,
409                                            const int *elements) {
410   SmallVector<bool, 8> values(elements, elements + numElements);
411   return wrap(
412       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
413 }
414 
415 /// Creates a dense attribute with elements of the type deduced by templates.
416 template <typename T>
417 static MlirAttribute getDenseAttribute(MlirType shapedType,
418                                        intptr_t numElements,
419                                        const T *elements) {
420   return wrap(
421       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
422                              llvm::makeArrayRef(elements, numElements)));
423 }
424 
425 MlirAttribute mlirDenseElementsAttrUInt8Get(MlirType shapedType,
426                                             intptr_t numElements,
427                                             const uint8_t *elements) {
428   return getDenseAttribute(shapedType, numElements, elements);
429 }
430 MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType,
431                                            intptr_t numElements,
432                                            const int8_t *elements) {
433   return getDenseAttribute(shapedType, numElements, elements);
434 }
435 MlirAttribute mlirDenseElementsAttrUInt16Get(MlirType shapedType,
436                                              intptr_t numElements,
437                                              const uint16_t *elements) {
438   return getDenseAttribute(shapedType, numElements, elements);
439 }
440 MlirAttribute mlirDenseElementsAttrInt16Get(MlirType shapedType,
441                                             intptr_t numElements,
442                                             const int16_t *elements) {
443   return getDenseAttribute(shapedType, numElements, elements);
444 }
445 MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
446                                              intptr_t numElements,
447                                              const uint32_t *elements) {
448   return getDenseAttribute(shapedType, numElements, elements);
449 }
450 MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType,
451                                             intptr_t numElements,
452                                             const int32_t *elements) {
453   return getDenseAttribute(shapedType, numElements, elements);
454 }
455 MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType,
456                                              intptr_t numElements,
457                                              const uint64_t *elements) {
458   return getDenseAttribute(shapedType, numElements, elements);
459 }
460 MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType,
461                                             intptr_t numElements,
462                                             const int64_t *elements) {
463   return getDenseAttribute(shapedType, numElements, elements);
464 }
465 MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType,
466                                             intptr_t numElements,
467                                             const float *elements) {
468   return getDenseAttribute(shapedType, numElements, elements);
469 }
470 MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
471                                              intptr_t numElements,
472                                              const double *elements) {
473   return getDenseAttribute(shapedType, numElements, elements);
474 }
475 MlirAttribute mlirDenseElementsAttrBFloat16Get(MlirType shapedType,
476                                                intptr_t numElements,
477                                                const uint16_t *elements) {
478   size_t bufferSize = numElements * 2;
479   const void *buffer = static_cast<const void *>(elements);
480   return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer);
481 }
482 MlirAttribute mlirDenseElementsAttrFloat16Get(MlirType shapedType,
483                                               intptr_t numElements,
484                                               const uint16_t *elements) {
485   size_t bufferSize = numElements * 2;
486   const void *buffer = static_cast<const void *>(elements);
487   return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer);
488 }
489 
490 MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
491                                              intptr_t numElements,
492                                              MlirStringRef *strs) {
493   SmallVector<StringRef, 8> values;
494   values.reserve(numElements);
495   for (intptr_t i = 0; i < numElements; ++i)
496     values.push_back(unwrap(strs[i]));
497 
498   return wrap(
499       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
500 }
501 
502 MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
503                                               MlirType shapedType) {
504   return wrap(unwrap(attr).cast<DenseElementsAttr>().reshape(
505       unwrap(shapedType).cast<ShapedType>()));
506 }
507 
508 //===----------------------------------------------------------------------===//
509 // Splat accessors.
510 
511 bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
512   return unwrap(attr).cast<DenseElementsAttr>().isSplat();
513 }
514 
515 MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
516   return wrap(
517       unwrap(attr).cast<DenseElementsAttr>().getSplatValue<Attribute>());
518 }
519 int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
520   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<bool>();
521 }
522 int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) {
523   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int8_t>();
524 }
525 uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) {
526   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint8_t>();
527 }
528 int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) {
529   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int32_t>();
530 }
531 uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) {
532   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint32_t>();
533 }
534 int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) {
535   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int64_t>();
536 }
537 uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) {
538   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint64_t>();
539 }
540 float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) {
541   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<float>();
542 }
543 double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) {
544   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<double>();
545 }
546 MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
547   return wrap(
548       unwrap(attr).cast<DenseElementsAttr>().getSplatValue<StringRef>());
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // Indexed accessors.
553 
554 bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
555   return unwrap(attr).cast<DenseElementsAttr>().getValues<bool>()[pos];
556 }
557 int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
558   return unwrap(attr).cast<DenseElementsAttr>().getValues<int8_t>()[pos];
559 }
560 uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
561   return unwrap(attr).cast<DenseElementsAttr>().getValues<uint8_t>()[pos];
562 }
563 int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) {
564   return unwrap(attr).cast<DenseElementsAttr>().getValues<int16_t>()[pos];
565 }
566 uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) {
567   return unwrap(attr).cast<DenseElementsAttr>().getValues<uint16_t>()[pos];
568 }
569 int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
570   return unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>()[pos];
571 }
572 uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
573   return unwrap(attr).cast<DenseElementsAttr>().getValues<uint32_t>()[pos];
574 }
575 int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
576   return unwrap(attr).cast<DenseElementsAttr>().getValues<int64_t>()[pos];
577 }
578 uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
579   return unwrap(attr).cast<DenseElementsAttr>().getValues<uint64_t>()[pos];
580 }
581 float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
582   return unwrap(attr).cast<DenseElementsAttr>().getValues<float>()[pos];
583 }
584 double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
585   return unwrap(attr).cast<DenseElementsAttr>().getValues<double>()[pos];
586 }
587 MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
588                                                   intptr_t pos) {
589   return wrap(
590       unwrap(attr).cast<DenseElementsAttr>().getValues<StringRef>()[pos]);
591 }
592 
593 //===----------------------------------------------------------------------===//
594 // Raw data accessors.
595 
596 const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
597   return static_cast<const void *>(
598       unwrap(attr).cast<DenseElementsAttr>().getRawData().data());
599 }
600 
601 //===----------------------------------------------------------------------===//
602 // Opaque elements attribute.
603 //===----------------------------------------------------------------------===//
604 
605 bool mlirAttributeIsAOpaqueElements(MlirAttribute attr) {
606   return unwrap(attr).isa<OpaqueElementsAttr>();
607 }
608 
609 //===----------------------------------------------------------------------===//
610 // Sparse elements attribute.
611 //===----------------------------------------------------------------------===//
612 
613 bool mlirAttributeIsASparseElements(MlirAttribute attr) {
614   return unwrap(attr).isa<SparseElementsAttr>();
615 }
616 
617 MlirAttribute mlirSparseElementsAttribute(MlirType shapedType,
618                                           MlirAttribute denseIndices,
619                                           MlirAttribute denseValues) {
620   return wrap(
621       SparseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
622                               unwrap(denseIndices).cast<DenseElementsAttr>(),
623                               unwrap(denseValues).cast<DenseElementsAttr>()));
624 }
625 
626 MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) {
627   return wrap(unwrap(attr).cast<SparseElementsAttr>().getIndices());
628 }
629 
630 MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) {
631   return wrap(unwrap(attr).cast<SparseElementsAttr>().getValues());
632 }
633