1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "PybindUtils.h" 12 13 #include "mlir-c/BuiltinAttributes.h" 14 #include "mlir-c/BuiltinTypes.h" 15 16 namespace py = pybind11; 17 using namespace mlir; 18 using namespace mlir::python; 19 20 using llvm::SmallVector; 21 using llvm::StringRef; 22 using llvm::Twine; 23 24 namespace { 25 26 static MlirStringRef toMlirStringRef(const std::string &s) { 27 return mlirStringRefCreate(s.data(), s.size()); 28 } 29 30 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 31 public: 32 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 33 static constexpr const char *pyClassName = "AffineMapAttr"; 34 using PyConcreteAttribute::PyConcreteAttribute; 35 36 static void bindDerived(ClassTy &c) { 37 c.def_static( 38 "get", 39 [](PyAffineMap &affineMap) { 40 MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 41 return PyAffineMapAttribute(affineMap.getContext(), attr); 42 }, 43 py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 44 } 45 }; 46 47 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 48 public: 49 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 50 static constexpr const char *pyClassName = "ArrayAttr"; 51 using PyConcreteAttribute::PyConcreteAttribute; 52 53 class PyArrayAttributeIterator { 54 public: 55 PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {} 56 57 PyArrayAttributeIterator &dunderIter() { return *this; } 58 59 PyAttribute dunderNext() { 60 if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { 61 throw py::stop_iteration(); 62 } 63 return PyAttribute(attr.getContext(), 64 mlirArrayAttrGetElement(attr.get(), nextIndex++)); 65 } 66 67 static void bind(py::module &m) { 68 py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator") 69 .def("__iter__", &PyArrayAttributeIterator::dunderIter) 70 .def("__next__", &PyArrayAttributeIterator::dunderNext); 71 } 72 73 private: 74 PyAttribute attr; 75 int nextIndex = 0; 76 }; 77 78 static void bindDerived(ClassTy &c) { 79 c.def_static( 80 "get", 81 [](py::list attributes, DefaultingPyMlirContext context) { 82 SmallVector<MlirAttribute> mlirAttributes; 83 mlirAttributes.reserve(py::len(attributes)); 84 for (auto attribute : attributes) { 85 try { 86 mlirAttributes.push_back(attribute.cast<PyAttribute>()); 87 } catch (py::cast_error &err) { 88 std::string msg = std::string("Invalid attribute when attempting " 89 "to create an ArrayAttribute (") + 90 err.what() + ")"; 91 throw py::cast_error(msg); 92 } catch (py::reference_cast_error &err) { 93 // This exception seems thrown when the value is "None". 94 std::string msg = 95 std::string("Invalid attribute (None?) when attempting to " 96 "create an ArrayAttribute (") + 97 err.what() + ")"; 98 throw py::cast_error(msg); 99 } 100 } 101 MlirAttribute attr = mlirArrayAttrGet( 102 context->get(), mlirAttributes.size(), mlirAttributes.data()); 103 return PyArrayAttribute(context->getRef(), attr); 104 }, 105 py::arg("attributes"), py::arg("context") = py::none(), 106 "Gets a uniqued Array attribute"); 107 c.def("__getitem__", 108 [](PyArrayAttribute &arr, intptr_t i) { 109 if (i >= mlirArrayAttrGetNumElements(arr)) 110 throw py::index_error("ArrayAttribute index out of range"); 111 return PyAttribute(arr.getContext(), 112 mlirArrayAttrGetElement(arr, i)); 113 }) 114 .def("__len__", 115 [](const PyArrayAttribute &arr) { 116 return mlirArrayAttrGetNumElements(arr); 117 }) 118 .def("__iter__", [](const PyArrayAttribute &arr) { 119 return PyArrayAttributeIterator(arr); 120 }); 121 } 122 }; 123 124 /// Float Point Attribute subclass - FloatAttr. 125 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 126 public: 127 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 128 static constexpr const char *pyClassName = "FloatAttr"; 129 using PyConcreteAttribute::PyConcreteAttribute; 130 131 static void bindDerived(ClassTy &c) { 132 c.def_static( 133 "get", 134 [](PyType &type, double value, DefaultingPyLocation loc) { 135 MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 136 // TODO: Rework error reporting once diagnostic engine is exposed 137 // in C API. 138 if (mlirAttributeIsNull(attr)) { 139 throw SetPyError(PyExc_ValueError, 140 Twine("invalid '") + 141 py::repr(py::cast(type)).cast<std::string>() + 142 "' and expected floating point type."); 143 } 144 return PyFloatAttribute(type.getContext(), attr); 145 }, 146 py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 147 "Gets an uniqued float point attribute associated to a type"); 148 c.def_static( 149 "get_f32", 150 [](double value, DefaultingPyMlirContext context) { 151 MlirAttribute attr = mlirFloatAttrDoubleGet( 152 context->get(), mlirF32TypeGet(context->get()), value); 153 return PyFloatAttribute(context->getRef(), attr); 154 }, 155 py::arg("value"), py::arg("context") = py::none(), 156 "Gets an uniqued float point attribute associated to a f32 type"); 157 c.def_static( 158 "get_f64", 159 [](double value, DefaultingPyMlirContext context) { 160 MlirAttribute attr = mlirFloatAttrDoubleGet( 161 context->get(), mlirF64TypeGet(context->get()), value); 162 return PyFloatAttribute(context->getRef(), attr); 163 }, 164 py::arg("value"), py::arg("context") = py::none(), 165 "Gets an uniqued float point attribute associated to a f64 type"); 166 c.def_property_readonly( 167 "value", 168 [](PyFloatAttribute &self) { 169 return mlirFloatAttrGetValueDouble(self); 170 }, 171 "Returns the value of the float point attribute"); 172 } 173 }; 174 175 /// Integer Attribute subclass - IntegerAttr. 176 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 177 public: 178 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 179 static constexpr const char *pyClassName = "IntegerAttr"; 180 using PyConcreteAttribute::PyConcreteAttribute; 181 182 static void bindDerived(ClassTy &c) { 183 c.def_static( 184 "get", 185 [](PyType &type, int64_t value) { 186 MlirAttribute attr = mlirIntegerAttrGet(type, value); 187 return PyIntegerAttribute(type.getContext(), attr); 188 }, 189 py::arg("type"), py::arg("value"), 190 "Gets an uniqued integer attribute associated to a type"); 191 c.def_property_readonly( 192 "value", 193 [](PyIntegerAttribute &self) { 194 return mlirIntegerAttrGetValueInt(self); 195 }, 196 "Returns the value of the integer attribute"); 197 } 198 }; 199 200 /// Bool Attribute subclass - BoolAttr. 201 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 202 public: 203 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 204 static constexpr const char *pyClassName = "BoolAttr"; 205 using PyConcreteAttribute::PyConcreteAttribute; 206 207 static void bindDerived(ClassTy &c) { 208 c.def_static( 209 "get", 210 [](bool value, DefaultingPyMlirContext context) { 211 MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 212 return PyBoolAttribute(context->getRef(), attr); 213 }, 214 py::arg("value"), py::arg("context") = py::none(), 215 "Gets an uniqued bool attribute"); 216 c.def_property_readonly( 217 "value", 218 [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, 219 "Returns the value of the bool attribute"); 220 } 221 }; 222 223 class PyFlatSymbolRefAttribute 224 : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 225 public: 226 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 227 static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 228 using PyConcreteAttribute::PyConcreteAttribute; 229 230 static void bindDerived(ClassTy &c) { 231 c.def_static( 232 "get", 233 [](std::string value, DefaultingPyMlirContext context) { 234 MlirAttribute attr = 235 mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 236 return PyFlatSymbolRefAttribute(context->getRef(), attr); 237 }, 238 py::arg("value"), py::arg("context") = py::none(), 239 "Gets a uniqued FlatSymbolRef attribute"); 240 c.def_property_readonly( 241 "value", 242 [](PyFlatSymbolRefAttribute &self) { 243 MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 244 return py::str(stringRef.data, stringRef.length); 245 }, 246 "Returns the value of the FlatSymbolRef attribute as a string"); 247 } 248 }; 249 250 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 251 public: 252 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 253 static constexpr const char *pyClassName = "StringAttr"; 254 using PyConcreteAttribute::PyConcreteAttribute; 255 256 static void bindDerived(ClassTy &c) { 257 c.def_static( 258 "get", 259 [](std::string value, DefaultingPyMlirContext context) { 260 MlirAttribute attr = 261 mlirStringAttrGet(context->get(), toMlirStringRef(value)); 262 return PyStringAttribute(context->getRef(), attr); 263 }, 264 py::arg("value"), py::arg("context") = py::none(), 265 "Gets a uniqued string attribute"); 266 c.def_static( 267 "get_typed", 268 [](PyType &type, std::string value) { 269 MlirAttribute attr = 270 mlirStringAttrTypedGet(type, toMlirStringRef(value)); 271 return PyStringAttribute(type.getContext(), attr); 272 }, 273 274 "Gets a uniqued string attribute associated to a type"); 275 c.def_property_readonly( 276 "value", 277 [](PyStringAttribute &self) { 278 MlirStringRef stringRef = mlirStringAttrGetValue(self); 279 return py::str(stringRef.data, stringRef.length); 280 }, 281 "Returns the value of the string attribute"); 282 } 283 }; 284 285 // TODO: Support construction of bool elements. 286 // TODO: Support construction of string elements. 287 class PyDenseElementsAttribute 288 : public PyConcreteAttribute<PyDenseElementsAttribute> { 289 public: 290 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 291 static constexpr const char *pyClassName = "DenseElementsAttr"; 292 using PyConcreteAttribute::PyConcreteAttribute; 293 294 static PyDenseElementsAttribute 295 getFromBuffer(py::buffer array, bool signless, 296 DefaultingPyMlirContext contextWrapper) { 297 // Request a contiguous view. In exotic cases, this will cause a copy. 298 int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; 299 Py_buffer *view = new Py_buffer(); 300 if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { 301 delete view; 302 throw py::error_already_set(); 303 } 304 py::buffer_info arrayInfo(view); 305 306 MlirContext context = contextWrapper->get(); 307 // Switch on the types that can be bulk loaded between the Python and 308 // MLIR-C APIs. 309 // See: https://docs.python.org/3/library/struct.html#format-characters 310 if (arrayInfo.format == "f") { 311 // f32 312 assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); 313 return PyDenseElementsAttribute( 314 contextWrapper->getRef(), 315 bulkLoad(context, mlirDenseElementsAttrFloatGet, 316 mlirF32TypeGet(context), arrayInfo)); 317 } else if (arrayInfo.format == "d") { 318 // f64 319 assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); 320 return PyDenseElementsAttribute( 321 contextWrapper->getRef(), 322 bulkLoad(context, mlirDenseElementsAttrDoubleGet, 323 mlirF64TypeGet(context), arrayInfo)); 324 } else if (isSignedIntegerFormat(arrayInfo.format)) { 325 if (arrayInfo.itemsize == 4) { 326 // i32 327 MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) 328 : mlirIntegerTypeSignedGet(context, 32); 329 return PyDenseElementsAttribute(contextWrapper->getRef(), 330 bulkLoad(context, 331 mlirDenseElementsAttrInt32Get, 332 elementType, arrayInfo)); 333 } else if (arrayInfo.itemsize == 8) { 334 // i64 335 MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) 336 : mlirIntegerTypeSignedGet(context, 64); 337 return PyDenseElementsAttribute(contextWrapper->getRef(), 338 bulkLoad(context, 339 mlirDenseElementsAttrInt64Get, 340 elementType, arrayInfo)); 341 } 342 } else if (isUnsignedIntegerFormat(arrayInfo.format)) { 343 if (arrayInfo.itemsize == 4) { 344 // unsigned i32 345 MlirType elementType = signless 346 ? mlirIntegerTypeGet(context, 32) 347 : mlirIntegerTypeUnsignedGet(context, 32); 348 return PyDenseElementsAttribute(contextWrapper->getRef(), 349 bulkLoad(context, 350 mlirDenseElementsAttrUInt32Get, 351 elementType, arrayInfo)); 352 } else if (arrayInfo.itemsize == 8) { 353 // unsigned i64 354 MlirType elementType = signless 355 ? mlirIntegerTypeGet(context, 64) 356 : mlirIntegerTypeUnsignedGet(context, 64); 357 return PyDenseElementsAttribute(contextWrapper->getRef(), 358 bulkLoad(context, 359 mlirDenseElementsAttrUInt64Get, 360 elementType, arrayInfo)); 361 } 362 } 363 364 // TODO: Fall back to string-based get. 365 std::string message = "unimplemented array format conversion from format: "; 366 message.append(arrayInfo.format); 367 throw SetPyError(PyExc_ValueError, message); 368 } 369 370 static PyDenseElementsAttribute getSplat(PyType shapedType, 371 PyAttribute &elementAttr) { 372 auto contextWrapper = 373 PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 374 if (!mlirAttributeIsAInteger(elementAttr) && 375 !mlirAttributeIsAFloat(elementAttr)) { 376 std::string message = "Illegal element type for DenseElementsAttr: "; 377 message.append(py::repr(py::cast(elementAttr))); 378 throw SetPyError(PyExc_ValueError, message); 379 } 380 if (!mlirTypeIsAShaped(shapedType) || 381 !mlirShapedTypeHasStaticShape(shapedType)) { 382 std::string message = 383 "Expected a static ShapedType for the shaped_type parameter: "; 384 message.append(py::repr(py::cast(shapedType))); 385 throw SetPyError(PyExc_ValueError, message); 386 } 387 MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 388 MlirType attrType = mlirAttributeGetType(elementAttr); 389 if (!mlirTypeEqual(shapedElementType, attrType)) { 390 std::string message = 391 "Shaped element type and attribute type must be equal: shaped="; 392 message.append(py::repr(py::cast(shapedType))); 393 message.append(", element="); 394 message.append(py::repr(py::cast(elementAttr))); 395 throw SetPyError(PyExc_ValueError, message); 396 } 397 398 MlirAttribute elements = 399 mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 400 return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 401 } 402 403 intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 404 405 py::buffer_info accessBuffer() { 406 MlirType shapedType = mlirAttributeGetType(*this); 407 MlirType elementType = mlirShapedTypeGetElementType(shapedType); 408 409 if (mlirTypeIsAF32(elementType)) { 410 // f32 411 return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue); 412 } else if (mlirTypeIsAF64(elementType)) { 413 // f64 414 return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue); 415 } else if (mlirTypeIsAInteger(elementType) && 416 mlirIntegerTypeGetWidth(elementType) == 32) { 417 if (mlirIntegerTypeIsSignless(elementType) || 418 mlirIntegerTypeIsSigned(elementType)) { 419 // i32 420 return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value); 421 } else if (mlirIntegerTypeIsUnsigned(elementType)) { 422 // unsigned i32 423 return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value); 424 } 425 } else if (mlirTypeIsAInteger(elementType) && 426 mlirIntegerTypeGetWidth(elementType) == 64) { 427 if (mlirIntegerTypeIsSignless(elementType) || 428 mlirIntegerTypeIsSigned(elementType)) { 429 // i64 430 return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value); 431 } else if (mlirIntegerTypeIsUnsigned(elementType)) { 432 // unsigned i64 433 return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value); 434 } 435 } 436 437 std::string message = "unimplemented array format."; 438 throw SetPyError(PyExc_ValueError, message); 439 } 440 441 static void bindDerived(ClassTy &c) { 442 c.def("__len__", &PyDenseElementsAttribute::dunderLen) 443 .def_static("get", PyDenseElementsAttribute::getFromBuffer, 444 py::arg("array"), py::arg("signless") = true, 445 py::arg("context") = py::none(), 446 "Gets from a buffer or ndarray") 447 .def_static("get_splat", PyDenseElementsAttribute::getSplat, 448 py::arg("shaped_type"), py::arg("element_attr"), 449 "Gets a DenseElementsAttr where all values are the same") 450 .def_property_readonly("is_splat", 451 [](PyDenseElementsAttribute &self) -> bool { 452 return mlirDenseElementsAttrIsSplat(self); 453 }) 454 .def_buffer(&PyDenseElementsAttribute::accessBuffer); 455 } 456 457 private: 458 template <typename ElementTy> 459 static MlirAttribute 460 bulkLoad(MlirContext context, 461 MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *), 462 MlirType mlirElementType, py::buffer_info &arrayInfo) { 463 SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(), 464 arrayInfo.shape.begin() + arrayInfo.ndim); 465 MlirAttribute encodingAttr = mlirAttributeGetNull(); 466 auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), 467 mlirElementType, encodingAttr); 468 intptr_t numElements = arrayInfo.size; 469 const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr); 470 return ctor(shapedType, numElements, contents); 471 } 472 473 static bool isUnsignedIntegerFormat(const std::string &format) { 474 if (format.empty()) 475 return false; 476 char code = format[0]; 477 return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 478 code == 'Q'; 479 } 480 481 static bool isSignedIntegerFormat(const std::string &format) { 482 if (format.empty()) 483 return false; 484 char code = format[0]; 485 return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 486 code == 'q'; 487 } 488 489 template <typename Type> 490 py::buffer_info bufferInfo(MlirType shapedType, 491 Type (*value)(MlirAttribute, intptr_t)) { 492 intptr_t rank = mlirShapedTypeGetRank(shapedType); 493 // Prepare the data for the buffer_info. 494 // Buffer is configured for read-only access below. 495 Type *data = static_cast<Type *>( 496 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 497 // Prepare the shape for the buffer_info. 498 SmallVector<intptr_t, 4> shape; 499 for (intptr_t i = 0; i < rank; ++i) 500 shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 501 // Prepare the strides for the buffer_info. 502 SmallVector<intptr_t, 4> strides; 503 intptr_t strideFactor = 1; 504 for (intptr_t i = 1; i < rank; ++i) { 505 strideFactor = 1; 506 for (intptr_t j = i; j < rank; ++j) { 507 strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 508 } 509 strides.push_back(sizeof(Type) * strideFactor); 510 } 511 strides.push_back(sizeof(Type)); 512 return py::buffer_info(data, sizeof(Type), 513 py::format_descriptor<Type>::format(), rank, shape, 514 strides, /*readonly=*/true); 515 } 516 }; // namespace 517 518 /// Refinement of the PyDenseElementsAttribute for attributes containing integer 519 /// (and boolean) values. Supports element access. 520 class PyDenseIntElementsAttribute 521 : public PyConcreteAttribute<PyDenseIntElementsAttribute, 522 PyDenseElementsAttribute> { 523 public: 524 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 525 static constexpr const char *pyClassName = "DenseIntElementsAttr"; 526 using PyConcreteAttribute::PyConcreteAttribute; 527 528 /// Returns the element at the given linear position. Asserts if the index is 529 /// out of range. 530 py::int_ dunderGetItem(intptr_t pos) { 531 if (pos < 0 || pos >= dunderLen()) { 532 throw SetPyError(PyExc_IndexError, 533 "attempt to access out of bounds element"); 534 } 535 536 MlirType type = mlirAttributeGetType(*this); 537 type = mlirShapedTypeGetElementType(type); 538 assert(mlirTypeIsAInteger(type) && 539 "expected integer element type in dense int elements attribute"); 540 // Dispatch element extraction to an appropriate C function based on the 541 // elemental type of the attribute. py::int_ is implicitly constructible 542 // from any C++ integral type and handles bitwidth correctly. 543 // TODO: consider caching the type properties in the constructor to avoid 544 // querying them on each element access. 545 unsigned width = mlirIntegerTypeGetWidth(type); 546 bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 547 if (isUnsigned) { 548 if (width == 1) { 549 return mlirDenseElementsAttrGetBoolValue(*this, pos); 550 } 551 if (width == 32) { 552 return mlirDenseElementsAttrGetUInt32Value(*this, pos); 553 } 554 if (width == 64) { 555 return mlirDenseElementsAttrGetUInt64Value(*this, pos); 556 } 557 } else { 558 if (width == 1) { 559 return mlirDenseElementsAttrGetBoolValue(*this, pos); 560 } 561 if (width == 32) { 562 return mlirDenseElementsAttrGetInt32Value(*this, pos); 563 } 564 if (width == 64) { 565 return mlirDenseElementsAttrGetInt64Value(*this, pos); 566 } 567 } 568 throw SetPyError(PyExc_TypeError, "Unsupported integer type"); 569 } 570 571 static void bindDerived(ClassTy &c) { 572 c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 573 } 574 }; 575 576 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 577 public: 578 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 579 static constexpr const char *pyClassName = "DictAttr"; 580 using PyConcreteAttribute::PyConcreteAttribute; 581 582 intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 583 584 static void bindDerived(ClassTy &c) { 585 c.def("__len__", &PyDictAttribute::dunderLen); 586 c.def_static( 587 "get", 588 [](py::dict attributes, DefaultingPyMlirContext context) { 589 SmallVector<MlirNamedAttribute> mlirNamedAttributes; 590 mlirNamedAttributes.reserve(attributes.size()); 591 for (auto &it : attributes) { 592 auto &mlir_attr = it.second.cast<PyAttribute &>(); 593 auto name = it.first.cast<std::string>(); 594 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 595 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr), 596 toMlirStringRef(name)), 597 mlir_attr)); 598 } 599 MlirAttribute attr = 600 mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 601 mlirNamedAttributes.data()); 602 return PyDictAttribute(context->getRef(), attr); 603 }, 604 py::arg("value"), py::arg("context") = py::none(), 605 "Gets an uniqued dict attribute"); 606 c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 607 MlirAttribute attr = 608 mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 609 if (mlirAttributeIsNull(attr)) { 610 throw SetPyError(PyExc_KeyError, 611 "attempt to access a non-existent attribute"); 612 } 613 return PyAttribute(self.getContext(), attr); 614 }); 615 c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 616 if (index < 0 || index >= self.dunderLen()) { 617 throw SetPyError(PyExc_IndexError, 618 "attempt to access out of bounds attribute"); 619 } 620 MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 621 return PyNamedAttribute( 622 namedAttr.attribute, 623 std::string(mlirIdentifierStr(namedAttr.name).data)); 624 }); 625 } 626 }; 627 628 /// Refinement of PyDenseElementsAttribute for attributes containing 629 /// floating-point values. Supports element access. 630 class PyDenseFPElementsAttribute 631 : public PyConcreteAttribute<PyDenseFPElementsAttribute, 632 PyDenseElementsAttribute> { 633 public: 634 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 635 static constexpr const char *pyClassName = "DenseFPElementsAttr"; 636 using PyConcreteAttribute::PyConcreteAttribute; 637 638 py::float_ dunderGetItem(intptr_t pos) { 639 if (pos < 0 || pos >= dunderLen()) { 640 throw SetPyError(PyExc_IndexError, 641 "attempt to access out of bounds element"); 642 } 643 644 MlirType type = mlirAttributeGetType(*this); 645 type = mlirShapedTypeGetElementType(type); 646 // Dispatch element extraction to an appropriate C function based on the 647 // elemental type of the attribute. py::float_ is implicitly constructible 648 // from float and double. 649 // TODO: consider caching the type properties in the constructor to avoid 650 // querying them on each element access. 651 if (mlirTypeIsAF32(type)) { 652 return mlirDenseElementsAttrGetFloatValue(*this, pos); 653 } 654 if (mlirTypeIsAF64(type)) { 655 return mlirDenseElementsAttrGetDoubleValue(*this, pos); 656 } 657 throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); 658 } 659 660 static void bindDerived(ClassTy &c) { 661 c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 662 } 663 }; 664 665 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 666 public: 667 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 668 static constexpr const char *pyClassName = "TypeAttr"; 669 using PyConcreteAttribute::PyConcreteAttribute; 670 671 static void bindDerived(ClassTy &c) { 672 c.def_static( 673 "get", 674 [](PyType value, DefaultingPyMlirContext context) { 675 MlirAttribute attr = mlirTypeAttrGet(value.get()); 676 return PyTypeAttribute(context->getRef(), attr); 677 }, 678 py::arg("value"), py::arg("context") = py::none(), 679 "Gets a uniqued Type attribute"); 680 c.def_property_readonly("value", [](PyTypeAttribute &self) { 681 return PyType(self.getContext()->getRef(), 682 mlirTypeAttrGetValue(self.get())); 683 }); 684 } 685 }; 686 687 /// Unit Attribute subclass. Unit attributes don't have values. 688 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 689 public: 690 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 691 static constexpr const char *pyClassName = "UnitAttr"; 692 using PyConcreteAttribute::PyConcreteAttribute; 693 694 static void bindDerived(ClassTy &c) { 695 c.def_static( 696 "get", 697 [](DefaultingPyMlirContext context) { 698 return PyUnitAttribute(context->getRef(), 699 mlirUnitAttrGet(context->get())); 700 }, 701 py::arg("context") = py::none(), "Create a Unit attribute."); 702 } 703 }; 704 705 } // namespace 706 707 void mlir::python::populateIRAttributes(py::module &m) { 708 PyAffineMapAttribute::bind(m); 709 PyArrayAttribute::bind(m); 710 PyArrayAttribute::PyArrayAttributeIterator::bind(m); 711 PyBoolAttribute::bind(m); 712 PyDenseElementsAttribute::bind(m); 713 PyDenseFPElementsAttribute::bind(m); 714 PyDenseIntElementsAttribute::bind(m); 715 PyDictAttribute::bind(m); 716 PyFlatSymbolRefAttribute::bind(m); 717 PyFloatAttribute::bind(m); 718 PyIntegerAttribute::bind(m); 719 PyStringAttribute::bind(m); 720 PyTypeAttribute::bind(m); 721 PyUnitAttribute::bind(m); 722 } 723