1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "PybindUtils.h" 12 13 #include "mlir-c/BuiltinAttributes.h" 14 #include "mlir-c/BuiltinTypes.h" 15 16 namespace py = pybind11; 17 using namespace mlir; 18 using namespace mlir::python; 19 20 using llvm::SmallVector; 21 using llvm::Twine; 22 23 namespace { 24 25 static MlirStringRef toMlirStringRef(const std::string &s) { 26 return mlirStringRefCreate(s.data(), s.size()); 27 } 28 29 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 30 public: 31 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 32 static constexpr const char *pyClassName = "AffineMapAttr"; 33 using PyConcreteAttribute::PyConcreteAttribute; 34 35 static void bindDerived(ClassTy &c) { 36 c.def_static( 37 "get", 38 [](PyAffineMap &affineMap) { 39 MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 40 return PyAffineMapAttribute(affineMap.getContext(), attr); 41 }, 42 py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 43 } 44 }; 45 46 template <typename T> 47 static T pyTryCast(py::handle object) { 48 try { 49 return object.cast<T>(); 50 } catch (py::cast_error &err) { 51 std::string msg = 52 std::string( 53 "Invalid attribute when attempting to create an ArrayAttribute (") + 54 err.what() + ")"; 55 throw py::cast_error(msg); 56 } catch (py::reference_cast_error &err) { 57 std::string msg = std::string("Invalid attribute (None?) when attempting " 58 "to create an ArrayAttribute (") + 59 err.what() + ")"; 60 throw py::cast_error(msg); 61 } 62 } 63 64 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 65 public: 66 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 67 static constexpr const char *pyClassName = "ArrayAttr"; 68 using PyConcreteAttribute::PyConcreteAttribute; 69 70 class PyArrayAttributeIterator { 71 public: 72 PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {} 73 74 PyArrayAttributeIterator &dunderIter() { return *this; } 75 76 PyAttribute dunderNext() { 77 if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { 78 throw py::stop_iteration(); 79 } 80 return PyAttribute(attr.getContext(), 81 mlirArrayAttrGetElement(attr.get(), nextIndex++)); 82 } 83 84 static void bind(py::module &m) { 85 py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator", 86 py::module_local()) 87 .def("__iter__", &PyArrayAttributeIterator::dunderIter) 88 .def("__next__", &PyArrayAttributeIterator::dunderNext); 89 } 90 91 private: 92 PyAttribute attr; 93 int nextIndex = 0; 94 }; 95 96 PyAttribute getItem(intptr_t i) { 97 return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i)); 98 } 99 100 static void bindDerived(ClassTy &c) { 101 c.def_static( 102 "get", 103 [](py::list attributes, DefaultingPyMlirContext context) { 104 SmallVector<MlirAttribute> mlirAttributes; 105 mlirAttributes.reserve(py::len(attributes)); 106 for (auto attribute : attributes) { 107 mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); 108 } 109 MlirAttribute attr = mlirArrayAttrGet( 110 context->get(), mlirAttributes.size(), mlirAttributes.data()); 111 return PyArrayAttribute(context->getRef(), attr); 112 }, 113 py::arg("attributes"), py::arg("context") = py::none(), 114 "Gets a uniqued Array attribute"); 115 c.def("__getitem__", 116 [](PyArrayAttribute &arr, intptr_t i) { 117 if (i >= mlirArrayAttrGetNumElements(arr)) 118 throw py::index_error("ArrayAttribute index out of range"); 119 return arr.getItem(i); 120 }) 121 .def("__len__", 122 [](const PyArrayAttribute &arr) { 123 return mlirArrayAttrGetNumElements(arr); 124 }) 125 .def("__iter__", [](const PyArrayAttribute &arr) { 126 return PyArrayAttributeIterator(arr); 127 }); 128 c.def("__add__", [](PyArrayAttribute arr, py::list extras) { 129 std::vector<MlirAttribute> attributes; 130 intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); 131 attributes.reserve(numOldElements + py::len(extras)); 132 for (intptr_t i = 0; i < numOldElements; ++i) 133 attributes.push_back(arr.getItem(i)); 134 for (py::handle attr : extras) 135 attributes.push_back(pyTryCast<PyAttribute>(attr)); 136 MlirAttribute arrayAttr = mlirArrayAttrGet( 137 arr.getContext()->get(), attributes.size(), attributes.data()); 138 return PyArrayAttribute(arr.getContext(), arrayAttr); 139 }); 140 } 141 }; 142 143 /// Float Point Attribute subclass - FloatAttr. 144 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 145 public: 146 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 147 static constexpr const char *pyClassName = "FloatAttr"; 148 using PyConcreteAttribute::PyConcreteAttribute; 149 150 static void bindDerived(ClassTy &c) { 151 c.def_static( 152 "get", 153 [](PyType &type, double value, DefaultingPyLocation loc) { 154 MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 155 // TODO: Rework error reporting once diagnostic engine is exposed 156 // in C API. 157 if (mlirAttributeIsNull(attr)) { 158 throw SetPyError(PyExc_ValueError, 159 Twine("invalid '") + 160 py::repr(py::cast(type)).cast<std::string>() + 161 "' and expected floating point type."); 162 } 163 return PyFloatAttribute(type.getContext(), attr); 164 }, 165 py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 166 "Gets an uniqued float point attribute associated to a type"); 167 c.def_static( 168 "get_f32", 169 [](double value, DefaultingPyMlirContext context) { 170 MlirAttribute attr = mlirFloatAttrDoubleGet( 171 context->get(), mlirF32TypeGet(context->get()), value); 172 return PyFloatAttribute(context->getRef(), attr); 173 }, 174 py::arg("value"), py::arg("context") = py::none(), 175 "Gets an uniqued float point attribute associated to a f32 type"); 176 c.def_static( 177 "get_f64", 178 [](double value, DefaultingPyMlirContext context) { 179 MlirAttribute attr = mlirFloatAttrDoubleGet( 180 context->get(), mlirF64TypeGet(context->get()), value); 181 return PyFloatAttribute(context->getRef(), attr); 182 }, 183 py::arg("value"), py::arg("context") = py::none(), 184 "Gets an uniqued float point attribute associated to a f64 type"); 185 c.def_property_readonly( 186 "value", 187 [](PyFloatAttribute &self) { 188 return mlirFloatAttrGetValueDouble(self); 189 }, 190 "Returns the value of the float point attribute"); 191 } 192 }; 193 194 /// Integer Attribute subclass - IntegerAttr. 195 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 196 public: 197 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 198 static constexpr const char *pyClassName = "IntegerAttr"; 199 using PyConcreteAttribute::PyConcreteAttribute; 200 201 static void bindDerived(ClassTy &c) { 202 c.def_static( 203 "get", 204 [](PyType &type, int64_t value) { 205 MlirAttribute attr = mlirIntegerAttrGet(type, value); 206 return PyIntegerAttribute(type.getContext(), attr); 207 }, 208 py::arg("type"), py::arg("value"), 209 "Gets an uniqued integer attribute associated to a type"); 210 c.def_property_readonly( 211 "value", 212 [](PyIntegerAttribute &self) { 213 return mlirIntegerAttrGetValueInt(self); 214 }, 215 "Returns the value of the integer attribute"); 216 } 217 }; 218 219 /// Bool Attribute subclass - BoolAttr. 220 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 221 public: 222 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 223 static constexpr const char *pyClassName = "BoolAttr"; 224 using PyConcreteAttribute::PyConcreteAttribute; 225 226 static void bindDerived(ClassTy &c) { 227 c.def_static( 228 "get", 229 [](bool value, DefaultingPyMlirContext context) { 230 MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 231 return PyBoolAttribute(context->getRef(), attr); 232 }, 233 py::arg("value"), py::arg("context") = py::none(), 234 "Gets an uniqued bool attribute"); 235 c.def_property_readonly( 236 "value", 237 [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, 238 "Returns the value of the bool attribute"); 239 } 240 }; 241 242 class PyFlatSymbolRefAttribute 243 : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 244 public: 245 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 246 static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 247 using PyConcreteAttribute::PyConcreteAttribute; 248 249 static void bindDerived(ClassTy &c) { 250 c.def_static( 251 "get", 252 [](std::string value, DefaultingPyMlirContext context) { 253 MlirAttribute attr = 254 mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 255 return PyFlatSymbolRefAttribute(context->getRef(), attr); 256 }, 257 py::arg("value"), py::arg("context") = py::none(), 258 "Gets a uniqued FlatSymbolRef attribute"); 259 c.def_property_readonly( 260 "value", 261 [](PyFlatSymbolRefAttribute &self) { 262 MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 263 return py::str(stringRef.data, stringRef.length); 264 }, 265 "Returns the value of the FlatSymbolRef attribute as a string"); 266 } 267 }; 268 269 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 270 public: 271 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 272 static constexpr const char *pyClassName = "StringAttr"; 273 using PyConcreteAttribute::PyConcreteAttribute; 274 275 static void bindDerived(ClassTy &c) { 276 c.def_static( 277 "get", 278 [](std::string value, DefaultingPyMlirContext context) { 279 MlirAttribute attr = 280 mlirStringAttrGet(context->get(), toMlirStringRef(value)); 281 return PyStringAttribute(context->getRef(), attr); 282 }, 283 py::arg("value"), py::arg("context") = py::none(), 284 "Gets a uniqued string attribute"); 285 c.def_static( 286 "get_typed", 287 [](PyType &type, std::string value) { 288 MlirAttribute attr = 289 mlirStringAttrTypedGet(type, toMlirStringRef(value)); 290 return PyStringAttribute(type.getContext(), attr); 291 }, 292 293 "Gets a uniqued string attribute associated to a type"); 294 c.def_property_readonly( 295 "value", 296 [](PyStringAttribute &self) { 297 MlirStringRef stringRef = mlirStringAttrGetValue(self); 298 return py::str(stringRef.data, stringRef.length); 299 }, 300 "Returns the value of the string attribute"); 301 } 302 }; 303 304 // TODO: Support construction of bool elements. 305 // TODO: Support construction of string elements. 306 class PyDenseElementsAttribute 307 : public PyConcreteAttribute<PyDenseElementsAttribute> { 308 public: 309 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 310 static constexpr const char *pyClassName = "DenseElementsAttr"; 311 using PyConcreteAttribute::PyConcreteAttribute; 312 313 static PyDenseElementsAttribute 314 getFromBuffer(py::buffer array, bool signless, 315 DefaultingPyMlirContext contextWrapper) { 316 // Request a contiguous view. In exotic cases, this will cause a copy. 317 int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; 318 Py_buffer *view = new Py_buffer(); 319 if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { 320 delete view; 321 throw py::error_already_set(); 322 } 323 py::buffer_info arrayInfo(view); 324 325 MlirContext context = contextWrapper->get(); 326 // Switch on the types that can be bulk loaded between the Python and 327 // MLIR-C APIs. 328 // See: https://docs.python.org/3/library/struct.html#format-characters 329 if (arrayInfo.format == "f") { 330 // f32 331 assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); 332 return PyDenseElementsAttribute( 333 contextWrapper->getRef(), 334 bulkLoad(context, mlirDenseElementsAttrFloatGet, 335 mlirF32TypeGet(context), arrayInfo)); 336 } else if (arrayInfo.format == "d") { 337 // f64 338 assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); 339 return PyDenseElementsAttribute( 340 contextWrapper->getRef(), 341 bulkLoad(context, mlirDenseElementsAttrDoubleGet, 342 mlirF64TypeGet(context), arrayInfo)); 343 } else if (isSignedIntegerFormat(arrayInfo.format)) { 344 if (arrayInfo.itemsize == 4) { 345 // i32 346 MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) 347 : mlirIntegerTypeSignedGet(context, 32); 348 return PyDenseElementsAttribute(contextWrapper->getRef(), 349 bulkLoad(context, 350 mlirDenseElementsAttrInt32Get, 351 elementType, arrayInfo)); 352 } else if (arrayInfo.itemsize == 8) { 353 // i64 354 MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) 355 : mlirIntegerTypeSignedGet(context, 64); 356 return PyDenseElementsAttribute(contextWrapper->getRef(), 357 bulkLoad(context, 358 mlirDenseElementsAttrInt64Get, 359 elementType, arrayInfo)); 360 } 361 } else if (isUnsignedIntegerFormat(arrayInfo.format)) { 362 if (arrayInfo.itemsize == 4) { 363 // unsigned i32 364 MlirType elementType = signless 365 ? mlirIntegerTypeGet(context, 32) 366 : mlirIntegerTypeUnsignedGet(context, 32); 367 return PyDenseElementsAttribute(contextWrapper->getRef(), 368 bulkLoad(context, 369 mlirDenseElementsAttrUInt32Get, 370 elementType, arrayInfo)); 371 } else if (arrayInfo.itemsize == 8) { 372 // unsigned i64 373 MlirType elementType = signless 374 ? mlirIntegerTypeGet(context, 64) 375 : mlirIntegerTypeUnsignedGet(context, 64); 376 return PyDenseElementsAttribute(contextWrapper->getRef(), 377 bulkLoad(context, 378 mlirDenseElementsAttrUInt64Get, 379 elementType, arrayInfo)); 380 } 381 } 382 383 // TODO: Fall back to string-based get. 384 std::string message = "unimplemented array format conversion from format: "; 385 message.append(arrayInfo.format); 386 throw SetPyError(PyExc_ValueError, message); 387 } 388 389 static PyDenseElementsAttribute getSplat(PyType shapedType, 390 PyAttribute &elementAttr) { 391 auto contextWrapper = 392 PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 393 if (!mlirAttributeIsAInteger(elementAttr) && 394 !mlirAttributeIsAFloat(elementAttr)) { 395 std::string message = "Illegal element type for DenseElementsAttr: "; 396 message.append(py::repr(py::cast(elementAttr))); 397 throw SetPyError(PyExc_ValueError, message); 398 } 399 if (!mlirTypeIsAShaped(shapedType) || 400 !mlirShapedTypeHasStaticShape(shapedType)) { 401 std::string message = 402 "Expected a static ShapedType for the shaped_type parameter: "; 403 message.append(py::repr(py::cast(shapedType))); 404 throw SetPyError(PyExc_ValueError, message); 405 } 406 MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 407 MlirType attrType = mlirAttributeGetType(elementAttr); 408 if (!mlirTypeEqual(shapedElementType, attrType)) { 409 std::string message = 410 "Shaped element type and attribute type must be equal: shaped="; 411 message.append(py::repr(py::cast(shapedType))); 412 message.append(", element="); 413 message.append(py::repr(py::cast(elementAttr))); 414 throw SetPyError(PyExc_ValueError, message); 415 } 416 417 MlirAttribute elements = 418 mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 419 return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 420 } 421 422 intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 423 424 py::buffer_info accessBuffer() { 425 MlirType shapedType = mlirAttributeGetType(*this); 426 MlirType elementType = mlirShapedTypeGetElementType(shapedType); 427 428 if (mlirTypeIsAF32(elementType)) { 429 // f32 430 return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue); 431 } else if (mlirTypeIsAF64(elementType)) { 432 // f64 433 return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue); 434 } else if (mlirTypeIsAInteger(elementType) && 435 mlirIntegerTypeGetWidth(elementType) == 32) { 436 if (mlirIntegerTypeIsSignless(elementType) || 437 mlirIntegerTypeIsSigned(elementType)) { 438 // i32 439 return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value); 440 } else if (mlirIntegerTypeIsUnsigned(elementType)) { 441 // unsigned i32 442 return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value); 443 } 444 } else if (mlirTypeIsAInteger(elementType) && 445 mlirIntegerTypeGetWidth(elementType) == 64) { 446 if (mlirIntegerTypeIsSignless(elementType) || 447 mlirIntegerTypeIsSigned(elementType)) { 448 // i64 449 return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value); 450 } else if (mlirIntegerTypeIsUnsigned(elementType)) { 451 // unsigned i64 452 return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value); 453 } 454 } 455 456 std::string message = "unimplemented array format."; 457 throw SetPyError(PyExc_ValueError, message); 458 } 459 460 static void bindDerived(ClassTy &c) { 461 c.def("__len__", &PyDenseElementsAttribute::dunderLen) 462 .def_static("get", PyDenseElementsAttribute::getFromBuffer, 463 py::arg("array"), py::arg("signless") = true, 464 py::arg("context") = py::none(), 465 "Gets from a buffer or ndarray") 466 .def_static("get_splat", PyDenseElementsAttribute::getSplat, 467 py::arg("shaped_type"), py::arg("element_attr"), 468 "Gets a DenseElementsAttr where all values are the same") 469 .def_property_readonly("is_splat", 470 [](PyDenseElementsAttribute &self) -> bool { 471 return mlirDenseElementsAttrIsSplat(self); 472 }) 473 .def_buffer(&PyDenseElementsAttribute::accessBuffer); 474 } 475 476 private: 477 template <typename ElementTy> 478 static MlirAttribute 479 bulkLoad(MlirContext context, 480 MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *), 481 MlirType mlirElementType, py::buffer_info &arrayInfo) { 482 SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(), 483 arrayInfo.shape.begin() + arrayInfo.ndim); 484 MlirAttribute encodingAttr = mlirAttributeGetNull(); 485 auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), 486 mlirElementType, encodingAttr); 487 intptr_t numElements = arrayInfo.size; 488 const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr); 489 return ctor(shapedType, numElements, contents); 490 } 491 492 static bool isUnsignedIntegerFormat(const std::string &format) { 493 if (format.empty()) 494 return false; 495 char code = format[0]; 496 return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 497 code == 'Q'; 498 } 499 500 static bool isSignedIntegerFormat(const std::string &format) { 501 if (format.empty()) 502 return false; 503 char code = format[0]; 504 return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 505 code == 'q'; 506 } 507 508 template <typename Type> 509 py::buffer_info bufferInfo(MlirType shapedType, 510 Type (*value)(MlirAttribute, intptr_t)) { 511 intptr_t rank = mlirShapedTypeGetRank(shapedType); 512 // Prepare the data for the buffer_info. 513 // Buffer is configured for read-only access below. 514 Type *data = static_cast<Type *>( 515 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 516 // Prepare the shape for the buffer_info. 517 SmallVector<intptr_t, 4> shape; 518 for (intptr_t i = 0; i < rank; ++i) 519 shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 520 // Prepare the strides for the buffer_info. 521 SmallVector<intptr_t, 4> strides; 522 intptr_t strideFactor = 1; 523 for (intptr_t i = 1; i < rank; ++i) { 524 strideFactor = 1; 525 for (intptr_t j = i; j < rank; ++j) { 526 strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 527 } 528 strides.push_back(sizeof(Type) * strideFactor); 529 } 530 strides.push_back(sizeof(Type)); 531 return py::buffer_info(data, sizeof(Type), 532 py::format_descriptor<Type>::format(), rank, shape, 533 strides, /*readonly=*/true); 534 } 535 }; // namespace 536 537 /// Refinement of the PyDenseElementsAttribute for attributes containing integer 538 /// (and boolean) values. Supports element access. 539 class PyDenseIntElementsAttribute 540 : public PyConcreteAttribute<PyDenseIntElementsAttribute, 541 PyDenseElementsAttribute> { 542 public: 543 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 544 static constexpr const char *pyClassName = "DenseIntElementsAttr"; 545 using PyConcreteAttribute::PyConcreteAttribute; 546 547 /// Returns the element at the given linear position. Asserts if the index is 548 /// out of range. 549 py::int_ dunderGetItem(intptr_t pos) { 550 if (pos < 0 || pos >= dunderLen()) { 551 throw SetPyError(PyExc_IndexError, 552 "attempt to access out of bounds element"); 553 } 554 555 MlirType type = mlirAttributeGetType(*this); 556 type = mlirShapedTypeGetElementType(type); 557 assert(mlirTypeIsAInteger(type) && 558 "expected integer element type in dense int elements attribute"); 559 // Dispatch element extraction to an appropriate C function based on the 560 // elemental type of the attribute. py::int_ is implicitly constructible 561 // from any C++ integral type and handles bitwidth correctly. 562 // TODO: consider caching the type properties in the constructor to avoid 563 // querying them on each element access. 564 unsigned width = mlirIntegerTypeGetWidth(type); 565 bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 566 if (isUnsigned) { 567 if (width == 1) { 568 return mlirDenseElementsAttrGetBoolValue(*this, pos); 569 } 570 if (width == 32) { 571 return mlirDenseElementsAttrGetUInt32Value(*this, pos); 572 } 573 if (width == 64) { 574 return mlirDenseElementsAttrGetUInt64Value(*this, pos); 575 } 576 } else { 577 if (width == 1) { 578 return mlirDenseElementsAttrGetBoolValue(*this, pos); 579 } 580 if (width == 32) { 581 return mlirDenseElementsAttrGetInt32Value(*this, pos); 582 } 583 if (width == 64) { 584 return mlirDenseElementsAttrGetInt64Value(*this, pos); 585 } 586 } 587 throw SetPyError(PyExc_TypeError, "Unsupported integer type"); 588 } 589 590 static void bindDerived(ClassTy &c) { 591 c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 592 } 593 }; 594 595 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 596 public: 597 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 598 static constexpr const char *pyClassName = "DictAttr"; 599 using PyConcreteAttribute::PyConcreteAttribute; 600 601 intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 602 603 static void bindDerived(ClassTy &c) { 604 c.def("__len__", &PyDictAttribute::dunderLen); 605 c.def_static( 606 "get", 607 [](py::dict attributes, DefaultingPyMlirContext context) { 608 SmallVector<MlirNamedAttribute> mlirNamedAttributes; 609 mlirNamedAttributes.reserve(attributes.size()); 610 for (auto &it : attributes) { 611 auto &mlir_attr = it.second.cast<PyAttribute &>(); 612 auto name = it.first.cast<std::string>(); 613 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 614 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr), 615 toMlirStringRef(name)), 616 mlir_attr)); 617 } 618 MlirAttribute attr = 619 mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 620 mlirNamedAttributes.data()); 621 return PyDictAttribute(context->getRef(), attr); 622 }, 623 py::arg("value") = py::dict(), py::arg("context") = py::none(), 624 "Gets an uniqued dict attribute"); 625 c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 626 MlirAttribute attr = 627 mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 628 if (mlirAttributeIsNull(attr)) { 629 throw SetPyError(PyExc_KeyError, 630 "attempt to access a non-existent attribute"); 631 } 632 return PyAttribute(self.getContext(), attr); 633 }); 634 c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 635 if (index < 0 || index >= self.dunderLen()) { 636 throw SetPyError(PyExc_IndexError, 637 "attempt to access out of bounds attribute"); 638 } 639 MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 640 return PyNamedAttribute( 641 namedAttr.attribute, 642 std::string(mlirIdentifierStr(namedAttr.name).data)); 643 }); 644 } 645 }; 646 647 /// Refinement of PyDenseElementsAttribute for attributes containing 648 /// floating-point values. Supports element access. 649 class PyDenseFPElementsAttribute 650 : public PyConcreteAttribute<PyDenseFPElementsAttribute, 651 PyDenseElementsAttribute> { 652 public: 653 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 654 static constexpr const char *pyClassName = "DenseFPElementsAttr"; 655 using PyConcreteAttribute::PyConcreteAttribute; 656 657 py::float_ dunderGetItem(intptr_t pos) { 658 if (pos < 0 || pos >= dunderLen()) { 659 throw SetPyError(PyExc_IndexError, 660 "attempt to access out of bounds element"); 661 } 662 663 MlirType type = mlirAttributeGetType(*this); 664 type = mlirShapedTypeGetElementType(type); 665 // Dispatch element extraction to an appropriate C function based on the 666 // elemental type of the attribute. py::float_ is implicitly constructible 667 // from float and double. 668 // TODO: consider caching the type properties in the constructor to avoid 669 // querying them on each element access. 670 if (mlirTypeIsAF32(type)) { 671 return mlirDenseElementsAttrGetFloatValue(*this, pos); 672 } 673 if (mlirTypeIsAF64(type)) { 674 return mlirDenseElementsAttrGetDoubleValue(*this, pos); 675 } 676 throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); 677 } 678 679 static void bindDerived(ClassTy &c) { 680 c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 681 } 682 }; 683 684 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 685 public: 686 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 687 static constexpr const char *pyClassName = "TypeAttr"; 688 using PyConcreteAttribute::PyConcreteAttribute; 689 690 static void bindDerived(ClassTy &c) { 691 c.def_static( 692 "get", 693 [](PyType value, DefaultingPyMlirContext context) { 694 MlirAttribute attr = mlirTypeAttrGet(value.get()); 695 return PyTypeAttribute(context->getRef(), attr); 696 }, 697 py::arg("value"), py::arg("context") = py::none(), 698 "Gets a uniqued Type attribute"); 699 c.def_property_readonly("value", [](PyTypeAttribute &self) { 700 return PyType(self.getContext()->getRef(), 701 mlirTypeAttrGetValue(self.get())); 702 }); 703 } 704 }; 705 706 /// Unit Attribute subclass. Unit attributes don't have values. 707 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 708 public: 709 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 710 static constexpr const char *pyClassName = "UnitAttr"; 711 using PyConcreteAttribute::PyConcreteAttribute; 712 713 static void bindDerived(ClassTy &c) { 714 c.def_static( 715 "get", 716 [](DefaultingPyMlirContext context) { 717 return PyUnitAttribute(context->getRef(), 718 mlirUnitAttrGet(context->get())); 719 }, 720 py::arg("context") = py::none(), "Create a Unit attribute."); 721 } 722 }; 723 724 } // namespace 725 726 void mlir::python::populateIRAttributes(py::module &m) { 727 PyAffineMapAttribute::bind(m); 728 PyArrayAttribute::bind(m); 729 PyArrayAttribute::PyArrayAttributeIterator::bind(m); 730 PyBoolAttribute::bind(m); 731 PyDenseElementsAttribute::bind(m); 732 PyDenseFPElementsAttribute::bind(m); 733 PyDenseIntElementsAttribute::bind(m); 734 PyDictAttribute::bind(m); 735 PyFlatSymbolRefAttribute::bind(m); 736 PyFloatAttribute::bind(m); 737 PyIntegerAttribute::bind(m); 738 PyStringAttribute::bind(m); 739 PyTypeAttribute::bind(m); 740 PyUnitAttribute::bind(m); 741 } 742