1 //===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/BuiltinAttributes.h"
10 #include "mlir/CAPI/AffineMap.h"
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Support.h"
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 
16 using namespace mlir;
17 
18 MlirAttribute mlirAttributeGetNull() { return {nullptr}; }
19 
20 //===----------------------------------------------------------------------===//
21 // Affine map attribute.
22 //===----------------------------------------------------------------------===//
23 
24 bool mlirAttributeIsAAffineMap(MlirAttribute attr) {
25   return unwrap(attr).isa<AffineMapAttr>();
26 }
27 
28 MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) {
29   return wrap(AffineMapAttr::get(unwrap(map)));
30 }
31 
32 MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) {
33   return wrap(unwrap(attr).cast<AffineMapAttr>().getValue());
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // Array attribute.
38 //===----------------------------------------------------------------------===//
39 
40 bool mlirAttributeIsAArray(MlirAttribute attr) {
41   return unwrap(attr).isa<ArrayAttr>();
42 }
43 
44 MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
45                                MlirAttribute const *elements) {
46   SmallVector<Attribute, 8> attrs;
47   return wrap(
48       ArrayAttr::get(unwrap(ctx), unwrapList(static_cast<size_t>(numElements),
49                                              elements, attrs)));
50 }
51 
52 intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) {
53   return static_cast<intptr_t>(unwrap(attr).cast<ArrayAttr>().size());
54 }
55 
56 MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) {
57   return wrap(unwrap(attr).cast<ArrayAttr>().getValue()[pos]);
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // Dictionary attribute.
62 //===----------------------------------------------------------------------===//
63 
64 bool mlirAttributeIsADictionary(MlirAttribute attr) {
65   return unwrap(attr).isa<DictionaryAttr>();
66 }
67 
68 MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
69                                     MlirNamedAttribute const *elements) {
70   SmallVector<NamedAttribute, 8> attributes;
71   attributes.reserve(numElements);
72   for (intptr_t i = 0; i < numElements; ++i)
73     attributes.emplace_back(
74         Identifier::get(unwrap(elements[i].name), unwrap(ctx)),
75         unwrap(elements[i].attribute));
76   return wrap(DictionaryAttr::get(unwrap(ctx), attributes));
77 }
78 
79 intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) {
80   return static_cast<intptr_t>(unwrap(attr).cast<DictionaryAttr>().size());
81 }
82 
83 MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr,
84                                                 intptr_t pos) {
85   NamedAttribute attribute =
86       unwrap(attr).cast<DictionaryAttr>().getValue()[pos];
87   return {wrap(attribute.first), wrap(attribute.second)};
88 }
89 
90 MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
91                                                  MlirStringRef name) {
92   return wrap(unwrap(attr).cast<DictionaryAttr>().get(unwrap(name)));
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // Floating point attribute.
97 //===----------------------------------------------------------------------===//
98 
99 bool mlirAttributeIsAFloat(MlirAttribute attr) {
100   return unwrap(attr).isa<FloatAttr>();
101 }
102 
103 MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
104                                      double value) {
105   return wrap(FloatAttr::get(unwrap(type), value));
106 }
107 
108 MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type,
109                                             double value) {
110   return wrap(FloatAttr::getChecked(unwrap(loc), unwrap(type), value));
111 }
112 
113 double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
114   return unwrap(attr).cast<FloatAttr>().getValueAsDouble();
115 }
116 
117 //===----------------------------------------------------------------------===//
118 // Integer attribute.
119 //===----------------------------------------------------------------------===//
120 
121 bool mlirAttributeIsAInteger(MlirAttribute attr) {
122   return unwrap(attr).isa<IntegerAttr>();
123 }
124 
125 MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) {
126   return wrap(IntegerAttr::get(unwrap(type), value));
127 }
128 
129 int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) {
130   return unwrap(attr).cast<IntegerAttr>().getInt();
131 }
132 
133 //===----------------------------------------------------------------------===//
134 // Bool attribute.
135 //===----------------------------------------------------------------------===//
136 
137 bool mlirAttributeIsABool(MlirAttribute attr) {
138   return unwrap(attr).isa<BoolAttr>();
139 }
140 
141 MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
142   return wrap(BoolAttr::get(unwrap(ctx), value));
143 }
144 
145 bool mlirBoolAttrGetValue(MlirAttribute attr) {
146   return unwrap(attr).cast<BoolAttr>().getValue();
147 }
148 
149 //===----------------------------------------------------------------------===//
150 // Integer set attribute.
151 //===----------------------------------------------------------------------===//
152 
153 bool mlirAttributeIsAIntegerSet(MlirAttribute attr) {
154   return unwrap(attr).isa<IntegerSetAttr>();
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // Opaque attribute.
159 //===----------------------------------------------------------------------===//
160 
161 bool mlirAttributeIsAOpaque(MlirAttribute attr) {
162   return unwrap(attr).isa<OpaqueAttr>();
163 }
164 
165 MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
166                                 intptr_t dataLength, const char *data,
167                                 MlirType type) {
168   return wrap(
169       OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
170                       StringRef(data, dataLength), unwrap(type)));
171 }
172 
173 MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
174   return wrap(unwrap(attr).cast<OpaqueAttr>().getDialectNamespace().strref());
175 }
176 
177 MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) {
178   return wrap(unwrap(attr).cast<OpaqueAttr>().getAttrData());
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // String attribute.
183 //===----------------------------------------------------------------------===//
184 
185 bool mlirAttributeIsAString(MlirAttribute attr) {
186   return unwrap(attr).isa<StringAttr>();
187 }
188 
189 MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
190   return wrap(StringAttr::get(unwrap(ctx), unwrap(str)));
191 }
192 
193 MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
194   return wrap(StringAttr::get(unwrap(str), unwrap(type)));
195 }
196 
197 MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {
198   return wrap(unwrap(attr).cast<StringAttr>().getValue());
199 }
200 
201 //===----------------------------------------------------------------------===//
202 // SymbolRef attribute.
203 //===----------------------------------------------------------------------===//
204 
205 bool mlirAttributeIsASymbolRef(MlirAttribute attr) {
206   return unwrap(attr).isa<SymbolRefAttr>();
207 }
208 
209 MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
210                                    intptr_t numReferences,
211                                    MlirAttribute const *references) {
212   SmallVector<FlatSymbolRefAttr, 4> refs;
213   refs.reserve(numReferences);
214   for (intptr_t i = 0; i < numReferences; ++i)
215     refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>());
216   return wrap(SymbolRefAttr::get(unwrap(ctx), unwrap(symbol), refs));
217 }
218 
219 MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
220   return wrap(unwrap(attr).cast<SymbolRefAttr>().getRootReference());
221 }
222 
223 MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) {
224   return wrap(unwrap(attr).cast<SymbolRefAttr>().getLeafReference());
225 }
226 
227 intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) {
228   return static_cast<intptr_t>(
229       unwrap(attr).cast<SymbolRefAttr>().getNestedReferences().size());
230 }
231 
232 MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
233                                                   intptr_t pos) {
234   return wrap(unwrap(attr).cast<SymbolRefAttr>().getNestedReferences()[pos]);
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // Flat SymbolRef attribute.
239 //===----------------------------------------------------------------------===//
240 
241 bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
242   return unwrap(attr).isa<FlatSymbolRefAttr>();
243 }
244 
245 MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) {
246   return wrap(FlatSymbolRefAttr::get(unwrap(ctx), unwrap(symbol)));
247 }
248 
249 MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
250   return wrap(unwrap(attr).cast<FlatSymbolRefAttr>().getValue());
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // Type attribute.
255 //===----------------------------------------------------------------------===//
256 
257 bool mlirAttributeIsAType(MlirAttribute attr) {
258   return unwrap(attr).isa<TypeAttr>();
259 }
260 
261 MlirAttribute mlirTypeAttrGet(MlirType type) {
262   return wrap(TypeAttr::get(unwrap(type)));
263 }
264 
265 MlirType mlirTypeAttrGetValue(MlirAttribute attr) {
266   return wrap(unwrap(attr).cast<TypeAttr>().getValue());
267 }
268 
269 //===----------------------------------------------------------------------===//
270 // Unit attribute.
271 //===----------------------------------------------------------------------===//
272 
273 bool mlirAttributeIsAUnit(MlirAttribute attr) {
274   return unwrap(attr).isa<UnitAttr>();
275 }
276 
277 MlirAttribute mlirUnitAttrGet(MlirContext ctx) {
278   return wrap(UnitAttr::get(unwrap(ctx)));
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // Elements attributes.
283 //===----------------------------------------------------------------------===//
284 
285 bool mlirAttributeIsAElements(MlirAttribute attr) {
286   return unwrap(attr).isa<ElementsAttr>();
287 }
288 
289 MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank,
290                                        uint64_t *idxs) {
291   return wrap(unwrap(attr).cast<ElementsAttr>().getValue(
292       llvm::makeArrayRef(idxs, rank)));
293 }
294 
295 bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank,
296                                   uint64_t *idxs) {
297   return unwrap(attr).cast<ElementsAttr>().isValidIndex(
298       llvm::makeArrayRef(idxs, rank));
299 }
300 
301 int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
302   return unwrap(attr).cast<ElementsAttr>().getNumElements();
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // Dense elements attribute.
307 //===----------------------------------------------------------------------===//
308 
309 //===----------------------------------------------------------------------===//
310 // IsA support.
311 
312 bool mlirAttributeIsADenseElements(MlirAttribute attr) {
313   return unwrap(attr).isa<DenseElementsAttr>();
314 }
315 bool mlirAttributeIsADenseIntElements(MlirAttribute attr) {
316   return unwrap(attr).isa<DenseIntElementsAttr>();
317 }
318 bool mlirAttributeIsADenseFPElements(MlirAttribute attr) {
319   return unwrap(attr).isa<DenseFPElementsAttr>();
320 }
321 
322 //===----------------------------------------------------------------------===//
323 // Constructors.
324 
325 MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
326                                        intptr_t numElements,
327                                        MlirAttribute const *elements) {
328   SmallVector<Attribute, 8> attributes;
329   return wrap(
330       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
331                              unwrapList(numElements, elements, attributes)));
332 }
333 
334 MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,
335                                             MlirAttribute element) {
336   return wrap(DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
337                                      unwrap(element)));
338 }
339 MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,
340                                                 bool element) {
341   return wrap(
342       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
343 }
344 MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,
345                                                   uint32_t element) {
346   return wrap(
347       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
348 }
349 MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType,
350                                                  int32_t element) {
351   return wrap(
352       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
353 }
354 MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType,
355                                                   uint64_t element) {
356   return wrap(
357       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
358 }
359 MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType,
360                                                  int64_t element) {
361   return wrap(
362       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
363 }
364 MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType,
365                                                  float element) {
366   return wrap(
367       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
368 }
369 MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,
370                                                   double element) {
371   return wrap(
372       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
373 }
374 
375 MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
376                                            intptr_t numElements,
377                                            const int *elements) {
378   SmallVector<bool, 8> values(elements, elements + numElements);
379   return wrap(
380       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
381 }
382 
383 /// Creates a dense attribute with elements of the type deduced by templates.
384 template <typename T>
385 static MlirAttribute getDenseAttribute(MlirType shapedType,
386                                        intptr_t numElements,
387                                        const T *elements) {
388   return wrap(
389       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
390                              llvm::makeArrayRef(elements, numElements)));
391 }
392 
393 MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
394                                              intptr_t numElements,
395                                              const uint32_t *elements) {
396   return getDenseAttribute(shapedType, numElements, elements);
397 }
398 MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType,
399                                             intptr_t numElements,
400                                             const int32_t *elements) {
401   return getDenseAttribute(shapedType, numElements, elements);
402 }
403 MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType,
404                                              intptr_t numElements,
405                                              const uint64_t *elements) {
406   return getDenseAttribute(shapedType, numElements, elements);
407 }
408 MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType,
409                                             intptr_t numElements,
410                                             const int64_t *elements) {
411   return getDenseAttribute(shapedType, numElements, elements);
412 }
413 MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType,
414                                             intptr_t numElements,
415                                             const float *elements) {
416   return getDenseAttribute(shapedType, numElements, elements);
417 }
418 MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
419                                              intptr_t numElements,
420                                              const double *elements) {
421   return getDenseAttribute(shapedType, numElements, elements);
422 }
423 
424 MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
425                                              intptr_t numElements,
426                                              MlirStringRef *strs) {
427   SmallVector<StringRef, 8> values;
428   values.reserve(numElements);
429   for (intptr_t i = 0; i < numElements; ++i)
430     values.push_back(unwrap(strs[i]));
431 
432   return wrap(
433       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
434 }
435 
436 MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
437                                               MlirType shapedType) {
438   return wrap(unwrap(attr).cast<DenseElementsAttr>().reshape(
439       unwrap(shapedType).cast<ShapedType>()));
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // Splat accessors.
444 
445 bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
446   return unwrap(attr).cast<DenseElementsAttr>().isSplat();
447 }
448 
449 MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
450   return wrap(unwrap(attr).cast<DenseElementsAttr>().getSplatValue());
451 }
452 int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
453   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<bool>();
454 }
455 int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) {
456   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int32_t>();
457 }
458 uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) {
459   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint32_t>();
460 }
461 int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) {
462   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int64_t>();
463 }
464 uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) {
465   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint64_t>();
466 }
467 float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) {
468   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<float>();
469 }
470 double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) {
471   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<double>();
472 }
473 MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
474   return wrap(
475       unwrap(attr).cast<DenseElementsAttr>().getSplatValue<StringRef>());
476 }
477 
478 //===----------------------------------------------------------------------===//
479 // Indexed accessors.
480 
481 bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
482   return *(unwrap(attr).cast<DenseElementsAttr>().getValues<bool>().begin() +
483            pos);
484 }
485 int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
486   return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>().begin() +
487            pos);
488 }
489 uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
490   return *(
491       unwrap(attr).cast<DenseElementsAttr>().getValues<uint32_t>().begin() +
492       pos);
493 }
494 int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
495   return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int64_t>().begin() +
496            pos);
497 }
498 uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
499   return *(
500       unwrap(attr).cast<DenseElementsAttr>().getValues<uint64_t>().begin() +
501       pos);
502 }
503 float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
504   return *(unwrap(attr).cast<DenseElementsAttr>().getValues<float>().begin() +
505            pos);
506 }
507 double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
508   return *(unwrap(attr).cast<DenseElementsAttr>().getValues<double>().begin() +
509            pos);
510 }
511 MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
512                                                   intptr_t pos) {
513   return wrap(
514       *(unwrap(attr).cast<DenseElementsAttr>().getValues<StringRef>().begin() +
515         pos));
516 }
517 
518 //===----------------------------------------------------------------------===//
519 // Raw data accessors.
520 
521 const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
522   return static_cast<const void *>(
523       unwrap(attr).cast<DenseElementsAttr>().getRawData().data());
524 }
525 
526 //===----------------------------------------------------------------------===//
527 // Opaque elements attribute.
528 //===----------------------------------------------------------------------===//
529 
530 bool mlirAttributeIsAOpaqueElements(MlirAttribute attr) {
531   return unwrap(attr).isa<OpaqueElementsAttr>();
532 }
533 
534 //===----------------------------------------------------------------------===//
535 // Sparse elements attribute.
536 //===----------------------------------------------------------------------===//
537 
538 bool mlirAttributeIsASparseElements(MlirAttribute attr) {
539   return unwrap(attr).isa<SparseElementsAttr>();
540 }
541 
542 MlirAttribute mlirSparseElementsAttribute(MlirType shapedType,
543                                           MlirAttribute denseIndices,
544                                           MlirAttribute denseValues) {
545   return wrap(
546       SparseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
547                               unwrap(denseIndices).cast<DenseElementsAttr>(),
548                               unwrap(denseValues).cast<DenseElementsAttr>()));
549 }
550 
551 MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) {
552   return wrap(unwrap(attr).cast<SparseElementsAttr>().getIndices());
553 }
554 
555 MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) {
556   return wrap(unwrap(attr).cast<SparseElementsAttr>().getValues());
557 }
558