1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "PybindUtils.h" 12 13 #include "mlir-c/BuiltinAttributes.h" 14 #include "mlir-c/BuiltinTypes.h" 15 16 namespace py = pybind11; 17 using namespace mlir; 18 using namespace mlir::python; 19 20 using llvm::SmallVector; 21 using llvm::StringRef; 22 using llvm::Twine; 23 24 namespace { 25 26 static MlirStringRef toMlirStringRef(const std::string &s) { 27 return mlirStringRefCreate(s.data(), s.size()); 28 } 29 30 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 31 public: 32 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 33 static constexpr const char *pyClassName = "AffineMapAttr"; 34 using PyConcreteAttribute::PyConcreteAttribute; 35 36 static void bindDerived(ClassTy &c) { 37 c.def_static( 38 "get", 39 [](PyAffineMap &affineMap) { 40 MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 41 return PyAffineMapAttribute(affineMap.getContext(), attr); 42 }, 43 py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 44 } 45 }; 46 47 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 48 public: 49 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 50 static constexpr const char *pyClassName = "ArrayAttr"; 51 using PyConcreteAttribute::PyConcreteAttribute; 52 53 class PyArrayAttributeIterator { 54 public: 55 PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {} 56 57 PyArrayAttributeIterator &dunderIter() { return *this; } 58 59 PyAttribute dunderNext() { 60 if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { 61 throw py::stop_iteration(); 62 } 63 return PyAttribute(attr.getContext(), 64 mlirArrayAttrGetElement(attr.get(), nextIndex++)); 65 } 66 67 static void bind(py::module &m) { 68 py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator", 69 py::module_local()) 70 .def("__iter__", &PyArrayAttributeIterator::dunderIter) 71 .def("__next__", &PyArrayAttributeIterator::dunderNext); 72 } 73 74 private: 75 PyAttribute attr; 76 int nextIndex = 0; 77 }; 78 79 static void bindDerived(ClassTy &c) { 80 c.def_static( 81 "get", 82 [](py::list attributes, DefaultingPyMlirContext context) { 83 SmallVector<MlirAttribute> mlirAttributes; 84 mlirAttributes.reserve(py::len(attributes)); 85 for (auto attribute : attributes) { 86 try { 87 mlirAttributes.push_back(attribute.cast<PyAttribute>()); 88 } catch (py::cast_error &err) { 89 std::string msg = std::string("Invalid attribute when attempting " 90 "to create an ArrayAttribute (") + 91 err.what() + ")"; 92 throw py::cast_error(msg); 93 } catch (py::reference_cast_error &err) { 94 // This exception seems thrown when the value is "None". 95 std::string msg = 96 std::string("Invalid attribute (None?) when attempting to " 97 "create an ArrayAttribute (") + 98 err.what() + ")"; 99 throw py::cast_error(msg); 100 } 101 } 102 MlirAttribute attr = mlirArrayAttrGet( 103 context->get(), mlirAttributes.size(), mlirAttributes.data()); 104 return PyArrayAttribute(context->getRef(), attr); 105 }, 106 py::arg("attributes"), py::arg("context") = py::none(), 107 "Gets a uniqued Array attribute"); 108 c.def("__getitem__", 109 [](PyArrayAttribute &arr, intptr_t i) { 110 if (i >= mlirArrayAttrGetNumElements(arr)) 111 throw py::index_error("ArrayAttribute index out of range"); 112 return PyAttribute(arr.getContext(), 113 mlirArrayAttrGetElement(arr, i)); 114 }) 115 .def("__len__", 116 [](const PyArrayAttribute &arr) { 117 return mlirArrayAttrGetNumElements(arr); 118 }) 119 .def("__iter__", [](const PyArrayAttribute &arr) { 120 return PyArrayAttributeIterator(arr); 121 }); 122 } 123 }; 124 125 /// Float Point Attribute subclass - FloatAttr. 126 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 127 public: 128 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 129 static constexpr const char *pyClassName = "FloatAttr"; 130 using PyConcreteAttribute::PyConcreteAttribute; 131 132 static void bindDerived(ClassTy &c) { 133 c.def_static( 134 "get", 135 [](PyType &type, double value, DefaultingPyLocation loc) { 136 MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 137 // TODO: Rework error reporting once diagnostic engine is exposed 138 // in C API. 139 if (mlirAttributeIsNull(attr)) { 140 throw SetPyError(PyExc_ValueError, 141 Twine("invalid '") + 142 py::repr(py::cast(type)).cast<std::string>() + 143 "' and expected floating point type."); 144 } 145 return PyFloatAttribute(type.getContext(), attr); 146 }, 147 py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 148 "Gets an uniqued float point attribute associated to a type"); 149 c.def_static( 150 "get_f32", 151 [](double value, DefaultingPyMlirContext context) { 152 MlirAttribute attr = mlirFloatAttrDoubleGet( 153 context->get(), mlirF32TypeGet(context->get()), value); 154 return PyFloatAttribute(context->getRef(), attr); 155 }, 156 py::arg("value"), py::arg("context") = py::none(), 157 "Gets an uniqued float point attribute associated to a f32 type"); 158 c.def_static( 159 "get_f64", 160 [](double value, DefaultingPyMlirContext context) { 161 MlirAttribute attr = mlirFloatAttrDoubleGet( 162 context->get(), mlirF64TypeGet(context->get()), value); 163 return PyFloatAttribute(context->getRef(), attr); 164 }, 165 py::arg("value"), py::arg("context") = py::none(), 166 "Gets an uniqued float point attribute associated to a f64 type"); 167 c.def_property_readonly( 168 "value", 169 [](PyFloatAttribute &self) { 170 return mlirFloatAttrGetValueDouble(self); 171 }, 172 "Returns the value of the float point attribute"); 173 } 174 }; 175 176 /// Integer Attribute subclass - IntegerAttr. 177 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 178 public: 179 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 180 static constexpr const char *pyClassName = "IntegerAttr"; 181 using PyConcreteAttribute::PyConcreteAttribute; 182 183 static void bindDerived(ClassTy &c) { 184 c.def_static( 185 "get", 186 [](PyType &type, int64_t value) { 187 MlirAttribute attr = mlirIntegerAttrGet(type, value); 188 return PyIntegerAttribute(type.getContext(), attr); 189 }, 190 py::arg("type"), py::arg("value"), 191 "Gets an uniqued integer attribute associated to a type"); 192 c.def_property_readonly( 193 "value", 194 [](PyIntegerAttribute &self) { 195 return mlirIntegerAttrGetValueInt(self); 196 }, 197 "Returns the value of the integer attribute"); 198 } 199 }; 200 201 /// Bool Attribute subclass - BoolAttr. 202 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 203 public: 204 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 205 static constexpr const char *pyClassName = "BoolAttr"; 206 using PyConcreteAttribute::PyConcreteAttribute; 207 208 static void bindDerived(ClassTy &c) { 209 c.def_static( 210 "get", 211 [](bool value, DefaultingPyMlirContext context) { 212 MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 213 return PyBoolAttribute(context->getRef(), attr); 214 }, 215 py::arg("value"), py::arg("context") = py::none(), 216 "Gets an uniqued bool attribute"); 217 c.def_property_readonly( 218 "value", 219 [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, 220 "Returns the value of the bool attribute"); 221 } 222 }; 223 224 class PyFlatSymbolRefAttribute 225 : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 226 public: 227 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 228 static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 229 using PyConcreteAttribute::PyConcreteAttribute; 230 231 static void bindDerived(ClassTy &c) { 232 c.def_static( 233 "get", 234 [](std::string value, DefaultingPyMlirContext context) { 235 MlirAttribute attr = 236 mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 237 return PyFlatSymbolRefAttribute(context->getRef(), attr); 238 }, 239 py::arg("value"), py::arg("context") = py::none(), 240 "Gets a uniqued FlatSymbolRef attribute"); 241 c.def_property_readonly( 242 "value", 243 [](PyFlatSymbolRefAttribute &self) { 244 MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 245 return py::str(stringRef.data, stringRef.length); 246 }, 247 "Returns the value of the FlatSymbolRef attribute as a string"); 248 } 249 }; 250 251 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 252 public: 253 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 254 static constexpr const char *pyClassName = "StringAttr"; 255 using PyConcreteAttribute::PyConcreteAttribute; 256 257 static void bindDerived(ClassTy &c) { 258 c.def_static( 259 "get", 260 [](std::string value, DefaultingPyMlirContext context) { 261 MlirAttribute attr = 262 mlirStringAttrGet(context->get(), toMlirStringRef(value)); 263 return PyStringAttribute(context->getRef(), attr); 264 }, 265 py::arg("value"), py::arg("context") = py::none(), 266 "Gets a uniqued string attribute"); 267 c.def_static( 268 "get_typed", 269 [](PyType &type, std::string value) { 270 MlirAttribute attr = 271 mlirStringAttrTypedGet(type, toMlirStringRef(value)); 272 return PyStringAttribute(type.getContext(), attr); 273 }, 274 275 "Gets a uniqued string attribute associated to a type"); 276 c.def_property_readonly( 277 "value", 278 [](PyStringAttribute &self) { 279 MlirStringRef stringRef = mlirStringAttrGetValue(self); 280 return py::str(stringRef.data, stringRef.length); 281 }, 282 "Returns the value of the string attribute"); 283 } 284 }; 285 286 // TODO: Support construction of bool elements. 287 // TODO: Support construction of string elements. 288 class PyDenseElementsAttribute 289 : public PyConcreteAttribute<PyDenseElementsAttribute> { 290 public: 291 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 292 static constexpr const char *pyClassName = "DenseElementsAttr"; 293 using PyConcreteAttribute::PyConcreteAttribute; 294 295 static PyDenseElementsAttribute 296 getFromBuffer(py::buffer array, bool signless, 297 DefaultingPyMlirContext contextWrapper) { 298 // Request a contiguous view. In exotic cases, this will cause a copy. 299 int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; 300 Py_buffer *view = new Py_buffer(); 301 if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { 302 delete view; 303 throw py::error_already_set(); 304 } 305 py::buffer_info arrayInfo(view); 306 307 MlirContext context = contextWrapper->get(); 308 // Switch on the types that can be bulk loaded between the Python and 309 // MLIR-C APIs. 310 // See: https://docs.python.org/3/library/struct.html#format-characters 311 if (arrayInfo.format == "f") { 312 // f32 313 assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); 314 return PyDenseElementsAttribute( 315 contextWrapper->getRef(), 316 bulkLoad(context, mlirDenseElementsAttrFloatGet, 317 mlirF32TypeGet(context), arrayInfo)); 318 } else if (arrayInfo.format == "d") { 319 // f64 320 assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); 321 return PyDenseElementsAttribute( 322 contextWrapper->getRef(), 323 bulkLoad(context, mlirDenseElementsAttrDoubleGet, 324 mlirF64TypeGet(context), arrayInfo)); 325 } else if (isSignedIntegerFormat(arrayInfo.format)) { 326 if (arrayInfo.itemsize == 4) { 327 // i32 328 MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) 329 : mlirIntegerTypeSignedGet(context, 32); 330 return PyDenseElementsAttribute(contextWrapper->getRef(), 331 bulkLoad(context, 332 mlirDenseElementsAttrInt32Get, 333 elementType, arrayInfo)); 334 } else if (arrayInfo.itemsize == 8) { 335 // i64 336 MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) 337 : mlirIntegerTypeSignedGet(context, 64); 338 return PyDenseElementsAttribute(contextWrapper->getRef(), 339 bulkLoad(context, 340 mlirDenseElementsAttrInt64Get, 341 elementType, arrayInfo)); 342 } 343 } else if (isUnsignedIntegerFormat(arrayInfo.format)) { 344 if (arrayInfo.itemsize == 4) { 345 // unsigned i32 346 MlirType elementType = signless 347 ? mlirIntegerTypeGet(context, 32) 348 : mlirIntegerTypeUnsignedGet(context, 32); 349 return PyDenseElementsAttribute(contextWrapper->getRef(), 350 bulkLoad(context, 351 mlirDenseElementsAttrUInt32Get, 352 elementType, arrayInfo)); 353 } else if (arrayInfo.itemsize == 8) { 354 // unsigned i64 355 MlirType elementType = signless 356 ? mlirIntegerTypeGet(context, 64) 357 : mlirIntegerTypeUnsignedGet(context, 64); 358 return PyDenseElementsAttribute(contextWrapper->getRef(), 359 bulkLoad(context, 360 mlirDenseElementsAttrUInt64Get, 361 elementType, arrayInfo)); 362 } 363 } 364 365 // TODO: Fall back to string-based get. 366 std::string message = "unimplemented array format conversion from format: "; 367 message.append(arrayInfo.format); 368 throw SetPyError(PyExc_ValueError, message); 369 } 370 371 static PyDenseElementsAttribute getSplat(PyType shapedType, 372 PyAttribute &elementAttr) { 373 auto contextWrapper = 374 PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 375 if (!mlirAttributeIsAInteger(elementAttr) && 376 !mlirAttributeIsAFloat(elementAttr)) { 377 std::string message = "Illegal element type for DenseElementsAttr: "; 378 message.append(py::repr(py::cast(elementAttr))); 379 throw SetPyError(PyExc_ValueError, message); 380 } 381 if (!mlirTypeIsAShaped(shapedType) || 382 !mlirShapedTypeHasStaticShape(shapedType)) { 383 std::string message = 384 "Expected a static ShapedType for the shaped_type parameter: "; 385 message.append(py::repr(py::cast(shapedType))); 386 throw SetPyError(PyExc_ValueError, message); 387 } 388 MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 389 MlirType attrType = mlirAttributeGetType(elementAttr); 390 if (!mlirTypeEqual(shapedElementType, attrType)) { 391 std::string message = 392 "Shaped element type and attribute type must be equal: shaped="; 393 message.append(py::repr(py::cast(shapedType))); 394 message.append(", element="); 395 message.append(py::repr(py::cast(elementAttr))); 396 throw SetPyError(PyExc_ValueError, message); 397 } 398 399 MlirAttribute elements = 400 mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 401 return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 402 } 403 404 intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 405 406 py::buffer_info accessBuffer() { 407 MlirType shapedType = mlirAttributeGetType(*this); 408 MlirType elementType = mlirShapedTypeGetElementType(shapedType); 409 410 if (mlirTypeIsAF32(elementType)) { 411 // f32 412 return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue); 413 } else if (mlirTypeIsAF64(elementType)) { 414 // f64 415 return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue); 416 } else if (mlirTypeIsAInteger(elementType) && 417 mlirIntegerTypeGetWidth(elementType) == 32) { 418 if (mlirIntegerTypeIsSignless(elementType) || 419 mlirIntegerTypeIsSigned(elementType)) { 420 // i32 421 return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value); 422 } else if (mlirIntegerTypeIsUnsigned(elementType)) { 423 // unsigned i32 424 return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value); 425 } 426 } else if (mlirTypeIsAInteger(elementType) && 427 mlirIntegerTypeGetWidth(elementType) == 64) { 428 if (mlirIntegerTypeIsSignless(elementType) || 429 mlirIntegerTypeIsSigned(elementType)) { 430 // i64 431 return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value); 432 } else if (mlirIntegerTypeIsUnsigned(elementType)) { 433 // unsigned i64 434 return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value); 435 } 436 } 437 438 std::string message = "unimplemented array format."; 439 throw SetPyError(PyExc_ValueError, message); 440 } 441 442 static void bindDerived(ClassTy &c) { 443 c.def("__len__", &PyDenseElementsAttribute::dunderLen) 444 .def_static("get", PyDenseElementsAttribute::getFromBuffer, 445 py::arg("array"), py::arg("signless") = true, 446 py::arg("context") = py::none(), 447 "Gets from a buffer or ndarray") 448 .def_static("get_splat", PyDenseElementsAttribute::getSplat, 449 py::arg("shaped_type"), py::arg("element_attr"), 450 "Gets a DenseElementsAttr where all values are the same") 451 .def_property_readonly("is_splat", 452 [](PyDenseElementsAttribute &self) -> bool { 453 return mlirDenseElementsAttrIsSplat(self); 454 }) 455 .def_buffer(&PyDenseElementsAttribute::accessBuffer); 456 } 457 458 private: 459 template <typename ElementTy> 460 static MlirAttribute 461 bulkLoad(MlirContext context, 462 MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *), 463 MlirType mlirElementType, py::buffer_info &arrayInfo) { 464 SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(), 465 arrayInfo.shape.begin() + arrayInfo.ndim); 466 MlirAttribute encodingAttr = mlirAttributeGetNull(); 467 auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), 468 mlirElementType, encodingAttr); 469 intptr_t numElements = arrayInfo.size; 470 const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr); 471 return ctor(shapedType, numElements, contents); 472 } 473 474 static bool isUnsignedIntegerFormat(const std::string &format) { 475 if (format.empty()) 476 return false; 477 char code = format[0]; 478 return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 479 code == 'Q'; 480 } 481 482 static bool isSignedIntegerFormat(const std::string &format) { 483 if (format.empty()) 484 return false; 485 char code = format[0]; 486 return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 487 code == 'q'; 488 } 489 490 template <typename Type> 491 py::buffer_info bufferInfo(MlirType shapedType, 492 Type (*value)(MlirAttribute, intptr_t)) { 493 intptr_t rank = mlirShapedTypeGetRank(shapedType); 494 // Prepare the data for the buffer_info. 495 // Buffer is configured for read-only access below. 496 Type *data = static_cast<Type *>( 497 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 498 // Prepare the shape for the buffer_info. 499 SmallVector<intptr_t, 4> shape; 500 for (intptr_t i = 0; i < rank; ++i) 501 shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 502 // Prepare the strides for the buffer_info. 503 SmallVector<intptr_t, 4> strides; 504 intptr_t strideFactor = 1; 505 for (intptr_t i = 1; i < rank; ++i) { 506 strideFactor = 1; 507 for (intptr_t j = i; j < rank; ++j) { 508 strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 509 } 510 strides.push_back(sizeof(Type) * strideFactor); 511 } 512 strides.push_back(sizeof(Type)); 513 return py::buffer_info(data, sizeof(Type), 514 py::format_descriptor<Type>::format(), rank, shape, 515 strides, /*readonly=*/true); 516 } 517 }; // namespace 518 519 /// Refinement of the PyDenseElementsAttribute for attributes containing integer 520 /// (and boolean) values. Supports element access. 521 class PyDenseIntElementsAttribute 522 : public PyConcreteAttribute<PyDenseIntElementsAttribute, 523 PyDenseElementsAttribute> { 524 public: 525 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 526 static constexpr const char *pyClassName = "DenseIntElementsAttr"; 527 using PyConcreteAttribute::PyConcreteAttribute; 528 529 /// Returns the element at the given linear position. Asserts if the index is 530 /// out of range. 531 py::int_ dunderGetItem(intptr_t pos) { 532 if (pos < 0 || pos >= dunderLen()) { 533 throw SetPyError(PyExc_IndexError, 534 "attempt to access out of bounds element"); 535 } 536 537 MlirType type = mlirAttributeGetType(*this); 538 type = mlirShapedTypeGetElementType(type); 539 assert(mlirTypeIsAInteger(type) && 540 "expected integer element type in dense int elements attribute"); 541 // Dispatch element extraction to an appropriate C function based on the 542 // elemental type of the attribute. py::int_ is implicitly constructible 543 // from any C++ integral type and handles bitwidth correctly. 544 // TODO: consider caching the type properties in the constructor to avoid 545 // querying them on each element access. 546 unsigned width = mlirIntegerTypeGetWidth(type); 547 bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 548 if (isUnsigned) { 549 if (width == 1) { 550 return mlirDenseElementsAttrGetBoolValue(*this, pos); 551 } 552 if (width == 32) { 553 return mlirDenseElementsAttrGetUInt32Value(*this, pos); 554 } 555 if (width == 64) { 556 return mlirDenseElementsAttrGetUInt64Value(*this, pos); 557 } 558 } else { 559 if (width == 1) { 560 return mlirDenseElementsAttrGetBoolValue(*this, pos); 561 } 562 if (width == 32) { 563 return mlirDenseElementsAttrGetInt32Value(*this, pos); 564 } 565 if (width == 64) { 566 return mlirDenseElementsAttrGetInt64Value(*this, pos); 567 } 568 } 569 throw SetPyError(PyExc_TypeError, "Unsupported integer type"); 570 } 571 572 static void bindDerived(ClassTy &c) { 573 c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 574 } 575 }; 576 577 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 578 public: 579 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 580 static constexpr const char *pyClassName = "DictAttr"; 581 using PyConcreteAttribute::PyConcreteAttribute; 582 583 intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 584 585 static void bindDerived(ClassTy &c) { 586 c.def("__len__", &PyDictAttribute::dunderLen); 587 c.def_static( 588 "get", 589 [](py::dict attributes, DefaultingPyMlirContext context) { 590 SmallVector<MlirNamedAttribute> mlirNamedAttributes; 591 mlirNamedAttributes.reserve(attributes.size()); 592 for (auto &it : attributes) { 593 auto &mlir_attr = it.second.cast<PyAttribute &>(); 594 auto name = it.first.cast<std::string>(); 595 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 596 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr), 597 toMlirStringRef(name)), 598 mlir_attr)); 599 } 600 MlirAttribute attr = 601 mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 602 mlirNamedAttributes.data()); 603 return PyDictAttribute(context->getRef(), attr); 604 }, 605 py::arg("value"), py::arg("context") = py::none(), 606 "Gets an uniqued dict attribute"); 607 c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 608 MlirAttribute attr = 609 mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 610 if (mlirAttributeIsNull(attr)) { 611 throw SetPyError(PyExc_KeyError, 612 "attempt to access a non-existent attribute"); 613 } 614 return PyAttribute(self.getContext(), attr); 615 }); 616 c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 617 if (index < 0 || index >= self.dunderLen()) { 618 throw SetPyError(PyExc_IndexError, 619 "attempt to access out of bounds attribute"); 620 } 621 MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 622 return PyNamedAttribute( 623 namedAttr.attribute, 624 std::string(mlirIdentifierStr(namedAttr.name).data)); 625 }); 626 } 627 }; 628 629 /// Refinement of PyDenseElementsAttribute for attributes containing 630 /// floating-point values. Supports element access. 631 class PyDenseFPElementsAttribute 632 : public PyConcreteAttribute<PyDenseFPElementsAttribute, 633 PyDenseElementsAttribute> { 634 public: 635 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 636 static constexpr const char *pyClassName = "DenseFPElementsAttr"; 637 using PyConcreteAttribute::PyConcreteAttribute; 638 639 py::float_ dunderGetItem(intptr_t pos) { 640 if (pos < 0 || pos >= dunderLen()) { 641 throw SetPyError(PyExc_IndexError, 642 "attempt to access out of bounds element"); 643 } 644 645 MlirType type = mlirAttributeGetType(*this); 646 type = mlirShapedTypeGetElementType(type); 647 // Dispatch element extraction to an appropriate C function based on the 648 // elemental type of the attribute. py::float_ is implicitly constructible 649 // from float and double. 650 // TODO: consider caching the type properties in the constructor to avoid 651 // querying them on each element access. 652 if (mlirTypeIsAF32(type)) { 653 return mlirDenseElementsAttrGetFloatValue(*this, pos); 654 } 655 if (mlirTypeIsAF64(type)) { 656 return mlirDenseElementsAttrGetDoubleValue(*this, pos); 657 } 658 throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); 659 } 660 661 static void bindDerived(ClassTy &c) { 662 c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 663 } 664 }; 665 666 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 667 public: 668 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 669 static constexpr const char *pyClassName = "TypeAttr"; 670 using PyConcreteAttribute::PyConcreteAttribute; 671 672 static void bindDerived(ClassTy &c) { 673 c.def_static( 674 "get", 675 [](PyType value, DefaultingPyMlirContext context) { 676 MlirAttribute attr = mlirTypeAttrGet(value.get()); 677 return PyTypeAttribute(context->getRef(), attr); 678 }, 679 py::arg("value"), py::arg("context") = py::none(), 680 "Gets a uniqued Type attribute"); 681 c.def_property_readonly("value", [](PyTypeAttribute &self) { 682 return PyType(self.getContext()->getRef(), 683 mlirTypeAttrGetValue(self.get())); 684 }); 685 } 686 }; 687 688 /// Unit Attribute subclass. Unit attributes don't have values. 689 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 690 public: 691 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 692 static constexpr const char *pyClassName = "UnitAttr"; 693 using PyConcreteAttribute::PyConcreteAttribute; 694 695 static void bindDerived(ClassTy &c) { 696 c.def_static( 697 "get", 698 [](DefaultingPyMlirContext context) { 699 return PyUnitAttribute(context->getRef(), 700 mlirUnitAttrGet(context->get())); 701 }, 702 py::arg("context") = py::none(), "Create a Unit attribute."); 703 } 704 }; 705 706 } // namespace 707 708 void mlir::python::populateIRAttributes(py::module &m) { 709 PyAffineMapAttribute::bind(m); 710 PyArrayAttribute::bind(m); 711 PyArrayAttribute::PyArrayAttributeIterator::bind(m); 712 PyBoolAttribute::bind(m); 713 PyDenseElementsAttribute::bind(m); 714 PyDenseFPElementsAttribute::bind(m); 715 PyDenseIntElementsAttribute::bind(m); 716 PyDictAttribute::bind(m); 717 PyFlatSymbolRefAttribute::bind(m); 718 PyFloatAttribute::bind(m); 719 PyIntegerAttribute::bind(m); 720 PyStringAttribute::bind(m); 721 PyTypeAttribute::bind(m); 722 PyUnitAttribute::bind(m); 723 } 724