1 //===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/BuiltinAttributes.h"
10 #include "mlir/CAPI/AffineMap.h"
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Support.h"
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // Affine map attribute.
20 //===----------------------------------------------------------------------===//
21 
22 bool mlirAttributeIsAAffineMap(MlirAttribute attr) {
23   return unwrap(attr).isa<AffineMapAttr>();
24 }
25 
26 MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) {
27   return wrap(AffineMapAttr::get(unwrap(map)));
28 }
29 
30 MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) {
31   return wrap(unwrap(attr).cast<AffineMapAttr>().getValue());
32 }
33 
34 //===----------------------------------------------------------------------===//
35 // Array attribute.
36 //===----------------------------------------------------------------------===//
37 
38 bool mlirAttributeIsAArray(MlirAttribute attr) {
39   return unwrap(attr).isa<ArrayAttr>();
40 }
41 
42 MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
43                                MlirAttribute const *elements) {
44   SmallVector<Attribute, 8> attrs;
45   return wrap(
46       ArrayAttr::get(unwrap(ctx), unwrapList(static_cast<size_t>(numElements),
47                                              elements, attrs)));
48 }
49 
50 intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) {
51   return static_cast<intptr_t>(unwrap(attr).cast<ArrayAttr>().size());
52 }
53 
54 MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) {
55   return wrap(unwrap(attr).cast<ArrayAttr>().getValue()[pos]);
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // Dictionary attribute.
60 //===----------------------------------------------------------------------===//
61 
62 bool mlirAttributeIsADictionary(MlirAttribute attr) {
63   return unwrap(attr).isa<DictionaryAttr>();
64 }
65 
66 MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
67                                     MlirNamedAttribute const *elements) {
68   SmallVector<NamedAttribute, 8> attributes;
69   attributes.reserve(numElements);
70   for (intptr_t i = 0; i < numElements; ++i)
71     attributes.emplace_back(
72         Identifier::get(unwrap(elements[i].name), unwrap(ctx)),
73         unwrap(elements[i].attribute));
74   return wrap(DictionaryAttr::get(unwrap(ctx), attributes));
75 }
76 
77 intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) {
78   return static_cast<intptr_t>(unwrap(attr).cast<DictionaryAttr>().size());
79 }
80 
81 MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr,
82                                                 intptr_t pos) {
83   NamedAttribute attribute =
84       unwrap(attr).cast<DictionaryAttr>().getValue()[pos];
85   return {wrap(attribute.first), wrap(attribute.second)};
86 }
87 
88 MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
89                                                  MlirStringRef name) {
90   return wrap(unwrap(attr).cast<DictionaryAttr>().get(unwrap(name)));
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // Floating point attribute.
95 //===----------------------------------------------------------------------===//
96 
97 bool mlirAttributeIsAFloat(MlirAttribute attr) {
98   return unwrap(attr).isa<FloatAttr>();
99 }
100 
101 MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
102                                      double value) {
103   return wrap(FloatAttr::get(unwrap(type), value));
104 }
105 
106 MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type,
107                                             double value) {
108   return wrap(FloatAttr::getChecked(unwrap(loc), unwrap(type), value));
109 }
110 
111 double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
112   return unwrap(attr).cast<FloatAttr>().getValueAsDouble();
113 }
114 
115 //===----------------------------------------------------------------------===//
116 // Integer attribute.
117 //===----------------------------------------------------------------------===//
118 
119 bool mlirAttributeIsAInteger(MlirAttribute attr) {
120   return unwrap(attr).isa<IntegerAttr>();
121 }
122 
123 MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) {
124   return wrap(IntegerAttr::get(unwrap(type), value));
125 }
126 
127 int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) {
128   return unwrap(attr).cast<IntegerAttr>().getInt();
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // Bool attribute.
133 //===----------------------------------------------------------------------===//
134 
135 bool mlirAttributeIsABool(MlirAttribute attr) {
136   return unwrap(attr).isa<BoolAttr>();
137 }
138 
139 MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
140   return wrap(BoolAttr::get(unwrap(ctx), value));
141 }
142 
143 bool mlirBoolAttrGetValue(MlirAttribute attr) {
144   return unwrap(attr).cast<BoolAttr>().getValue();
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // Integer set attribute.
149 //===----------------------------------------------------------------------===//
150 
151 bool mlirAttributeIsAIntegerSet(MlirAttribute attr) {
152   return unwrap(attr).isa<IntegerSetAttr>();
153 }
154 
155 //===----------------------------------------------------------------------===//
156 // Opaque attribute.
157 //===----------------------------------------------------------------------===//
158 
159 bool mlirAttributeIsAOpaque(MlirAttribute attr) {
160   return unwrap(attr).isa<OpaqueAttr>();
161 }
162 
163 MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
164                                 intptr_t dataLength, const char *data,
165                                 MlirType type) {
166   return wrap(
167       OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
168                       StringRef(data, dataLength), unwrap(type)));
169 }
170 
171 MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
172   return wrap(unwrap(attr).cast<OpaqueAttr>().getDialectNamespace().strref());
173 }
174 
175 MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) {
176   return wrap(unwrap(attr).cast<OpaqueAttr>().getAttrData());
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // String attribute.
181 //===----------------------------------------------------------------------===//
182 
183 bool mlirAttributeIsAString(MlirAttribute attr) {
184   return unwrap(attr).isa<StringAttr>();
185 }
186 
187 MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
188   return wrap(StringAttr::get(unwrap(ctx), unwrap(str)));
189 }
190 
191 MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
192   return wrap(StringAttr::get(unwrap(str), unwrap(type)));
193 }
194 
195 MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {
196   return wrap(unwrap(attr).cast<StringAttr>().getValue());
197 }
198 
199 //===----------------------------------------------------------------------===//
200 // SymbolRef attribute.
201 //===----------------------------------------------------------------------===//
202 
203 bool mlirAttributeIsASymbolRef(MlirAttribute attr) {
204   return unwrap(attr).isa<SymbolRefAttr>();
205 }
206 
207 MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
208                                    intptr_t numReferences,
209                                    MlirAttribute const *references) {
210   SmallVector<FlatSymbolRefAttr, 4> refs;
211   refs.reserve(numReferences);
212   for (intptr_t i = 0; i < numReferences; ++i)
213     refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>());
214   return wrap(SymbolRefAttr::get(unwrap(ctx), unwrap(symbol), refs));
215 }
216 
217 MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
218   return wrap(unwrap(attr).cast<SymbolRefAttr>().getRootReference());
219 }
220 
221 MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) {
222   return wrap(unwrap(attr).cast<SymbolRefAttr>().getLeafReference());
223 }
224 
225 intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) {
226   return static_cast<intptr_t>(
227       unwrap(attr).cast<SymbolRefAttr>().getNestedReferences().size());
228 }
229 
230 MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
231                                                   intptr_t pos) {
232   return wrap(unwrap(attr).cast<SymbolRefAttr>().getNestedReferences()[pos]);
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // Flat SymbolRef attribute.
237 //===----------------------------------------------------------------------===//
238 
239 bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
240   return unwrap(attr).isa<FlatSymbolRefAttr>();
241 }
242 
243 MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) {
244   return wrap(FlatSymbolRefAttr::get(unwrap(ctx), unwrap(symbol)));
245 }
246 
247 MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
248   return wrap(unwrap(attr).cast<FlatSymbolRefAttr>().getValue());
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // Type attribute.
253 //===----------------------------------------------------------------------===//
254 
255 bool mlirAttributeIsAType(MlirAttribute attr) {
256   return unwrap(attr).isa<TypeAttr>();
257 }
258 
259 MlirAttribute mlirTypeAttrGet(MlirType type) {
260   return wrap(TypeAttr::get(unwrap(type)));
261 }
262 
263 MlirType mlirTypeAttrGetValue(MlirAttribute attr) {
264   return wrap(unwrap(attr).cast<TypeAttr>().getValue());
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // Unit attribute.
269 //===----------------------------------------------------------------------===//
270 
271 bool mlirAttributeIsAUnit(MlirAttribute attr) {
272   return unwrap(attr).isa<UnitAttr>();
273 }
274 
275 MlirAttribute mlirUnitAttrGet(MlirContext ctx) {
276   return wrap(UnitAttr::get(unwrap(ctx)));
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // Elements attributes.
281 //===----------------------------------------------------------------------===//
282 
283 bool mlirAttributeIsAElements(MlirAttribute attr) {
284   return unwrap(attr).isa<ElementsAttr>();
285 }
286 
287 MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank,
288                                        uint64_t *idxs) {
289   return wrap(unwrap(attr).cast<ElementsAttr>().getValue(
290       llvm::makeArrayRef(idxs, rank)));
291 }
292 
293 bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank,
294                                   uint64_t *idxs) {
295   return unwrap(attr).cast<ElementsAttr>().isValidIndex(
296       llvm::makeArrayRef(idxs, rank));
297 }
298 
299 int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
300   return unwrap(attr).cast<ElementsAttr>().getNumElements();
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // Dense elements attribute.
305 //===----------------------------------------------------------------------===//
306 
307 //===----------------------------------------------------------------------===//
308 // IsA support.
309 
310 bool mlirAttributeIsADenseElements(MlirAttribute attr) {
311   return unwrap(attr).isa<DenseElementsAttr>();
312 }
313 bool mlirAttributeIsADenseIntElements(MlirAttribute attr) {
314   return unwrap(attr).isa<DenseIntElementsAttr>();
315 }
316 bool mlirAttributeIsADenseFPElements(MlirAttribute attr) {
317   return unwrap(attr).isa<DenseFPElementsAttr>();
318 }
319 
320 //===----------------------------------------------------------------------===//
321 // Constructors.
322 
323 MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
324                                        intptr_t numElements,
325                                        MlirAttribute const *elements) {
326   SmallVector<Attribute, 8> attributes;
327   return wrap(
328       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
329                              unwrapList(numElements, elements, attributes)));
330 }
331 
332 MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,
333                                             MlirAttribute element) {
334   return wrap(DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
335                                      unwrap(element)));
336 }
337 MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,
338                                                 bool element) {
339   return wrap(
340       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
341 }
342 MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,
343                                                   uint32_t element) {
344   return wrap(
345       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
346 }
347 MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType,
348                                                  int32_t element) {
349   return wrap(
350       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
351 }
352 MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType,
353                                                   uint64_t element) {
354   return wrap(
355       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
356 }
357 MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType,
358                                                  int64_t element) {
359   return wrap(
360       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
361 }
362 MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType,
363                                                  float element) {
364   return wrap(
365       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
366 }
367 MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,
368                                                   double element) {
369   return wrap(
370       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
371 }
372 
373 MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
374                                            intptr_t numElements,
375                                            const int *elements) {
376   SmallVector<bool, 8> values(elements, elements + numElements);
377   return wrap(
378       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
379 }
380 
381 /// Creates a dense attribute with elements of the type deduced by templates.
382 template <typename T>
383 static MlirAttribute getDenseAttribute(MlirType shapedType,
384                                        intptr_t numElements,
385                                        const T *elements) {
386   return wrap(
387       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
388                              llvm::makeArrayRef(elements, numElements)));
389 }
390 
391 MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
392                                              intptr_t numElements,
393                                              const uint32_t *elements) {
394   return getDenseAttribute(shapedType, numElements, elements);
395 }
396 MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType,
397                                             intptr_t numElements,
398                                             const int32_t *elements) {
399   return getDenseAttribute(shapedType, numElements, elements);
400 }
401 MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType,
402                                              intptr_t numElements,
403                                              const uint64_t *elements) {
404   return getDenseAttribute(shapedType, numElements, elements);
405 }
406 MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType,
407                                             intptr_t numElements,
408                                             const int64_t *elements) {
409   return getDenseAttribute(shapedType, numElements, elements);
410 }
411 MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType,
412                                             intptr_t numElements,
413                                             const float *elements) {
414   return getDenseAttribute(shapedType, numElements, elements);
415 }
416 MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
417                                              intptr_t numElements,
418                                              const double *elements) {
419   return getDenseAttribute(shapedType, numElements, elements);
420 }
421 
422 MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
423                                              intptr_t numElements,
424                                              MlirStringRef *strs) {
425   SmallVector<StringRef, 8> values;
426   values.reserve(numElements);
427   for (intptr_t i = 0; i < numElements; ++i)
428     values.push_back(unwrap(strs[i]));
429 
430   return wrap(
431       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
432 }
433 
434 MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
435                                               MlirType shapedType) {
436   return wrap(unwrap(attr).cast<DenseElementsAttr>().reshape(
437       unwrap(shapedType).cast<ShapedType>()));
438 }
439 
440 //===----------------------------------------------------------------------===//
441 // Splat accessors.
442 
443 bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
444   return unwrap(attr).cast<DenseElementsAttr>().isSplat();
445 }
446 
447 MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
448   return wrap(unwrap(attr).cast<DenseElementsAttr>().getSplatValue());
449 }
450 int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
451   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<bool>();
452 }
453 int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) {
454   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int32_t>();
455 }
456 uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) {
457   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint32_t>();
458 }
459 int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) {
460   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int64_t>();
461 }
462 uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) {
463   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint64_t>();
464 }
465 float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) {
466   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<float>();
467 }
468 double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) {
469   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<double>();
470 }
471 MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
472   return wrap(
473       unwrap(attr).cast<DenseElementsAttr>().getSplatValue<StringRef>());
474 }
475 
476 //===----------------------------------------------------------------------===//
477 // Indexed accessors.
478 
479 bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
480   return *(unwrap(attr).cast<DenseElementsAttr>().getValues<bool>().begin() +
481            pos);
482 }
483 int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
484   return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>().begin() +
485            pos);
486 }
487 uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
488   return *(
489       unwrap(attr).cast<DenseElementsAttr>().getValues<uint32_t>().begin() +
490       pos);
491 }
492 int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
493   return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int64_t>().begin() +
494            pos);
495 }
496 uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
497   return *(
498       unwrap(attr).cast<DenseElementsAttr>().getValues<uint64_t>().begin() +
499       pos);
500 }
501 float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
502   return *(unwrap(attr).cast<DenseElementsAttr>().getValues<float>().begin() +
503            pos);
504 }
505 double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
506   return *(unwrap(attr).cast<DenseElementsAttr>().getValues<double>().begin() +
507            pos);
508 }
509 MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
510                                                   intptr_t pos) {
511   return wrap(
512       *(unwrap(attr).cast<DenseElementsAttr>().getValues<StringRef>().begin() +
513         pos));
514 }
515 
516 //===----------------------------------------------------------------------===//
517 // Raw data accessors.
518 
519 const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
520   return static_cast<const void *>(
521       unwrap(attr).cast<DenseElementsAttr>().getRawData().data());
522 }
523 
524 //===----------------------------------------------------------------------===//
525 // Opaque elements attribute.
526 //===----------------------------------------------------------------------===//
527 
528 bool mlirAttributeIsAOpaqueElements(MlirAttribute attr) {
529   return unwrap(attr).isa<OpaqueElementsAttr>();
530 }
531 
532 //===----------------------------------------------------------------------===//
533 // Sparse elements attribute.
534 //===----------------------------------------------------------------------===//
535 
536 bool mlirAttributeIsASparseElements(MlirAttribute attr) {
537   return unwrap(attr).isa<SparseElementsAttr>();
538 }
539 
540 MlirAttribute mlirSparseElementsAttribute(MlirType shapedType,
541                                           MlirAttribute denseIndices,
542                                           MlirAttribute denseValues) {
543   return wrap(
544       SparseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
545                               unwrap(denseIndices).cast<DenseElementsAttr>(),
546                               unwrap(denseValues).cast<DenseElementsAttr>()));
547 }
548 
549 MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) {
550   return wrap(unwrap(attr).cast<SparseElementsAttr>().getIndices());
551 }
552 
553 MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) {
554   return wrap(unwrap(attr).cast<SparseElementsAttr>().getValues());
555 }
556