1 //===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir-c/BuiltinAttributes.h"
10 #include "mlir/CAPI/AffineMap.h"
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Support.h"
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/BuiltinTypes.h"
15
16 using namespace mlir;
17
mlirAttributeGetNull()18 MlirAttribute mlirAttributeGetNull() { return {nullptr}; }
19
20 //===----------------------------------------------------------------------===//
21 // Affine map attribute.
22 //===----------------------------------------------------------------------===//
23
mlirAttributeIsAAffineMap(MlirAttribute attr)24 bool mlirAttributeIsAAffineMap(MlirAttribute attr) {
25 return unwrap(attr).isa<AffineMapAttr>();
26 }
27
mlirAffineMapAttrGet(MlirAffineMap map)28 MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) {
29 return wrap(AffineMapAttr::get(unwrap(map)));
30 }
31
mlirAffineMapAttrGetValue(MlirAttribute attr)32 MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) {
33 return wrap(unwrap(attr).cast<AffineMapAttr>().getValue());
34 }
35
36 //===----------------------------------------------------------------------===//
37 // Array attribute.
38 //===----------------------------------------------------------------------===//
39
mlirAttributeIsAArray(MlirAttribute attr)40 bool mlirAttributeIsAArray(MlirAttribute attr) {
41 return unwrap(attr).isa<ArrayAttr>();
42 }
43
mlirArrayAttrGet(MlirContext ctx,intptr_t numElements,MlirAttribute const * elements)44 MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
45 MlirAttribute const *elements) {
46 SmallVector<Attribute, 8> attrs;
47 return wrap(
48 ArrayAttr::get(unwrap(ctx), unwrapList(static_cast<size_t>(numElements),
49 elements, attrs)));
50 }
51
mlirArrayAttrGetNumElements(MlirAttribute attr)52 intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) {
53 return static_cast<intptr_t>(unwrap(attr).cast<ArrayAttr>().size());
54 }
55
mlirArrayAttrGetElement(MlirAttribute attr,intptr_t pos)56 MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) {
57 return wrap(unwrap(attr).cast<ArrayAttr>().getValue()[pos]);
58 }
59
60 //===----------------------------------------------------------------------===//
61 // Dictionary attribute.
62 //===----------------------------------------------------------------------===//
63
mlirAttributeIsADictionary(MlirAttribute attr)64 bool mlirAttributeIsADictionary(MlirAttribute attr) {
65 return unwrap(attr).isa<DictionaryAttr>();
66 }
67
mlirDictionaryAttrGet(MlirContext ctx,intptr_t numElements,MlirNamedAttribute const * elements)68 MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
69 MlirNamedAttribute const *elements) {
70 SmallVector<NamedAttribute, 8> attributes;
71 attributes.reserve(numElements);
72 for (intptr_t i = 0; i < numElements; ++i)
73 attributes.emplace_back(unwrap(elements[i].name),
74 unwrap(elements[i].attribute));
75 return wrap(DictionaryAttr::get(unwrap(ctx), attributes));
76 }
77
mlirDictionaryAttrGetNumElements(MlirAttribute attr)78 intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) {
79 return static_cast<intptr_t>(unwrap(attr).cast<DictionaryAttr>().size());
80 }
81
mlirDictionaryAttrGetElement(MlirAttribute attr,intptr_t pos)82 MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr,
83 intptr_t pos) {
84 NamedAttribute attribute =
85 unwrap(attr).cast<DictionaryAttr>().getValue()[pos];
86 return {wrap(attribute.getName()), wrap(attribute.getValue())};
87 }
88
mlirDictionaryAttrGetElementByName(MlirAttribute attr,MlirStringRef name)89 MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
90 MlirStringRef name) {
91 return wrap(unwrap(attr).cast<DictionaryAttr>().get(unwrap(name)));
92 }
93
94 //===----------------------------------------------------------------------===//
95 // Floating point attribute.
96 //===----------------------------------------------------------------------===//
97
mlirAttributeIsAFloat(MlirAttribute attr)98 bool mlirAttributeIsAFloat(MlirAttribute attr) {
99 return unwrap(attr).isa<FloatAttr>();
100 }
101
mlirFloatAttrDoubleGet(MlirContext ctx,MlirType type,double value)102 MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
103 double value) {
104 return wrap(FloatAttr::get(unwrap(type), value));
105 }
106
mlirFloatAttrDoubleGetChecked(MlirLocation loc,MlirType type,double value)107 MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type,
108 double value) {
109 return wrap(FloatAttr::getChecked(unwrap(loc), unwrap(type), value));
110 }
111
mlirFloatAttrGetValueDouble(MlirAttribute attr)112 double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
113 return unwrap(attr).cast<FloatAttr>().getValueAsDouble();
114 }
115
116 //===----------------------------------------------------------------------===//
117 // Integer attribute.
118 //===----------------------------------------------------------------------===//
119
mlirAttributeIsAInteger(MlirAttribute attr)120 bool mlirAttributeIsAInteger(MlirAttribute attr) {
121 return unwrap(attr).isa<IntegerAttr>();
122 }
123
mlirIntegerAttrGet(MlirType type,int64_t value)124 MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) {
125 return wrap(IntegerAttr::get(unwrap(type), value));
126 }
127
mlirIntegerAttrGetValueInt(MlirAttribute attr)128 int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) {
129 return unwrap(attr).cast<IntegerAttr>().getInt();
130 }
131
mlirIntegerAttrGetValueSInt(MlirAttribute attr)132 int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) {
133 return unwrap(attr).cast<IntegerAttr>().getSInt();
134 }
135
mlirIntegerAttrGetValueUInt(MlirAttribute attr)136 uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
137 return unwrap(attr).cast<IntegerAttr>().getUInt();
138 }
139
140 //===----------------------------------------------------------------------===//
141 // Bool attribute.
142 //===----------------------------------------------------------------------===//
143
mlirAttributeIsABool(MlirAttribute attr)144 bool mlirAttributeIsABool(MlirAttribute attr) {
145 return unwrap(attr).isa<BoolAttr>();
146 }
147
mlirBoolAttrGet(MlirContext ctx,int value)148 MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
149 return wrap(BoolAttr::get(unwrap(ctx), value));
150 }
151
mlirBoolAttrGetValue(MlirAttribute attr)152 bool mlirBoolAttrGetValue(MlirAttribute attr) {
153 return unwrap(attr).cast<BoolAttr>().getValue();
154 }
155
156 //===----------------------------------------------------------------------===//
157 // Integer set attribute.
158 //===----------------------------------------------------------------------===//
159
mlirAttributeIsAIntegerSet(MlirAttribute attr)160 bool mlirAttributeIsAIntegerSet(MlirAttribute attr) {
161 return unwrap(attr).isa<IntegerSetAttr>();
162 }
163
164 //===----------------------------------------------------------------------===//
165 // Opaque attribute.
166 //===----------------------------------------------------------------------===//
167
mlirAttributeIsAOpaque(MlirAttribute attr)168 bool mlirAttributeIsAOpaque(MlirAttribute attr) {
169 return unwrap(attr).isa<OpaqueAttr>();
170 }
171
mlirOpaqueAttrGet(MlirContext ctx,MlirStringRef dialectNamespace,intptr_t dataLength,const char * data,MlirType type)172 MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
173 intptr_t dataLength, const char *data,
174 MlirType type) {
175 return wrap(
176 OpaqueAttr::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)),
177 StringRef(data, dataLength), unwrap(type)));
178 }
179
mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr)180 MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
181 return wrap(unwrap(attr).cast<OpaqueAttr>().getDialectNamespace().strref());
182 }
183
mlirOpaqueAttrGetData(MlirAttribute attr)184 MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) {
185 return wrap(unwrap(attr).cast<OpaqueAttr>().getAttrData());
186 }
187
188 //===----------------------------------------------------------------------===//
189 // String attribute.
190 //===----------------------------------------------------------------------===//
191
mlirAttributeIsAString(MlirAttribute attr)192 bool mlirAttributeIsAString(MlirAttribute attr) {
193 return unwrap(attr).isa<StringAttr>();
194 }
195
mlirStringAttrGet(MlirContext ctx,MlirStringRef str)196 MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
197 return wrap((Attribute)StringAttr::get(unwrap(ctx), unwrap(str)));
198 }
199
mlirStringAttrTypedGet(MlirType type,MlirStringRef str)200 MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
201 return wrap((Attribute)StringAttr::get(unwrap(str), unwrap(type)));
202 }
203
mlirStringAttrGetValue(MlirAttribute attr)204 MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {
205 return wrap(unwrap(attr).cast<StringAttr>().getValue());
206 }
207
208 //===----------------------------------------------------------------------===//
209 // SymbolRef attribute.
210 //===----------------------------------------------------------------------===//
211
mlirAttributeIsASymbolRef(MlirAttribute attr)212 bool mlirAttributeIsASymbolRef(MlirAttribute attr) {
213 return unwrap(attr).isa<SymbolRefAttr>();
214 }
215
mlirSymbolRefAttrGet(MlirContext ctx,MlirStringRef symbol,intptr_t numReferences,MlirAttribute const * references)216 MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
217 intptr_t numReferences,
218 MlirAttribute const *references) {
219 SmallVector<FlatSymbolRefAttr, 4> refs;
220 refs.reserve(numReferences);
221 for (intptr_t i = 0; i < numReferences; ++i)
222 refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>());
223 auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol));
224 return wrap(SymbolRefAttr::get(symbolAttr, refs));
225 }
226
mlirSymbolRefAttrGetRootReference(MlirAttribute attr)227 MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
228 return wrap(unwrap(attr).cast<SymbolRefAttr>().getRootReference().getValue());
229 }
230
mlirSymbolRefAttrGetLeafReference(MlirAttribute attr)231 MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) {
232 return wrap(unwrap(attr).cast<SymbolRefAttr>().getLeafReference().getValue());
233 }
234
mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr)235 intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) {
236 return static_cast<intptr_t>(
237 unwrap(attr).cast<SymbolRefAttr>().getNestedReferences().size());
238 }
239
mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,intptr_t pos)240 MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
241 intptr_t pos) {
242 return wrap(unwrap(attr).cast<SymbolRefAttr>().getNestedReferences()[pos]);
243 }
244
245 //===----------------------------------------------------------------------===//
246 // Flat SymbolRef attribute.
247 //===----------------------------------------------------------------------===//
248
mlirAttributeIsAFlatSymbolRef(MlirAttribute attr)249 bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
250 return unwrap(attr).isa<FlatSymbolRefAttr>();
251 }
252
mlirFlatSymbolRefAttrGet(MlirContext ctx,MlirStringRef symbol)253 MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) {
254 return wrap(FlatSymbolRefAttr::get(unwrap(ctx), unwrap(symbol)));
255 }
256
mlirFlatSymbolRefAttrGetValue(MlirAttribute attr)257 MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
258 return wrap(unwrap(attr).cast<FlatSymbolRefAttr>().getValue());
259 }
260
261 //===----------------------------------------------------------------------===//
262 // Type attribute.
263 //===----------------------------------------------------------------------===//
264
mlirAttributeIsAType(MlirAttribute attr)265 bool mlirAttributeIsAType(MlirAttribute attr) {
266 return unwrap(attr).isa<TypeAttr>();
267 }
268
mlirTypeAttrGet(MlirType type)269 MlirAttribute mlirTypeAttrGet(MlirType type) {
270 return wrap(TypeAttr::get(unwrap(type)));
271 }
272
mlirTypeAttrGetValue(MlirAttribute attr)273 MlirType mlirTypeAttrGetValue(MlirAttribute attr) {
274 return wrap(unwrap(attr).cast<TypeAttr>().getValue());
275 }
276
277 //===----------------------------------------------------------------------===//
278 // Unit attribute.
279 //===----------------------------------------------------------------------===//
280
mlirAttributeIsAUnit(MlirAttribute attr)281 bool mlirAttributeIsAUnit(MlirAttribute attr) {
282 return unwrap(attr).isa<UnitAttr>();
283 }
284
mlirUnitAttrGet(MlirContext ctx)285 MlirAttribute mlirUnitAttrGet(MlirContext ctx) {
286 return wrap(UnitAttr::get(unwrap(ctx)));
287 }
288
289 //===----------------------------------------------------------------------===//
290 // Elements attributes.
291 //===----------------------------------------------------------------------===//
292
mlirAttributeIsAElements(MlirAttribute attr)293 bool mlirAttributeIsAElements(MlirAttribute attr) {
294 return unwrap(attr).isa<ElementsAttr>();
295 }
296
mlirElementsAttrGetValue(MlirAttribute attr,intptr_t rank,uint64_t * idxs)297 MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank,
298 uint64_t *idxs) {
299 return wrap(unwrap(attr)
300 .cast<ElementsAttr>()
301 .getValues<Attribute>()[llvm::makeArrayRef(idxs, rank)]);
302 }
303
mlirElementsAttrIsValidIndex(MlirAttribute attr,intptr_t rank,uint64_t * idxs)304 bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank,
305 uint64_t *idxs) {
306 return unwrap(attr).cast<ElementsAttr>().isValidIndex(
307 llvm::makeArrayRef(idxs, rank));
308 }
309
mlirElementsAttrGetNumElements(MlirAttribute attr)310 int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
311 return unwrap(attr).cast<ElementsAttr>().getNumElements();
312 }
313
314 //===----------------------------------------------------------------------===//
315 // Dense elements attribute.
316 //===----------------------------------------------------------------------===//
317
318 //===----------------------------------------------------------------------===//
319 // IsA support.
320
mlirAttributeIsADenseElements(MlirAttribute attr)321 bool mlirAttributeIsADenseElements(MlirAttribute attr) {
322 return unwrap(attr).isa<DenseElementsAttr>();
323 }
mlirAttributeIsADenseIntElements(MlirAttribute attr)324 bool mlirAttributeIsADenseIntElements(MlirAttribute attr) {
325 return unwrap(attr).isa<DenseIntElementsAttr>();
326 }
mlirAttributeIsADenseFPElements(MlirAttribute attr)327 bool mlirAttributeIsADenseFPElements(MlirAttribute attr) {
328 return unwrap(attr).isa<DenseFPElementsAttr>();
329 }
330
331 //===----------------------------------------------------------------------===//
332 // Constructors.
333
mlirDenseElementsAttrGet(MlirType shapedType,intptr_t numElements,MlirAttribute const * elements)334 MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
335 intptr_t numElements,
336 MlirAttribute const *elements) {
337 SmallVector<Attribute, 8> attributes;
338 return wrap(
339 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
340 unwrapList(numElements, elements, attributes)));
341 }
342
mlirDenseElementsAttrRawBufferGet(MlirType shapedType,size_t rawBufferSize,const void * rawBuffer)343 MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType,
344 size_t rawBufferSize,
345 const void *rawBuffer) {
346 auto shapedTypeCpp = unwrap(shapedType).cast<ShapedType>();
347 ArrayRef<char> rawBufferCpp(static_cast<const char *>(rawBuffer),
348 rawBufferSize);
349 bool isSplat = false;
350 if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp,
351 isSplat))
352 return mlirAttributeGetNull();
353 return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp));
354 }
355
mlirDenseElementsAttrSplatGet(MlirType shapedType,MlirAttribute element)356 MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,
357 MlirAttribute element) {
358 return wrap(DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
359 unwrap(element)));
360 }
mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,bool element)361 MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,
362 bool element) {
363 return wrap(
364 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
365 }
mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType,uint8_t element)366 MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType,
367 uint8_t element) {
368 return wrap(
369 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
370 }
mlirDenseElementsAttrInt8SplatGet(MlirType shapedType,int8_t element)371 MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType,
372 int8_t element) {
373 return wrap(
374 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
375 }
mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,uint32_t element)376 MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,
377 uint32_t element) {
378 return wrap(
379 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
380 }
mlirDenseElementsAttrInt32SplatGet(MlirType shapedType,int32_t element)381 MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType,
382 int32_t element) {
383 return wrap(
384 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
385 }
mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType,uint64_t element)386 MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType,
387 uint64_t element) {
388 return wrap(
389 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
390 }
mlirDenseElementsAttrInt64SplatGet(MlirType shapedType,int64_t element)391 MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType,
392 int64_t element) {
393 return wrap(
394 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
395 }
mlirDenseElementsAttrFloatSplatGet(MlirType shapedType,float element)396 MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType,
397 float element) {
398 return wrap(
399 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
400 }
mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,double element)401 MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,
402 double element) {
403 return wrap(
404 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
405 }
406
mlirDenseElementsAttrBoolGet(MlirType shapedType,intptr_t numElements,const int * elements)407 MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
408 intptr_t numElements,
409 const int *elements) {
410 SmallVector<bool, 8> values(elements, elements + numElements);
411 return wrap(
412 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
413 }
414
415 /// Creates a dense attribute with elements of the type deduced by templates.
416 template <typename T>
getDenseAttribute(MlirType shapedType,intptr_t numElements,const T * elements)417 static MlirAttribute getDenseAttribute(MlirType shapedType,
418 intptr_t numElements,
419 const T *elements) {
420 return wrap(
421 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
422 llvm::makeArrayRef(elements, numElements)));
423 }
424
mlirDenseElementsAttrUInt8Get(MlirType shapedType,intptr_t numElements,const uint8_t * elements)425 MlirAttribute mlirDenseElementsAttrUInt8Get(MlirType shapedType,
426 intptr_t numElements,
427 const uint8_t *elements) {
428 return getDenseAttribute(shapedType, numElements, elements);
429 }
mlirDenseElementsAttrInt8Get(MlirType shapedType,intptr_t numElements,const int8_t * elements)430 MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType,
431 intptr_t numElements,
432 const int8_t *elements) {
433 return getDenseAttribute(shapedType, numElements, elements);
434 }
mlirDenseElementsAttrUInt16Get(MlirType shapedType,intptr_t numElements,const uint16_t * elements)435 MlirAttribute mlirDenseElementsAttrUInt16Get(MlirType shapedType,
436 intptr_t numElements,
437 const uint16_t *elements) {
438 return getDenseAttribute(shapedType, numElements, elements);
439 }
mlirDenseElementsAttrInt16Get(MlirType shapedType,intptr_t numElements,const int16_t * elements)440 MlirAttribute mlirDenseElementsAttrInt16Get(MlirType shapedType,
441 intptr_t numElements,
442 const int16_t *elements) {
443 return getDenseAttribute(shapedType, numElements, elements);
444 }
mlirDenseElementsAttrUInt32Get(MlirType shapedType,intptr_t numElements,const uint32_t * elements)445 MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
446 intptr_t numElements,
447 const uint32_t *elements) {
448 return getDenseAttribute(shapedType, numElements, elements);
449 }
mlirDenseElementsAttrInt32Get(MlirType shapedType,intptr_t numElements,const int32_t * elements)450 MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType,
451 intptr_t numElements,
452 const int32_t *elements) {
453 return getDenseAttribute(shapedType, numElements, elements);
454 }
mlirDenseElementsAttrUInt64Get(MlirType shapedType,intptr_t numElements,const uint64_t * elements)455 MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType,
456 intptr_t numElements,
457 const uint64_t *elements) {
458 return getDenseAttribute(shapedType, numElements, elements);
459 }
mlirDenseElementsAttrInt64Get(MlirType shapedType,intptr_t numElements,const int64_t * elements)460 MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType,
461 intptr_t numElements,
462 const int64_t *elements) {
463 return getDenseAttribute(shapedType, numElements, elements);
464 }
mlirDenseElementsAttrFloatGet(MlirType shapedType,intptr_t numElements,const float * elements)465 MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType,
466 intptr_t numElements,
467 const float *elements) {
468 return getDenseAttribute(shapedType, numElements, elements);
469 }
mlirDenseElementsAttrDoubleGet(MlirType shapedType,intptr_t numElements,const double * elements)470 MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
471 intptr_t numElements,
472 const double *elements) {
473 return getDenseAttribute(shapedType, numElements, elements);
474 }
mlirDenseElementsAttrBFloat16Get(MlirType shapedType,intptr_t numElements,const uint16_t * elements)475 MlirAttribute mlirDenseElementsAttrBFloat16Get(MlirType shapedType,
476 intptr_t numElements,
477 const uint16_t *elements) {
478 size_t bufferSize = numElements * 2;
479 const void *buffer = static_cast<const void *>(elements);
480 return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer);
481 }
mlirDenseElementsAttrFloat16Get(MlirType shapedType,intptr_t numElements,const uint16_t * elements)482 MlirAttribute mlirDenseElementsAttrFloat16Get(MlirType shapedType,
483 intptr_t numElements,
484 const uint16_t *elements) {
485 size_t bufferSize = numElements * 2;
486 const void *buffer = static_cast<const void *>(elements);
487 return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer);
488 }
489
mlirDenseElementsAttrStringGet(MlirType shapedType,intptr_t numElements,MlirStringRef * strs)490 MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
491 intptr_t numElements,
492 MlirStringRef *strs) {
493 SmallVector<StringRef, 8> values;
494 values.reserve(numElements);
495 for (intptr_t i = 0; i < numElements; ++i)
496 values.push_back(unwrap(strs[i]));
497
498 return wrap(
499 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
500 }
501
mlirDenseElementsAttrReshapeGet(MlirAttribute attr,MlirType shapedType)502 MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
503 MlirType shapedType) {
504 return wrap(unwrap(attr).cast<DenseElementsAttr>().reshape(
505 unwrap(shapedType).cast<ShapedType>()));
506 }
507
508 //===----------------------------------------------------------------------===//
509 // Splat accessors.
510
mlirDenseElementsAttrIsSplat(MlirAttribute attr)511 bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
512 return unwrap(attr).cast<DenseElementsAttr>().isSplat();
513 }
514
mlirDenseElementsAttrGetSplatValue(MlirAttribute attr)515 MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
516 return wrap(
517 unwrap(attr).cast<DenseElementsAttr>().getSplatValue<Attribute>());
518 }
mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr)519 int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
520 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<bool>();
521 }
mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr)522 int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) {
523 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int8_t>();
524 }
mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr)525 uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) {
526 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint8_t>();
527 }
mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr)528 int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) {
529 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int32_t>();
530 }
mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr)531 uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) {
532 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint32_t>();
533 }
mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr)534 int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) {
535 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int64_t>();
536 }
mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr)537 uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) {
538 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint64_t>();
539 }
mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr)540 float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) {
541 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<float>();
542 }
mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr)543 double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) {
544 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<double>();
545 }
mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr)546 MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
547 return wrap(
548 unwrap(attr).cast<DenseElementsAttr>().getSplatValue<StringRef>());
549 }
550
551 //===----------------------------------------------------------------------===//
552 // Indexed accessors.
553
mlirDenseElementsAttrGetBoolValue(MlirAttribute attr,intptr_t pos)554 bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
555 return unwrap(attr).cast<DenseElementsAttr>().getValues<bool>()[pos];
556 }
mlirDenseElementsAttrGetInt8Value(MlirAttribute attr,intptr_t pos)557 int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
558 return unwrap(attr).cast<DenseElementsAttr>().getValues<int8_t>()[pos];
559 }
mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr,intptr_t pos)560 uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
561 return unwrap(attr).cast<DenseElementsAttr>().getValues<uint8_t>()[pos];
562 }
mlirDenseElementsAttrGetInt16Value(MlirAttribute attr,intptr_t pos)563 int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) {
564 return unwrap(attr).cast<DenseElementsAttr>().getValues<int16_t>()[pos];
565 }
mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr,intptr_t pos)566 uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) {
567 return unwrap(attr).cast<DenseElementsAttr>().getValues<uint16_t>()[pos];
568 }
mlirDenseElementsAttrGetInt32Value(MlirAttribute attr,intptr_t pos)569 int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
570 return unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>()[pos];
571 }
mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr,intptr_t pos)572 uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
573 return unwrap(attr).cast<DenseElementsAttr>().getValues<uint32_t>()[pos];
574 }
mlirDenseElementsAttrGetInt64Value(MlirAttribute attr,intptr_t pos)575 int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
576 return unwrap(attr).cast<DenseElementsAttr>().getValues<int64_t>()[pos];
577 }
mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr,intptr_t pos)578 uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
579 return unwrap(attr).cast<DenseElementsAttr>().getValues<uint64_t>()[pos];
580 }
mlirDenseElementsAttrGetFloatValue(MlirAttribute attr,intptr_t pos)581 float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
582 return unwrap(attr).cast<DenseElementsAttr>().getValues<float>()[pos];
583 }
mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr,intptr_t pos)584 double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
585 return unwrap(attr).cast<DenseElementsAttr>().getValues<double>()[pos];
586 }
mlirDenseElementsAttrGetStringValue(MlirAttribute attr,intptr_t pos)587 MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
588 intptr_t pos) {
589 return wrap(
590 unwrap(attr).cast<DenseElementsAttr>().getValues<StringRef>()[pos]);
591 }
592
593 //===----------------------------------------------------------------------===//
594 // Raw data accessors.
595
mlirDenseElementsAttrGetRawData(MlirAttribute attr)596 const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
597 return static_cast<const void *>(
598 unwrap(attr).cast<DenseElementsAttr>().getRawData().data());
599 }
600
601 //===----------------------------------------------------------------------===//
602 // Opaque elements attribute.
603 //===----------------------------------------------------------------------===//
604
mlirAttributeIsAOpaqueElements(MlirAttribute attr)605 bool mlirAttributeIsAOpaqueElements(MlirAttribute attr) {
606 return unwrap(attr).isa<OpaqueElementsAttr>();
607 }
608
609 //===----------------------------------------------------------------------===//
610 // Sparse elements attribute.
611 //===----------------------------------------------------------------------===//
612
mlirAttributeIsASparseElements(MlirAttribute attr)613 bool mlirAttributeIsASparseElements(MlirAttribute attr) {
614 return unwrap(attr).isa<SparseElementsAttr>();
615 }
616
mlirSparseElementsAttribute(MlirType shapedType,MlirAttribute denseIndices,MlirAttribute denseValues)617 MlirAttribute mlirSparseElementsAttribute(MlirType shapedType,
618 MlirAttribute denseIndices,
619 MlirAttribute denseValues) {
620 return wrap(
621 SparseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
622 unwrap(denseIndices).cast<DenseElementsAttr>(),
623 unwrap(denseValues).cast<DenseElementsAttr>()));
624 }
625
mlirSparseElementsAttrGetIndices(MlirAttribute attr)626 MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) {
627 return wrap(unwrap(attr).cast<SparseElementsAttr>().getIndices());
628 }
629
mlirSparseElementsAttrGetValues(MlirAttribute attr)630 MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) {
631 return wrap(unwrap(attr).cast<SparseElementsAttr>().getValues());
632 }
633