1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "PybindUtils.h" 12 13 #include "mlir-c/BuiltinAttributes.h" 14 #include "mlir-c/BuiltinTypes.h" 15 16 namespace py = pybind11; 17 using namespace mlir; 18 using namespace mlir::python; 19 20 using llvm::SmallVector; 21 using llvm::StringRef; 22 using llvm::Twine; 23 24 namespace { 25 26 static MlirStringRef toMlirStringRef(const std::string &s) { 27 return mlirStringRefCreate(s.data(), s.size()); 28 } 29 30 /// CRTP base classes for Python attributes that subclass Attribute and should 31 /// be castable from it (i.e. via something like StringAttr(attr)). 32 /// By default, attribute class hierarchies are one level deep (i.e. a 33 /// concrete attribute class extends PyAttribute); however, intermediate 34 /// python-visible base classes can be modeled by specifying a BaseTy. 35 template <typename DerivedTy, typename BaseTy = PyAttribute> 36 class PyConcreteAttribute : public BaseTy { 37 public: 38 // Derived classes must define statics for: 39 // IsAFunctionTy isaFunction 40 // const char *pyClassName 41 using ClassTy = py::class_<DerivedTy, BaseTy>; 42 using IsAFunctionTy = bool (*)(MlirAttribute); 43 44 PyConcreteAttribute() = default; 45 PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 46 : BaseTy(std::move(contextRef), attr) {} 47 PyConcreteAttribute(PyAttribute &orig) 48 : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} 49 50 static MlirAttribute castFrom(PyAttribute &orig) { 51 if (!DerivedTy::isaFunction(orig)) { 52 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 53 throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") + 54 DerivedTy::pyClassName + 55 " (from " + origRepr + ")"); 56 } 57 return orig; 58 } 59 60 static void bind(py::module &m) { 61 auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol()); 62 cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>()); 63 DerivedTy::bindDerived(cls); 64 } 65 66 /// Implemented by derived classes to add methods to the Python subclass. 67 static void bindDerived(ClassTy &m) {} 68 }; 69 70 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 71 public: 72 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 73 static constexpr const char *pyClassName = "AffineMapAttr"; 74 using PyConcreteAttribute::PyConcreteAttribute; 75 76 static void bindDerived(ClassTy &c) { 77 c.def_static( 78 "get", 79 [](PyAffineMap &affineMap) { 80 MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 81 return PyAffineMapAttribute(affineMap.getContext(), attr); 82 }, 83 py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 84 } 85 }; 86 87 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 88 public: 89 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 90 static constexpr const char *pyClassName = "ArrayAttr"; 91 using PyConcreteAttribute::PyConcreteAttribute; 92 93 class PyArrayAttributeIterator { 94 public: 95 PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {} 96 97 PyArrayAttributeIterator &dunderIter() { return *this; } 98 99 PyAttribute dunderNext() { 100 if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { 101 throw py::stop_iteration(); 102 } 103 return PyAttribute(attr.getContext(), 104 mlirArrayAttrGetElement(attr.get(), nextIndex++)); 105 } 106 107 static void bind(py::module &m) { 108 py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator") 109 .def("__iter__", &PyArrayAttributeIterator::dunderIter) 110 .def("__next__", &PyArrayAttributeIterator::dunderNext); 111 } 112 113 private: 114 PyAttribute attr; 115 int nextIndex = 0; 116 }; 117 118 static void bindDerived(ClassTy &c) { 119 c.def_static( 120 "get", 121 [](py::list attributes, DefaultingPyMlirContext context) { 122 SmallVector<MlirAttribute> mlirAttributes; 123 mlirAttributes.reserve(py::len(attributes)); 124 for (auto attribute : attributes) { 125 try { 126 mlirAttributes.push_back(attribute.cast<PyAttribute>()); 127 } catch (py::cast_error &err) { 128 std::string msg = std::string("Invalid attribute when attempting " 129 "to create an ArrayAttribute (") + 130 err.what() + ")"; 131 throw py::cast_error(msg); 132 } catch (py::reference_cast_error &err) { 133 // This exception seems thrown when the value is "None". 134 std::string msg = 135 std::string("Invalid attribute (None?) when attempting to " 136 "create an ArrayAttribute (") + 137 err.what() + ")"; 138 throw py::cast_error(msg); 139 } 140 } 141 MlirAttribute attr = mlirArrayAttrGet( 142 context->get(), mlirAttributes.size(), mlirAttributes.data()); 143 return PyArrayAttribute(context->getRef(), attr); 144 }, 145 py::arg("attributes"), py::arg("context") = py::none(), 146 "Gets a uniqued Array attribute"); 147 c.def("__getitem__", 148 [](PyArrayAttribute &arr, intptr_t i) { 149 if (i >= mlirArrayAttrGetNumElements(arr)) 150 throw py::index_error("ArrayAttribute index out of range"); 151 return PyAttribute(arr.getContext(), 152 mlirArrayAttrGetElement(arr, i)); 153 }) 154 .def("__len__", 155 [](const PyArrayAttribute &arr) { 156 return mlirArrayAttrGetNumElements(arr); 157 }) 158 .def("__iter__", [](const PyArrayAttribute &arr) { 159 return PyArrayAttributeIterator(arr); 160 }); 161 } 162 }; 163 164 /// Float Point Attribute subclass - FloatAttr. 165 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 166 public: 167 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 168 static constexpr const char *pyClassName = "FloatAttr"; 169 using PyConcreteAttribute::PyConcreteAttribute; 170 171 static void bindDerived(ClassTy &c) { 172 c.def_static( 173 "get", 174 [](PyType &type, double value, DefaultingPyLocation loc) { 175 MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 176 // TODO: Rework error reporting once diagnostic engine is exposed 177 // in C API. 178 if (mlirAttributeIsNull(attr)) { 179 throw SetPyError(PyExc_ValueError, 180 Twine("invalid '") + 181 py::repr(py::cast(type)).cast<std::string>() + 182 "' and expected floating point type."); 183 } 184 return PyFloatAttribute(type.getContext(), attr); 185 }, 186 py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 187 "Gets an uniqued float point attribute associated to a type"); 188 c.def_static( 189 "get_f32", 190 [](double value, DefaultingPyMlirContext context) { 191 MlirAttribute attr = mlirFloatAttrDoubleGet( 192 context->get(), mlirF32TypeGet(context->get()), value); 193 return PyFloatAttribute(context->getRef(), attr); 194 }, 195 py::arg("value"), py::arg("context") = py::none(), 196 "Gets an uniqued float point attribute associated to a f32 type"); 197 c.def_static( 198 "get_f64", 199 [](double value, DefaultingPyMlirContext context) { 200 MlirAttribute attr = mlirFloatAttrDoubleGet( 201 context->get(), mlirF64TypeGet(context->get()), value); 202 return PyFloatAttribute(context->getRef(), attr); 203 }, 204 py::arg("value"), py::arg("context") = py::none(), 205 "Gets an uniqued float point attribute associated to a f64 type"); 206 c.def_property_readonly( 207 "value", 208 [](PyFloatAttribute &self) { 209 return mlirFloatAttrGetValueDouble(self); 210 }, 211 "Returns the value of the float point attribute"); 212 } 213 }; 214 215 /// Integer Attribute subclass - IntegerAttr. 216 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 217 public: 218 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 219 static constexpr const char *pyClassName = "IntegerAttr"; 220 using PyConcreteAttribute::PyConcreteAttribute; 221 222 static void bindDerived(ClassTy &c) { 223 c.def_static( 224 "get", 225 [](PyType &type, int64_t value) { 226 MlirAttribute attr = mlirIntegerAttrGet(type, value); 227 return PyIntegerAttribute(type.getContext(), attr); 228 }, 229 py::arg("type"), py::arg("value"), 230 "Gets an uniqued integer attribute associated to a type"); 231 c.def_property_readonly( 232 "value", 233 [](PyIntegerAttribute &self) { 234 return mlirIntegerAttrGetValueInt(self); 235 }, 236 "Returns the value of the integer attribute"); 237 } 238 }; 239 240 /// Bool Attribute subclass - BoolAttr. 241 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 242 public: 243 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 244 static constexpr const char *pyClassName = "BoolAttr"; 245 using PyConcreteAttribute::PyConcreteAttribute; 246 247 static void bindDerived(ClassTy &c) { 248 c.def_static( 249 "get", 250 [](bool value, DefaultingPyMlirContext context) { 251 MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 252 return PyBoolAttribute(context->getRef(), attr); 253 }, 254 py::arg("value"), py::arg("context") = py::none(), 255 "Gets an uniqued bool attribute"); 256 c.def_property_readonly( 257 "value", 258 [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, 259 "Returns the value of the bool attribute"); 260 } 261 }; 262 263 class PyFlatSymbolRefAttribute 264 : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 265 public: 266 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 267 static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 268 using PyConcreteAttribute::PyConcreteAttribute; 269 270 static void bindDerived(ClassTy &c) { 271 c.def_static( 272 "get", 273 [](std::string value, DefaultingPyMlirContext context) { 274 MlirAttribute attr = 275 mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 276 return PyFlatSymbolRefAttribute(context->getRef(), attr); 277 }, 278 py::arg("value"), py::arg("context") = py::none(), 279 "Gets a uniqued FlatSymbolRef attribute"); 280 c.def_property_readonly( 281 "value", 282 [](PyFlatSymbolRefAttribute &self) { 283 MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 284 return py::str(stringRef.data, stringRef.length); 285 }, 286 "Returns the value of the FlatSymbolRef attribute as a string"); 287 } 288 }; 289 290 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 291 public: 292 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 293 static constexpr const char *pyClassName = "StringAttr"; 294 using PyConcreteAttribute::PyConcreteAttribute; 295 296 static void bindDerived(ClassTy &c) { 297 c.def_static( 298 "get", 299 [](std::string value, DefaultingPyMlirContext context) { 300 MlirAttribute attr = 301 mlirStringAttrGet(context->get(), toMlirStringRef(value)); 302 return PyStringAttribute(context->getRef(), attr); 303 }, 304 py::arg("value"), py::arg("context") = py::none(), 305 "Gets a uniqued string attribute"); 306 c.def_static( 307 "get_typed", 308 [](PyType &type, std::string value) { 309 MlirAttribute attr = 310 mlirStringAttrTypedGet(type, toMlirStringRef(value)); 311 return PyStringAttribute(type.getContext(), attr); 312 }, 313 314 "Gets a uniqued string attribute associated to a type"); 315 c.def_property_readonly( 316 "value", 317 [](PyStringAttribute &self) { 318 MlirStringRef stringRef = mlirStringAttrGetValue(self); 319 return py::str(stringRef.data, stringRef.length); 320 }, 321 "Returns the value of the string attribute"); 322 } 323 }; 324 325 // TODO: Support construction of bool elements. 326 // TODO: Support construction of string elements. 327 class PyDenseElementsAttribute 328 : public PyConcreteAttribute<PyDenseElementsAttribute> { 329 public: 330 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 331 static constexpr const char *pyClassName = "DenseElementsAttr"; 332 using PyConcreteAttribute::PyConcreteAttribute; 333 334 static PyDenseElementsAttribute 335 getFromBuffer(py::buffer array, bool signless, 336 DefaultingPyMlirContext contextWrapper) { 337 // Request a contiguous view. In exotic cases, this will cause a copy. 338 int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; 339 Py_buffer *view = new Py_buffer(); 340 if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { 341 delete view; 342 throw py::error_already_set(); 343 } 344 py::buffer_info arrayInfo(view); 345 346 MlirContext context = contextWrapper->get(); 347 // Switch on the types that can be bulk loaded between the Python and 348 // MLIR-C APIs. 349 // See: https://docs.python.org/3/library/struct.html#format-characters 350 if (arrayInfo.format == "f") { 351 // f32 352 assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); 353 return PyDenseElementsAttribute( 354 contextWrapper->getRef(), 355 bulkLoad(context, mlirDenseElementsAttrFloatGet, 356 mlirF32TypeGet(context), arrayInfo)); 357 } else if (arrayInfo.format == "d") { 358 // f64 359 assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); 360 return PyDenseElementsAttribute( 361 contextWrapper->getRef(), 362 bulkLoad(context, mlirDenseElementsAttrDoubleGet, 363 mlirF64TypeGet(context), arrayInfo)); 364 } else if (isSignedIntegerFormat(arrayInfo.format)) { 365 if (arrayInfo.itemsize == 4) { 366 // i32 367 MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) 368 : mlirIntegerTypeSignedGet(context, 32); 369 return PyDenseElementsAttribute(contextWrapper->getRef(), 370 bulkLoad(context, 371 mlirDenseElementsAttrInt32Get, 372 elementType, arrayInfo)); 373 } else if (arrayInfo.itemsize == 8) { 374 // i64 375 MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) 376 : mlirIntegerTypeSignedGet(context, 64); 377 return PyDenseElementsAttribute(contextWrapper->getRef(), 378 bulkLoad(context, 379 mlirDenseElementsAttrInt64Get, 380 elementType, arrayInfo)); 381 } 382 } else if (isUnsignedIntegerFormat(arrayInfo.format)) { 383 if (arrayInfo.itemsize == 4) { 384 // unsigned i32 385 MlirType elementType = signless 386 ? mlirIntegerTypeGet(context, 32) 387 : mlirIntegerTypeUnsignedGet(context, 32); 388 return PyDenseElementsAttribute(contextWrapper->getRef(), 389 bulkLoad(context, 390 mlirDenseElementsAttrUInt32Get, 391 elementType, arrayInfo)); 392 } else if (arrayInfo.itemsize == 8) { 393 // unsigned i64 394 MlirType elementType = signless 395 ? mlirIntegerTypeGet(context, 64) 396 : mlirIntegerTypeUnsignedGet(context, 64); 397 return PyDenseElementsAttribute(contextWrapper->getRef(), 398 bulkLoad(context, 399 mlirDenseElementsAttrUInt64Get, 400 elementType, arrayInfo)); 401 } 402 } 403 404 // TODO: Fall back to string-based get. 405 std::string message = "unimplemented array format conversion from format: "; 406 message.append(arrayInfo.format); 407 throw SetPyError(PyExc_ValueError, message); 408 } 409 410 static PyDenseElementsAttribute getSplat(PyType shapedType, 411 PyAttribute &elementAttr) { 412 auto contextWrapper = 413 PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 414 if (!mlirAttributeIsAInteger(elementAttr) && 415 !mlirAttributeIsAFloat(elementAttr)) { 416 std::string message = "Illegal element type for DenseElementsAttr: "; 417 message.append(py::repr(py::cast(elementAttr))); 418 throw SetPyError(PyExc_ValueError, message); 419 } 420 if (!mlirTypeIsAShaped(shapedType) || 421 !mlirShapedTypeHasStaticShape(shapedType)) { 422 std::string message = 423 "Expected a static ShapedType for the shaped_type parameter: "; 424 message.append(py::repr(py::cast(shapedType))); 425 throw SetPyError(PyExc_ValueError, message); 426 } 427 MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 428 MlirType attrType = mlirAttributeGetType(elementAttr); 429 if (!mlirTypeEqual(shapedElementType, attrType)) { 430 std::string message = 431 "Shaped element type and attribute type must be equal: shaped="; 432 message.append(py::repr(py::cast(shapedType))); 433 message.append(", element="); 434 message.append(py::repr(py::cast(elementAttr))); 435 throw SetPyError(PyExc_ValueError, message); 436 } 437 438 MlirAttribute elements = 439 mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 440 return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 441 } 442 443 intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 444 445 py::buffer_info accessBuffer() { 446 MlirType shapedType = mlirAttributeGetType(*this); 447 MlirType elementType = mlirShapedTypeGetElementType(shapedType); 448 449 if (mlirTypeIsAF32(elementType)) { 450 // f32 451 return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue); 452 } else if (mlirTypeIsAF64(elementType)) { 453 // f64 454 return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue); 455 } else if (mlirTypeIsAInteger(elementType) && 456 mlirIntegerTypeGetWidth(elementType) == 32) { 457 if (mlirIntegerTypeIsSignless(elementType) || 458 mlirIntegerTypeIsSigned(elementType)) { 459 // i32 460 return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value); 461 } else if (mlirIntegerTypeIsUnsigned(elementType)) { 462 // unsigned i32 463 return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value); 464 } 465 } else if (mlirTypeIsAInteger(elementType) && 466 mlirIntegerTypeGetWidth(elementType) == 64) { 467 if (mlirIntegerTypeIsSignless(elementType) || 468 mlirIntegerTypeIsSigned(elementType)) { 469 // i64 470 return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value); 471 } else if (mlirIntegerTypeIsUnsigned(elementType)) { 472 // unsigned i64 473 return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value); 474 } 475 } 476 477 std::string message = "unimplemented array format."; 478 throw SetPyError(PyExc_ValueError, message); 479 } 480 481 static void bindDerived(ClassTy &c) { 482 c.def("__len__", &PyDenseElementsAttribute::dunderLen) 483 .def_static("get", PyDenseElementsAttribute::getFromBuffer, 484 py::arg("array"), py::arg("signless") = true, 485 py::arg("context") = py::none(), 486 "Gets from a buffer or ndarray") 487 .def_static("get_splat", PyDenseElementsAttribute::getSplat, 488 py::arg("shaped_type"), py::arg("element_attr"), 489 "Gets a DenseElementsAttr where all values are the same") 490 .def_property_readonly("is_splat", 491 [](PyDenseElementsAttribute &self) -> bool { 492 return mlirDenseElementsAttrIsSplat(self); 493 }) 494 .def_buffer(&PyDenseElementsAttribute::accessBuffer); 495 } 496 497 private: 498 template <typename ElementTy> 499 static MlirAttribute 500 bulkLoad(MlirContext context, 501 MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *), 502 MlirType mlirElementType, py::buffer_info &arrayInfo) { 503 SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(), 504 arrayInfo.shape.begin() + arrayInfo.ndim); 505 auto shapedType = 506 mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType); 507 intptr_t numElements = arrayInfo.size; 508 const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr); 509 return ctor(shapedType, numElements, contents); 510 } 511 512 static bool isUnsignedIntegerFormat(const std::string &format) { 513 if (format.empty()) 514 return false; 515 char code = format[0]; 516 return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 517 code == 'Q'; 518 } 519 520 static bool isSignedIntegerFormat(const std::string &format) { 521 if (format.empty()) 522 return false; 523 char code = format[0]; 524 return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 525 code == 'q'; 526 } 527 528 template <typename Type> 529 py::buffer_info bufferInfo(MlirType shapedType, 530 Type (*value)(MlirAttribute, intptr_t)) { 531 intptr_t rank = mlirShapedTypeGetRank(shapedType); 532 // Prepare the data for the buffer_info. 533 // Buffer is configured for read-only access below. 534 Type *data = static_cast<Type *>( 535 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 536 // Prepare the shape for the buffer_info. 537 SmallVector<intptr_t, 4> shape; 538 for (intptr_t i = 0; i < rank; ++i) 539 shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 540 // Prepare the strides for the buffer_info. 541 SmallVector<intptr_t, 4> strides; 542 intptr_t strideFactor = 1; 543 for (intptr_t i = 1; i < rank; ++i) { 544 strideFactor = 1; 545 for (intptr_t j = i; j < rank; ++j) { 546 strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 547 } 548 strides.push_back(sizeof(Type) * strideFactor); 549 } 550 strides.push_back(sizeof(Type)); 551 return py::buffer_info(data, sizeof(Type), 552 py::format_descriptor<Type>::format(), rank, shape, 553 strides, /*readonly=*/true); 554 } 555 }; // namespace 556 557 /// Refinement of the PyDenseElementsAttribute for attributes containing integer 558 /// (and boolean) values. Supports element access. 559 class PyDenseIntElementsAttribute 560 : public PyConcreteAttribute<PyDenseIntElementsAttribute, 561 PyDenseElementsAttribute> { 562 public: 563 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 564 static constexpr const char *pyClassName = "DenseIntElementsAttr"; 565 using PyConcreteAttribute::PyConcreteAttribute; 566 567 /// Returns the element at the given linear position. Asserts if the index is 568 /// out of range. 569 py::int_ dunderGetItem(intptr_t pos) { 570 if (pos < 0 || pos >= dunderLen()) { 571 throw SetPyError(PyExc_IndexError, 572 "attempt to access out of bounds element"); 573 } 574 575 MlirType type = mlirAttributeGetType(*this); 576 type = mlirShapedTypeGetElementType(type); 577 assert(mlirTypeIsAInteger(type) && 578 "expected integer element type in dense int elements attribute"); 579 // Dispatch element extraction to an appropriate C function based on the 580 // elemental type of the attribute. py::int_ is implicitly constructible 581 // from any C++ integral type and handles bitwidth correctly. 582 // TODO: consider caching the type properties in the constructor to avoid 583 // querying them on each element access. 584 unsigned width = mlirIntegerTypeGetWidth(type); 585 bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 586 if (isUnsigned) { 587 if (width == 1) { 588 return mlirDenseElementsAttrGetBoolValue(*this, pos); 589 } 590 if (width == 32) { 591 return mlirDenseElementsAttrGetUInt32Value(*this, pos); 592 } 593 if (width == 64) { 594 return mlirDenseElementsAttrGetUInt64Value(*this, pos); 595 } 596 } else { 597 if (width == 1) { 598 return mlirDenseElementsAttrGetBoolValue(*this, pos); 599 } 600 if (width == 32) { 601 return mlirDenseElementsAttrGetInt32Value(*this, pos); 602 } 603 if (width == 64) { 604 return mlirDenseElementsAttrGetInt64Value(*this, pos); 605 } 606 } 607 throw SetPyError(PyExc_TypeError, "Unsupported integer type"); 608 } 609 610 static void bindDerived(ClassTy &c) { 611 c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 612 } 613 }; 614 615 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 616 public: 617 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 618 static constexpr const char *pyClassName = "DictAttr"; 619 using PyConcreteAttribute::PyConcreteAttribute; 620 621 intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 622 623 static void bindDerived(ClassTy &c) { 624 c.def("__len__", &PyDictAttribute::dunderLen); 625 c.def_static( 626 "get", 627 [](py::dict attributes, DefaultingPyMlirContext context) { 628 SmallVector<MlirNamedAttribute> mlirNamedAttributes; 629 mlirNamedAttributes.reserve(attributes.size()); 630 for (auto &it : attributes) { 631 auto &mlir_attr = it.second.cast<PyAttribute &>(); 632 auto name = it.first.cast<std::string>(); 633 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 634 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr), 635 toMlirStringRef(name)), 636 mlir_attr)); 637 } 638 MlirAttribute attr = 639 mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 640 mlirNamedAttributes.data()); 641 return PyDictAttribute(context->getRef(), attr); 642 }, 643 py::arg("value"), py::arg("context") = py::none(), 644 "Gets an uniqued dict attribute"); 645 c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 646 MlirAttribute attr = 647 mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 648 if (mlirAttributeIsNull(attr)) { 649 throw SetPyError(PyExc_KeyError, 650 "attempt to access a non-existent attribute"); 651 } 652 return PyAttribute(self.getContext(), attr); 653 }); 654 c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 655 if (index < 0 || index >= self.dunderLen()) { 656 throw SetPyError(PyExc_IndexError, 657 "attempt to access out of bounds attribute"); 658 } 659 MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 660 return PyNamedAttribute( 661 namedAttr.attribute, 662 std::string(mlirIdentifierStr(namedAttr.name).data)); 663 }); 664 } 665 }; 666 667 /// Refinement of PyDenseElementsAttribute for attributes containing 668 /// floating-point values. Supports element access. 669 class PyDenseFPElementsAttribute 670 : public PyConcreteAttribute<PyDenseFPElementsAttribute, 671 PyDenseElementsAttribute> { 672 public: 673 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 674 static constexpr const char *pyClassName = "DenseFPElementsAttr"; 675 using PyConcreteAttribute::PyConcreteAttribute; 676 677 py::float_ dunderGetItem(intptr_t pos) { 678 if (pos < 0 || pos >= dunderLen()) { 679 throw SetPyError(PyExc_IndexError, 680 "attempt to access out of bounds element"); 681 } 682 683 MlirType type = mlirAttributeGetType(*this); 684 type = mlirShapedTypeGetElementType(type); 685 // Dispatch element extraction to an appropriate C function based on the 686 // elemental type of the attribute. py::float_ is implicitly constructible 687 // from float and double. 688 // TODO: consider caching the type properties in the constructor to avoid 689 // querying them on each element access. 690 if (mlirTypeIsAF32(type)) { 691 return mlirDenseElementsAttrGetFloatValue(*this, pos); 692 } 693 if (mlirTypeIsAF64(type)) { 694 return mlirDenseElementsAttrGetDoubleValue(*this, pos); 695 } 696 throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); 697 } 698 699 static void bindDerived(ClassTy &c) { 700 c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 701 } 702 }; 703 704 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 705 public: 706 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 707 static constexpr const char *pyClassName = "TypeAttr"; 708 using PyConcreteAttribute::PyConcreteAttribute; 709 710 static void bindDerived(ClassTy &c) { 711 c.def_static( 712 "get", 713 [](PyType value, DefaultingPyMlirContext context) { 714 MlirAttribute attr = mlirTypeAttrGet(value.get()); 715 return PyTypeAttribute(context->getRef(), attr); 716 }, 717 py::arg("value"), py::arg("context") = py::none(), 718 "Gets a uniqued Type attribute"); 719 c.def_property_readonly("value", [](PyTypeAttribute &self) { 720 return PyType(self.getContext()->getRef(), 721 mlirTypeAttrGetValue(self.get())); 722 }); 723 } 724 }; 725 726 /// Unit Attribute subclass. Unit attributes don't have values. 727 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 728 public: 729 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 730 static constexpr const char *pyClassName = "UnitAttr"; 731 using PyConcreteAttribute::PyConcreteAttribute; 732 733 static void bindDerived(ClassTy &c) { 734 c.def_static( 735 "get", 736 [](DefaultingPyMlirContext context) { 737 return PyUnitAttribute(context->getRef(), 738 mlirUnitAttrGet(context->get())); 739 }, 740 py::arg("context") = py::none(), "Create a Unit attribute."); 741 } 742 }; 743 744 } // namespace 745 746 void mlir::python::populateIRAttributes(py::module &m) { 747 PyAffineMapAttribute::bind(m); 748 PyArrayAttribute::bind(m); 749 PyArrayAttribute::PyArrayAttributeIterator::bind(m); 750 PyBoolAttribute::bind(m); 751 PyDenseElementsAttribute::bind(m); 752 PyDenseFPElementsAttribute::bind(m); 753 PyDenseIntElementsAttribute::bind(m); 754 PyDictAttribute::bind(m); 755 PyFlatSymbolRefAttribute::bind(m); 756 PyFloatAttribute::bind(m); 757 PyIntegerAttribute::bind(m); 758 PyStringAttribute::bind(m); 759 PyTypeAttribute::bind(m); 760 PyUnitAttribute::bind(m); 761 } 762