1 //===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/BuiltinAttributes.h" 10 #include "mlir/CAPI/AffineMap.h" 11 #include "mlir/CAPI/IR.h" 12 #include "mlir/CAPI/Support.h" 13 #include "mlir/IR/Attributes.h" 14 #include "mlir/IR/BuiltinTypes.h" 15 16 using namespace mlir; 17 18 MlirAttribute mlirAttributeGetNull() { return {nullptr}; } 19 20 //===----------------------------------------------------------------------===// 21 // Affine map attribute. 22 //===----------------------------------------------------------------------===// 23 24 bool mlirAttributeIsAAffineMap(MlirAttribute attr) { 25 return unwrap(attr).isa<AffineMapAttr>(); 26 } 27 28 MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) { 29 return wrap(AffineMapAttr::get(unwrap(map))); 30 } 31 32 MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) { 33 return wrap(unwrap(attr).cast<AffineMapAttr>().getValue()); 34 } 35 36 //===----------------------------------------------------------------------===// 37 // Array attribute. 38 //===----------------------------------------------------------------------===// 39 40 bool mlirAttributeIsAArray(MlirAttribute attr) { 41 return unwrap(attr).isa<ArrayAttr>(); 42 } 43 44 MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, 45 MlirAttribute const *elements) { 46 SmallVector<Attribute, 8> attrs; 47 return wrap( 48 ArrayAttr::get(unwrap(ctx), unwrapList(static_cast<size_t>(numElements), 49 elements, attrs))); 50 } 51 52 intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) { 53 return static_cast<intptr_t>(unwrap(attr).cast<ArrayAttr>().size()); 54 } 55 56 MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) { 57 return wrap(unwrap(attr).cast<ArrayAttr>().getValue()[pos]); 58 } 59 60 //===----------------------------------------------------------------------===// 61 // Dictionary attribute. 62 //===----------------------------------------------------------------------===// 63 64 bool mlirAttributeIsADictionary(MlirAttribute attr) { 65 return unwrap(attr).isa<DictionaryAttr>(); 66 } 67 68 MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, 69 MlirNamedAttribute const *elements) { 70 SmallVector<NamedAttribute, 8> attributes; 71 attributes.reserve(numElements); 72 for (intptr_t i = 0; i < numElements; ++i) 73 attributes.emplace_back( 74 Identifier::get(unwrap(elements[i].name), unwrap(ctx)), 75 unwrap(elements[i].attribute)); 76 return wrap(DictionaryAttr::get(unwrap(ctx), attributes)); 77 } 78 79 intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) { 80 return static_cast<intptr_t>(unwrap(attr).cast<DictionaryAttr>().size()); 81 } 82 83 MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr, 84 intptr_t pos) { 85 NamedAttribute attribute = 86 unwrap(attr).cast<DictionaryAttr>().getValue()[pos]; 87 return {wrap(attribute.first), wrap(attribute.second)}; 88 } 89 90 MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, 91 MlirStringRef name) { 92 return wrap(unwrap(attr).cast<DictionaryAttr>().get(unwrap(name))); 93 } 94 95 //===----------------------------------------------------------------------===// 96 // Floating point attribute. 97 //===----------------------------------------------------------------------===// 98 99 bool mlirAttributeIsAFloat(MlirAttribute attr) { 100 return unwrap(attr).isa<FloatAttr>(); 101 } 102 103 MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type, 104 double value) { 105 return wrap(FloatAttr::get(unwrap(type), value)); 106 } 107 108 MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type, 109 double value) { 110 return wrap(FloatAttr::getChecked(unwrap(loc), unwrap(type), value)); 111 } 112 113 double mlirFloatAttrGetValueDouble(MlirAttribute attr) { 114 return unwrap(attr).cast<FloatAttr>().getValueAsDouble(); 115 } 116 117 //===----------------------------------------------------------------------===// 118 // Integer attribute. 119 //===----------------------------------------------------------------------===// 120 121 bool mlirAttributeIsAInteger(MlirAttribute attr) { 122 return unwrap(attr).isa<IntegerAttr>(); 123 } 124 125 MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) { 126 return wrap(IntegerAttr::get(unwrap(type), value)); 127 } 128 129 int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) { 130 return unwrap(attr).cast<IntegerAttr>().getInt(); 131 } 132 133 //===----------------------------------------------------------------------===// 134 // Bool attribute. 135 //===----------------------------------------------------------------------===// 136 137 bool mlirAttributeIsABool(MlirAttribute attr) { 138 return unwrap(attr).isa<BoolAttr>(); 139 } 140 141 MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) { 142 return wrap(BoolAttr::get(unwrap(ctx), value)); 143 } 144 145 bool mlirBoolAttrGetValue(MlirAttribute attr) { 146 return unwrap(attr).cast<BoolAttr>().getValue(); 147 } 148 149 //===----------------------------------------------------------------------===// 150 // Integer set attribute. 151 //===----------------------------------------------------------------------===// 152 153 bool mlirAttributeIsAIntegerSet(MlirAttribute attr) { 154 return unwrap(attr).isa<IntegerSetAttr>(); 155 } 156 157 //===----------------------------------------------------------------------===// 158 // Opaque attribute. 159 //===----------------------------------------------------------------------===// 160 161 bool mlirAttributeIsAOpaque(MlirAttribute attr) { 162 return unwrap(attr).isa<OpaqueAttr>(); 163 } 164 165 MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, 166 intptr_t dataLength, const char *data, 167 MlirType type) { 168 return wrap( 169 OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)), 170 StringRef(data, dataLength), unwrap(type))); 171 } 172 173 MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { 174 return wrap(unwrap(attr).cast<OpaqueAttr>().getDialectNamespace().strref()); 175 } 176 177 MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) { 178 return wrap(unwrap(attr).cast<OpaqueAttr>().getAttrData()); 179 } 180 181 //===----------------------------------------------------------------------===// 182 // String attribute. 183 //===----------------------------------------------------------------------===// 184 185 bool mlirAttributeIsAString(MlirAttribute attr) { 186 return unwrap(attr).isa<StringAttr>(); 187 } 188 189 MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) { 190 return wrap(StringAttr::get(unwrap(ctx), unwrap(str))); 191 } 192 193 MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) { 194 return wrap(StringAttr::get(unwrap(str), unwrap(type))); 195 } 196 197 MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) { 198 return wrap(unwrap(attr).cast<StringAttr>().getValue()); 199 } 200 201 //===----------------------------------------------------------------------===// 202 // SymbolRef attribute. 203 //===----------------------------------------------------------------------===// 204 205 bool mlirAttributeIsASymbolRef(MlirAttribute attr) { 206 return unwrap(attr).isa<SymbolRefAttr>(); 207 } 208 209 MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, 210 intptr_t numReferences, 211 MlirAttribute const *references) { 212 SmallVector<FlatSymbolRefAttr, 4> refs; 213 refs.reserve(numReferences); 214 for (intptr_t i = 0; i < numReferences; ++i) 215 refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>()); 216 return wrap(SymbolRefAttr::get(unwrap(ctx), unwrap(symbol), refs)); 217 } 218 219 MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) { 220 return wrap(unwrap(attr).cast<SymbolRefAttr>().getRootReference()); 221 } 222 223 MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) { 224 return wrap(unwrap(attr).cast<SymbolRefAttr>().getLeafReference()); 225 } 226 227 intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) { 228 return static_cast<intptr_t>( 229 unwrap(attr).cast<SymbolRefAttr>().getNestedReferences().size()); 230 } 231 232 MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, 233 intptr_t pos) { 234 return wrap(unwrap(attr).cast<SymbolRefAttr>().getNestedReferences()[pos]); 235 } 236 237 //===----------------------------------------------------------------------===// 238 // Flat SymbolRef attribute. 239 //===----------------------------------------------------------------------===// 240 241 bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) { 242 return unwrap(attr).isa<FlatSymbolRefAttr>(); 243 } 244 245 MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) { 246 return wrap(FlatSymbolRefAttr::get(unwrap(ctx), unwrap(symbol))); 247 } 248 249 MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { 250 return wrap(unwrap(attr).cast<FlatSymbolRefAttr>().getValue()); 251 } 252 253 //===----------------------------------------------------------------------===// 254 // Type attribute. 255 //===----------------------------------------------------------------------===// 256 257 bool mlirAttributeIsAType(MlirAttribute attr) { 258 return unwrap(attr).isa<TypeAttr>(); 259 } 260 261 MlirAttribute mlirTypeAttrGet(MlirType type) { 262 return wrap(TypeAttr::get(unwrap(type))); 263 } 264 265 MlirType mlirTypeAttrGetValue(MlirAttribute attr) { 266 return wrap(unwrap(attr).cast<TypeAttr>().getValue()); 267 } 268 269 //===----------------------------------------------------------------------===// 270 // Unit attribute. 271 //===----------------------------------------------------------------------===// 272 273 bool mlirAttributeIsAUnit(MlirAttribute attr) { 274 return unwrap(attr).isa<UnitAttr>(); 275 } 276 277 MlirAttribute mlirUnitAttrGet(MlirContext ctx) { 278 return wrap(UnitAttr::get(unwrap(ctx))); 279 } 280 281 //===----------------------------------------------------------------------===// 282 // Elements attributes. 283 //===----------------------------------------------------------------------===// 284 285 bool mlirAttributeIsAElements(MlirAttribute attr) { 286 return unwrap(attr).isa<ElementsAttr>(); 287 } 288 289 MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, 290 uint64_t *idxs) { 291 return wrap(unwrap(attr).cast<ElementsAttr>().getValue( 292 llvm::makeArrayRef(idxs, rank))); 293 } 294 295 bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, 296 uint64_t *idxs) { 297 return unwrap(attr).cast<ElementsAttr>().isValidIndex( 298 llvm::makeArrayRef(idxs, rank)); 299 } 300 301 int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) { 302 return unwrap(attr).cast<ElementsAttr>().getNumElements(); 303 } 304 305 //===----------------------------------------------------------------------===// 306 // Dense elements attribute. 307 //===----------------------------------------------------------------------===// 308 309 //===----------------------------------------------------------------------===// 310 // IsA support. 311 312 bool mlirAttributeIsADenseElements(MlirAttribute attr) { 313 return unwrap(attr).isa<DenseElementsAttr>(); 314 } 315 bool mlirAttributeIsADenseIntElements(MlirAttribute attr) { 316 return unwrap(attr).isa<DenseIntElementsAttr>(); 317 } 318 bool mlirAttributeIsADenseFPElements(MlirAttribute attr) { 319 return unwrap(attr).isa<DenseFPElementsAttr>(); 320 } 321 322 //===----------------------------------------------------------------------===// 323 // Constructors. 324 325 MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, 326 intptr_t numElements, 327 MlirAttribute const *elements) { 328 SmallVector<Attribute, 8> attributes; 329 return wrap( 330 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), 331 unwrapList(numElements, elements, attributes))); 332 } 333 334 MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, 335 MlirAttribute element) { 336 return wrap(DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), 337 unwrap(element))); 338 } 339 MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, 340 bool element) { 341 return wrap( 342 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element)); 343 } 344 MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, 345 uint32_t element) { 346 return wrap( 347 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element)); 348 } 349 MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, 350 int32_t element) { 351 return wrap( 352 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element)); 353 } 354 MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType, 355 uint64_t element) { 356 return wrap( 357 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element)); 358 } 359 MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType, 360 int64_t element) { 361 return wrap( 362 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element)); 363 } 364 MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType, 365 float element) { 366 return wrap( 367 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element)); 368 } 369 MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType, 370 double element) { 371 return wrap( 372 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element)); 373 } 374 375 MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType, 376 intptr_t numElements, 377 const int *elements) { 378 SmallVector<bool, 8> values(elements, elements + numElements); 379 return wrap( 380 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values)); 381 } 382 383 /// Creates a dense attribute with elements of the type deduced by templates. 384 template <typename T> 385 static MlirAttribute getDenseAttribute(MlirType shapedType, 386 intptr_t numElements, 387 const T *elements) { 388 return wrap( 389 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), 390 llvm::makeArrayRef(elements, numElements))); 391 } 392 393 MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType, 394 intptr_t numElements, 395 const uint32_t *elements) { 396 return getDenseAttribute(shapedType, numElements, elements); 397 } 398 MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType, 399 intptr_t numElements, 400 const int32_t *elements) { 401 return getDenseAttribute(shapedType, numElements, elements); 402 } 403 MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType, 404 intptr_t numElements, 405 const uint64_t *elements) { 406 return getDenseAttribute(shapedType, numElements, elements); 407 } 408 MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType, 409 intptr_t numElements, 410 const int64_t *elements) { 411 return getDenseAttribute(shapedType, numElements, elements); 412 } 413 MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType, 414 intptr_t numElements, 415 const float *elements) { 416 return getDenseAttribute(shapedType, numElements, elements); 417 } 418 MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType, 419 intptr_t numElements, 420 const double *elements) { 421 return getDenseAttribute(shapedType, numElements, elements); 422 } 423 424 MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, 425 intptr_t numElements, 426 MlirStringRef *strs) { 427 SmallVector<StringRef, 8> values; 428 values.reserve(numElements); 429 for (intptr_t i = 0; i < numElements; ++i) 430 values.push_back(unwrap(strs[i])); 431 432 return wrap( 433 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values)); 434 } 435 436 MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, 437 MlirType shapedType) { 438 return wrap(unwrap(attr).cast<DenseElementsAttr>().reshape( 439 unwrap(shapedType).cast<ShapedType>())); 440 } 441 442 //===----------------------------------------------------------------------===// 443 // Splat accessors. 444 445 bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { 446 return unwrap(attr).cast<DenseElementsAttr>().isSplat(); 447 } 448 449 MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { 450 return wrap(unwrap(attr).cast<DenseElementsAttr>().getSplatValue()); 451 } 452 int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { 453 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<bool>(); 454 } 455 int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) { 456 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int32_t>(); 457 } 458 uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) { 459 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint32_t>(); 460 } 461 int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) { 462 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int64_t>(); 463 } 464 uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) { 465 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint64_t>(); 466 } 467 float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) { 468 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<float>(); 469 } 470 double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) { 471 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<double>(); 472 } 473 MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) { 474 return wrap( 475 unwrap(attr).cast<DenseElementsAttr>().getSplatValue<StringRef>()); 476 } 477 478 //===----------------------------------------------------------------------===// 479 // Indexed accessors. 480 481 bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { 482 return *(unwrap(attr).cast<DenseElementsAttr>().getValues<bool>().begin() + 483 pos); 484 } 485 int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { 486 return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>().begin() + 487 pos); 488 } 489 uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { 490 return *( 491 unwrap(attr).cast<DenseElementsAttr>().getValues<uint32_t>().begin() + 492 pos); 493 } 494 int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { 495 return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int64_t>().begin() + 496 pos); 497 } 498 uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { 499 return *( 500 unwrap(attr).cast<DenseElementsAttr>().getValues<uint64_t>().begin() + 501 pos); 502 } 503 float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { 504 return *(unwrap(attr).cast<DenseElementsAttr>().getValues<float>().begin() + 505 pos); 506 } 507 double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { 508 return *(unwrap(attr).cast<DenseElementsAttr>().getValues<double>().begin() + 509 pos); 510 } 511 MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, 512 intptr_t pos) { 513 return wrap( 514 *(unwrap(attr).cast<DenseElementsAttr>().getValues<StringRef>().begin() + 515 pos)); 516 } 517 518 //===----------------------------------------------------------------------===// 519 // Raw data accessors. 520 521 const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { 522 return static_cast<const void *>( 523 unwrap(attr).cast<DenseElementsAttr>().getRawData().data()); 524 } 525 526 //===----------------------------------------------------------------------===// 527 // Opaque elements attribute. 528 //===----------------------------------------------------------------------===// 529 530 bool mlirAttributeIsAOpaqueElements(MlirAttribute attr) { 531 return unwrap(attr).isa<OpaqueElementsAttr>(); 532 } 533 534 //===----------------------------------------------------------------------===// 535 // Sparse elements attribute. 536 //===----------------------------------------------------------------------===// 537 538 bool mlirAttributeIsASparseElements(MlirAttribute attr) { 539 return unwrap(attr).isa<SparseElementsAttr>(); 540 } 541 542 MlirAttribute mlirSparseElementsAttribute(MlirType shapedType, 543 MlirAttribute denseIndices, 544 MlirAttribute denseValues) { 545 return wrap( 546 SparseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), 547 unwrap(denseIndices).cast<DenseElementsAttr>(), 548 unwrap(denseValues).cast<DenseElementsAttr>())); 549 } 550 551 MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) { 552 return wrap(unwrap(attr).cast<SparseElementsAttr>().getIndices()); 553 } 554 555 MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) { 556 return wrap(unwrap(attr).cast<SparseElementsAttr>().getValues()); 557 } 558