1 //===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/BuiltinAttributes.h" 10 #include "mlir/CAPI/AffineMap.h" 11 #include "mlir/CAPI/IR.h" 12 #include "mlir/CAPI/Support.h" 13 #include "mlir/IR/Attributes.h" 14 #include "mlir/IR/BuiltinTypes.h" 15 16 using namespace mlir; 17 18 //===----------------------------------------------------------------------===// 19 // Affine map attribute. 20 //===----------------------------------------------------------------------===// 21 22 bool mlirAttributeIsAAffineMap(MlirAttribute attr) { 23 return unwrap(attr).isa<AffineMapAttr>(); 24 } 25 26 MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) { 27 return wrap(AffineMapAttr::get(unwrap(map))); 28 } 29 30 MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) { 31 return wrap(unwrap(attr).cast<AffineMapAttr>().getValue()); 32 } 33 34 //===----------------------------------------------------------------------===// 35 // Array attribute. 36 //===----------------------------------------------------------------------===// 37 38 bool mlirAttributeIsAArray(MlirAttribute attr) { 39 return unwrap(attr).isa<ArrayAttr>(); 40 } 41 42 MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, 43 MlirAttribute const *elements) { 44 SmallVector<Attribute, 8> attrs; 45 return wrap( 46 ArrayAttr::get(unwrap(ctx), unwrapList(static_cast<size_t>(numElements), 47 elements, attrs))); 48 } 49 50 intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) { 51 return static_cast<intptr_t>(unwrap(attr).cast<ArrayAttr>().size()); 52 } 53 54 MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) { 55 return wrap(unwrap(attr).cast<ArrayAttr>().getValue()[pos]); 56 } 57 58 //===----------------------------------------------------------------------===// 59 // Dictionary attribute. 60 //===----------------------------------------------------------------------===// 61 62 bool mlirAttributeIsADictionary(MlirAttribute attr) { 63 return unwrap(attr).isa<DictionaryAttr>(); 64 } 65 66 MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, 67 MlirNamedAttribute const *elements) { 68 SmallVector<NamedAttribute, 8> attributes; 69 attributes.reserve(numElements); 70 for (intptr_t i = 0; i < numElements; ++i) 71 attributes.emplace_back( 72 Identifier::get(unwrap(elements[i].name), unwrap(ctx)), 73 unwrap(elements[i].attribute)); 74 return wrap(DictionaryAttr::get(unwrap(ctx), attributes)); 75 } 76 77 intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) { 78 return static_cast<intptr_t>(unwrap(attr).cast<DictionaryAttr>().size()); 79 } 80 81 MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr, 82 intptr_t pos) { 83 NamedAttribute attribute = 84 unwrap(attr).cast<DictionaryAttr>().getValue()[pos]; 85 return {wrap(attribute.first), wrap(attribute.second)}; 86 } 87 88 MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, 89 MlirStringRef name) { 90 return wrap(unwrap(attr).cast<DictionaryAttr>().get(unwrap(name))); 91 } 92 93 //===----------------------------------------------------------------------===// 94 // Floating point attribute. 95 //===----------------------------------------------------------------------===// 96 97 bool mlirAttributeIsAFloat(MlirAttribute attr) { 98 return unwrap(attr).isa<FloatAttr>(); 99 } 100 101 MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type, 102 double value) { 103 return wrap(FloatAttr::get(unwrap(type), value)); 104 } 105 106 MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type, 107 double value) { 108 return wrap(FloatAttr::getChecked(unwrap(loc), unwrap(type), value)); 109 } 110 111 double mlirFloatAttrGetValueDouble(MlirAttribute attr) { 112 return unwrap(attr).cast<FloatAttr>().getValueAsDouble(); 113 } 114 115 //===----------------------------------------------------------------------===// 116 // Integer attribute. 117 //===----------------------------------------------------------------------===// 118 119 bool mlirAttributeIsAInteger(MlirAttribute attr) { 120 return unwrap(attr).isa<IntegerAttr>(); 121 } 122 123 MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) { 124 return wrap(IntegerAttr::get(unwrap(type), value)); 125 } 126 127 int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) { 128 return unwrap(attr).cast<IntegerAttr>().getInt(); 129 } 130 131 //===----------------------------------------------------------------------===// 132 // Bool attribute. 133 //===----------------------------------------------------------------------===// 134 135 bool mlirAttributeIsABool(MlirAttribute attr) { 136 return unwrap(attr).isa<BoolAttr>(); 137 } 138 139 MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) { 140 return wrap(BoolAttr::get(unwrap(ctx), value)); 141 } 142 143 bool mlirBoolAttrGetValue(MlirAttribute attr) { 144 return unwrap(attr).cast<BoolAttr>().getValue(); 145 } 146 147 //===----------------------------------------------------------------------===// 148 // Integer set attribute. 149 //===----------------------------------------------------------------------===// 150 151 bool mlirAttributeIsAIntegerSet(MlirAttribute attr) { 152 return unwrap(attr).isa<IntegerSetAttr>(); 153 } 154 155 //===----------------------------------------------------------------------===// 156 // Opaque attribute. 157 //===----------------------------------------------------------------------===// 158 159 bool mlirAttributeIsAOpaque(MlirAttribute attr) { 160 return unwrap(attr).isa<OpaqueAttr>(); 161 } 162 163 MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, 164 intptr_t dataLength, const char *data, 165 MlirType type) { 166 return wrap( 167 OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)), 168 StringRef(data, dataLength), unwrap(type))); 169 } 170 171 MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { 172 return wrap(unwrap(attr).cast<OpaqueAttr>().getDialectNamespace().strref()); 173 } 174 175 MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) { 176 return wrap(unwrap(attr).cast<OpaqueAttr>().getAttrData()); 177 } 178 179 //===----------------------------------------------------------------------===// 180 // String attribute. 181 //===----------------------------------------------------------------------===// 182 183 bool mlirAttributeIsAString(MlirAttribute attr) { 184 return unwrap(attr).isa<StringAttr>(); 185 } 186 187 MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) { 188 return wrap(StringAttr::get(unwrap(ctx), unwrap(str))); 189 } 190 191 MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) { 192 return wrap(StringAttr::get(unwrap(str), unwrap(type))); 193 } 194 195 MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) { 196 return wrap(unwrap(attr).cast<StringAttr>().getValue()); 197 } 198 199 //===----------------------------------------------------------------------===// 200 // SymbolRef attribute. 201 //===----------------------------------------------------------------------===// 202 203 bool mlirAttributeIsASymbolRef(MlirAttribute attr) { 204 return unwrap(attr).isa<SymbolRefAttr>(); 205 } 206 207 MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, 208 intptr_t numReferences, 209 MlirAttribute const *references) { 210 SmallVector<FlatSymbolRefAttr, 4> refs; 211 refs.reserve(numReferences); 212 for (intptr_t i = 0; i < numReferences; ++i) 213 refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>()); 214 return wrap(SymbolRefAttr::get(unwrap(ctx), unwrap(symbol), refs)); 215 } 216 217 MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) { 218 return wrap(unwrap(attr).cast<SymbolRefAttr>().getRootReference()); 219 } 220 221 MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) { 222 return wrap(unwrap(attr).cast<SymbolRefAttr>().getLeafReference()); 223 } 224 225 intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) { 226 return static_cast<intptr_t>( 227 unwrap(attr).cast<SymbolRefAttr>().getNestedReferences().size()); 228 } 229 230 MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, 231 intptr_t pos) { 232 return wrap(unwrap(attr).cast<SymbolRefAttr>().getNestedReferences()[pos]); 233 } 234 235 //===----------------------------------------------------------------------===// 236 // Flat SymbolRef attribute. 237 //===----------------------------------------------------------------------===// 238 239 bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) { 240 return unwrap(attr).isa<FlatSymbolRefAttr>(); 241 } 242 243 MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) { 244 return wrap(FlatSymbolRefAttr::get(unwrap(ctx), unwrap(symbol))); 245 } 246 247 MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { 248 return wrap(unwrap(attr).cast<FlatSymbolRefAttr>().getValue()); 249 } 250 251 //===----------------------------------------------------------------------===// 252 // Type attribute. 253 //===----------------------------------------------------------------------===// 254 255 bool mlirAttributeIsAType(MlirAttribute attr) { 256 return unwrap(attr).isa<TypeAttr>(); 257 } 258 259 MlirAttribute mlirTypeAttrGet(MlirType type) { 260 return wrap(TypeAttr::get(unwrap(type))); 261 } 262 263 MlirType mlirTypeAttrGetValue(MlirAttribute attr) { 264 return wrap(unwrap(attr).cast<TypeAttr>().getValue()); 265 } 266 267 //===----------------------------------------------------------------------===// 268 // Unit attribute. 269 //===----------------------------------------------------------------------===// 270 271 bool mlirAttributeIsAUnit(MlirAttribute attr) { 272 return unwrap(attr).isa<UnitAttr>(); 273 } 274 275 MlirAttribute mlirUnitAttrGet(MlirContext ctx) { 276 return wrap(UnitAttr::get(unwrap(ctx))); 277 } 278 279 //===----------------------------------------------------------------------===// 280 // Elements attributes. 281 //===----------------------------------------------------------------------===// 282 283 bool mlirAttributeIsAElements(MlirAttribute attr) { 284 return unwrap(attr).isa<ElementsAttr>(); 285 } 286 287 MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, 288 uint64_t *idxs) { 289 return wrap(unwrap(attr).cast<ElementsAttr>().getValue( 290 llvm::makeArrayRef(idxs, rank))); 291 } 292 293 bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, 294 uint64_t *idxs) { 295 return unwrap(attr).cast<ElementsAttr>().isValidIndex( 296 llvm::makeArrayRef(idxs, rank)); 297 } 298 299 int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) { 300 return unwrap(attr).cast<ElementsAttr>().getNumElements(); 301 } 302 303 //===----------------------------------------------------------------------===// 304 // Dense elements attribute. 305 //===----------------------------------------------------------------------===// 306 307 //===----------------------------------------------------------------------===// 308 // IsA support. 309 310 bool mlirAttributeIsADenseElements(MlirAttribute attr) { 311 return unwrap(attr).isa<DenseElementsAttr>(); 312 } 313 bool mlirAttributeIsADenseIntElements(MlirAttribute attr) { 314 return unwrap(attr).isa<DenseIntElementsAttr>(); 315 } 316 bool mlirAttributeIsADenseFPElements(MlirAttribute attr) { 317 return unwrap(attr).isa<DenseFPElementsAttr>(); 318 } 319 320 //===----------------------------------------------------------------------===// 321 // Constructors. 322 323 MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, 324 intptr_t numElements, 325 MlirAttribute const *elements) { 326 SmallVector<Attribute, 8> attributes; 327 return wrap( 328 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), 329 unwrapList(numElements, elements, attributes))); 330 } 331 332 MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, 333 MlirAttribute element) { 334 return wrap(DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), 335 unwrap(element))); 336 } 337 MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, 338 bool element) { 339 return wrap( 340 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element)); 341 } 342 MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, 343 uint32_t element) { 344 return wrap( 345 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element)); 346 } 347 MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, 348 int32_t element) { 349 return wrap( 350 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element)); 351 } 352 MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType, 353 uint64_t element) { 354 return wrap( 355 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element)); 356 } 357 MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType, 358 int64_t element) { 359 return wrap( 360 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element)); 361 } 362 MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType, 363 float element) { 364 return wrap( 365 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element)); 366 } 367 MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType, 368 double element) { 369 return wrap( 370 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element)); 371 } 372 373 MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType, 374 intptr_t numElements, 375 const int *elements) { 376 SmallVector<bool, 8> values(elements, elements + numElements); 377 return wrap( 378 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values)); 379 } 380 381 /// Creates a dense attribute with elements of the type deduced by templates. 382 template <typename T> 383 static MlirAttribute getDenseAttribute(MlirType shapedType, 384 intptr_t numElements, 385 const T *elements) { 386 return wrap( 387 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), 388 llvm::makeArrayRef(elements, numElements))); 389 } 390 391 MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType, 392 intptr_t numElements, 393 const uint32_t *elements) { 394 return getDenseAttribute(shapedType, numElements, elements); 395 } 396 MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType, 397 intptr_t numElements, 398 const int32_t *elements) { 399 return getDenseAttribute(shapedType, numElements, elements); 400 } 401 MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType, 402 intptr_t numElements, 403 const uint64_t *elements) { 404 return getDenseAttribute(shapedType, numElements, elements); 405 } 406 MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType, 407 intptr_t numElements, 408 const int64_t *elements) { 409 return getDenseAttribute(shapedType, numElements, elements); 410 } 411 MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType, 412 intptr_t numElements, 413 const float *elements) { 414 return getDenseAttribute(shapedType, numElements, elements); 415 } 416 MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType, 417 intptr_t numElements, 418 const double *elements) { 419 return getDenseAttribute(shapedType, numElements, elements); 420 } 421 422 MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, 423 intptr_t numElements, 424 MlirStringRef *strs) { 425 SmallVector<StringRef, 8> values; 426 values.reserve(numElements); 427 for (intptr_t i = 0; i < numElements; ++i) 428 values.push_back(unwrap(strs[i])); 429 430 return wrap( 431 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values)); 432 } 433 434 MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, 435 MlirType shapedType) { 436 return wrap(unwrap(attr).cast<DenseElementsAttr>().reshape( 437 unwrap(shapedType).cast<ShapedType>())); 438 } 439 440 //===----------------------------------------------------------------------===// 441 // Splat accessors. 442 443 bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { 444 return unwrap(attr).cast<DenseElementsAttr>().isSplat(); 445 } 446 447 MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { 448 return wrap(unwrap(attr).cast<DenseElementsAttr>().getSplatValue()); 449 } 450 int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { 451 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<bool>(); 452 } 453 int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) { 454 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int32_t>(); 455 } 456 uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) { 457 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint32_t>(); 458 } 459 int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) { 460 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int64_t>(); 461 } 462 uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) { 463 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint64_t>(); 464 } 465 float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) { 466 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<float>(); 467 } 468 double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) { 469 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<double>(); 470 } 471 MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) { 472 return wrap( 473 unwrap(attr).cast<DenseElementsAttr>().getSplatValue<StringRef>()); 474 } 475 476 //===----------------------------------------------------------------------===// 477 // Indexed accessors. 478 479 bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { 480 return *(unwrap(attr).cast<DenseElementsAttr>().getValues<bool>().begin() + 481 pos); 482 } 483 int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { 484 return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>().begin() + 485 pos); 486 } 487 uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { 488 return *( 489 unwrap(attr).cast<DenseElementsAttr>().getValues<uint32_t>().begin() + 490 pos); 491 } 492 int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { 493 return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int64_t>().begin() + 494 pos); 495 } 496 uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { 497 return *( 498 unwrap(attr).cast<DenseElementsAttr>().getValues<uint64_t>().begin() + 499 pos); 500 } 501 float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { 502 return *(unwrap(attr).cast<DenseElementsAttr>().getValues<float>().begin() + 503 pos); 504 } 505 double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { 506 return *(unwrap(attr).cast<DenseElementsAttr>().getValues<double>().begin() + 507 pos); 508 } 509 MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, 510 intptr_t pos) { 511 return wrap( 512 *(unwrap(attr).cast<DenseElementsAttr>().getValues<StringRef>().begin() + 513 pos)); 514 } 515 516 //===----------------------------------------------------------------------===// 517 // Raw data accessors. 518 519 const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { 520 return static_cast<const void *>( 521 unwrap(attr).cast<DenseElementsAttr>().getRawData().data()); 522 } 523 524 //===----------------------------------------------------------------------===// 525 // Opaque elements attribute. 526 //===----------------------------------------------------------------------===// 527 528 bool mlirAttributeIsAOpaqueElements(MlirAttribute attr) { 529 return unwrap(attr).isa<OpaqueElementsAttr>(); 530 } 531 532 //===----------------------------------------------------------------------===// 533 // Sparse elements attribute. 534 //===----------------------------------------------------------------------===// 535 536 bool mlirAttributeIsASparseElements(MlirAttribute attr) { 537 return unwrap(attr).isa<SparseElementsAttr>(); 538 } 539 540 MlirAttribute mlirSparseElementsAttribute(MlirType shapedType, 541 MlirAttribute denseIndices, 542 MlirAttribute denseValues) { 543 return wrap( 544 SparseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), 545 unwrap(denseIndices).cast<DenseElementsAttr>(), 546 unwrap(denseValues).cast<DenseElementsAttr>())); 547 } 548 549 MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) { 550 return wrap(unwrap(attr).cast<SparseElementsAttr>().getIndices()); 551 } 552 553 MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) { 554 return wrap(unwrap(attr).cast<SparseElementsAttr>().getValues()); 555 } 556