1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "IRModule.h" 12 13 #include "PybindUtils.h" 14 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 18 namespace py = pybind11; 19 using namespace mlir; 20 using namespace mlir::python; 21 22 using llvm::Optional; 23 using llvm::SmallVector; 24 using llvm::Twine; 25 26 //------------------------------------------------------------------------------ 27 // Docstrings (trivial, non-duplicated docstrings are included inline). 28 //------------------------------------------------------------------------------ 29 30 static const char kDenseElementsAttrGetDocstring[] = 31 R"(Gets a DenseElementsAttr from a Python buffer or array. 32 33 When `type` is not provided, then some limited type inferencing is done based 34 on the buffer format. Support presently exists for 8/16/32/64 signed and 35 unsigned integers and float16/float32/float64. DenseElementsAttrs of these 36 types can also be converted back to a corresponding buffer. 37 38 For conversions outside of these types, a `type=` must be explicitly provided 39 and the buffer contents must be bit-castable to the MLIR internal 40 representation: 41 42 * Integer types (except for i1): the buffer must be byte aligned to the 43 next byte boundary. 44 * Floating point types: Must be bit-castable to the given floating point 45 size. 46 * i1 (bool): Bit packed into 8bit words where the bit pattern matches a 47 row major ordering. An arbitrary Numpy `bool_` array can be bit packed to 48 this specification with: `np.packbits(ary, axis=None, bitorder='little')`. 49 50 If a single element buffer is passed (or for i1, a single byte with value 0 51 or 255), then a splat will be created. 52 53 Args: 54 array: The array or buffer to convert. 55 signless: If inferring an appropriate MLIR type, use signless types for 56 integers (defaults True). 57 type: Skips inference of the MLIR element type and uses this instead. The 58 storage size must be consistent with the actual contents of the buffer. 59 shape: Overrides the shape of the buffer when constructing the MLIR 60 shaped type. This is needed when the physical and logical shape differ (as 61 for i1). 62 context: Explicit context, if not from context manager. 63 64 Returns: 65 DenseElementsAttr on success. 66 67 Raises: 68 ValueError: If the type of the buffer or array cannot be matched to an MLIR 69 type or if the buffer does not meet expectations. 70 )"; 71 72 namespace { 73 74 static MlirStringRef toMlirStringRef(const std::string &s) { 75 return mlirStringRefCreate(s.data(), s.size()); 76 } 77 78 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 79 public: 80 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 81 static constexpr const char *pyClassName = "AffineMapAttr"; 82 using PyConcreteAttribute::PyConcreteAttribute; 83 84 static void bindDerived(ClassTy &c) { 85 c.def_static( 86 "get", 87 [](PyAffineMap &affineMap) { 88 MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 89 return PyAffineMapAttribute(affineMap.getContext(), attr); 90 }, 91 py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 92 } 93 }; 94 95 template <typename T> 96 static T pyTryCast(py::handle object) { 97 try { 98 return object.cast<T>(); 99 } catch (py::cast_error &err) { 100 std::string msg = 101 std::string( 102 "Invalid attribute when attempting to create an ArrayAttribute (") + 103 err.what() + ")"; 104 throw py::cast_error(msg); 105 } catch (py::reference_cast_error &err) { 106 std::string msg = std::string("Invalid attribute (None?) when attempting " 107 "to create an ArrayAttribute (") + 108 err.what() + ")"; 109 throw py::cast_error(msg); 110 } 111 } 112 113 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 114 public: 115 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 116 static constexpr const char *pyClassName = "ArrayAttr"; 117 using PyConcreteAttribute::PyConcreteAttribute; 118 119 class PyArrayAttributeIterator { 120 public: 121 PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} 122 123 PyArrayAttributeIterator &dunderIter() { return *this; } 124 125 PyAttribute dunderNext() { 126 if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { 127 throw py::stop_iteration(); 128 } 129 return PyAttribute(attr.getContext(), 130 mlirArrayAttrGetElement(attr.get(), nextIndex++)); 131 } 132 133 static void bind(py::module &m) { 134 py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator", 135 py::module_local()) 136 .def("__iter__", &PyArrayAttributeIterator::dunderIter) 137 .def("__next__", &PyArrayAttributeIterator::dunderNext); 138 } 139 140 private: 141 PyAttribute attr; 142 int nextIndex = 0; 143 }; 144 145 PyAttribute getItem(intptr_t i) { 146 return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i)); 147 } 148 149 static void bindDerived(ClassTy &c) { 150 c.def_static( 151 "get", 152 [](py::list attributes, DefaultingPyMlirContext context) { 153 SmallVector<MlirAttribute> mlirAttributes; 154 mlirAttributes.reserve(py::len(attributes)); 155 for (auto attribute : attributes) { 156 mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); 157 } 158 MlirAttribute attr = mlirArrayAttrGet( 159 context->get(), mlirAttributes.size(), mlirAttributes.data()); 160 return PyArrayAttribute(context->getRef(), attr); 161 }, 162 py::arg("attributes"), py::arg("context") = py::none(), 163 "Gets a uniqued Array attribute"); 164 c.def("__getitem__", 165 [](PyArrayAttribute &arr, intptr_t i) { 166 if (i >= mlirArrayAttrGetNumElements(arr)) 167 throw py::index_error("ArrayAttribute index out of range"); 168 return arr.getItem(i); 169 }) 170 .def("__len__", 171 [](const PyArrayAttribute &arr) { 172 return mlirArrayAttrGetNumElements(arr); 173 }) 174 .def("__iter__", [](const PyArrayAttribute &arr) { 175 return PyArrayAttributeIterator(arr); 176 }); 177 c.def("__add__", [](PyArrayAttribute arr, py::list extras) { 178 std::vector<MlirAttribute> attributes; 179 intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); 180 attributes.reserve(numOldElements + py::len(extras)); 181 for (intptr_t i = 0; i < numOldElements; ++i) 182 attributes.push_back(arr.getItem(i)); 183 for (py::handle attr : extras) 184 attributes.push_back(pyTryCast<PyAttribute>(attr)); 185 MlirAttribute arrayAttr = mlirArrayAttrGet( 186 arr.getContext()->get(), attributes.size(), attributes.data()); 187 return PyArrayAttribute(arr.getContext(), arrayAttr); 188 }); 189 } 190 }; 191 192 /// Float Point Attribute subclass - FloatAttr. 193 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 194 public: 195 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 196 static constexpr const char *pyClassName = "FloatAttr"; 197 using PyConcreteAttribute::PyConcreteAttribute; 198 199 static void bindDerived(ClassTy &c) { 200 c.def_static( 201 "get", 202 [](PyType &type, double value, DefaultingPyLocation loc) { 203 MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 204 // TODO: Rework error reporting once diagnostic engine is exposed 205 // in C API. 206 if (mlirAttributeIsNull(attr)) { 207 throw SetPyError(PyExc_ValueError, 208 Twine("invalid '") + 209 py::repr(py::cast(type)).cast<std::string>() + 210 "' and expected floating point type."); 211 } 212 return PyFloatAttribute(type.getContext(), attr); 213 }, 214 py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 215 "Gets an uniqued float point attribute associated to a type"); 216 c.def_static( 217 "get_f32", 218 [](double value, DefaultingPyMlirContext context) { 219 MlirAttribute attr = mlirFloatAttrDoubleGet( 220 context->get(), mlirF32TypeGet(context->get()), value); 221 return PyFloatAttribute(context->getRef(), attr); 222 }, 223 py::arg("value"), py::arg("context") = py::none(), 224 "Gets an uniqued float point attribute associated to a f32 type"); 225 c.def_static( 226 "get_f64", 227 [](double value, DefaultingPyMlirContext context) { 228 MlirAttribute attr = mlirFloatAttrDoubleGet( 229 context->get(), mlirF64TypeGet(context->get()), value); 230 return PyFloatAttribute(context->getRef(), attr); 231 }, 232 py::arg("value"), py::arg("context") = py::none(), 233 "Gets an uniqued float point attribute associated to a f64 type"); 234 c.def_property_readonly( 235 "value", 236 [](PyFloatAttribute &self) { 237 return mlirFloatAttrGetValueDouble(self); 238 }, 239 "Returns the value of the float point attribute"); 240 } 241 }; 242 243 /// Integer Attribute subclass - IntegerAttr. 244 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 245 public: 246 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 247 static constexpr const char *pyClassName = "IntegerAttr"; 248 using PyConcreteAttribute::PyConcreteAttribute; 249 250 static void bindDerived(ClassTy &c) { 251 c.def_static( 252 "get", 253 [](PyType &type, int64_t value) { 254 MlirAttribute attr = mlirIntegerAttrGet(type, value); 255 return PyIntegerAttribute(type.getContext(), attr); 256 }, 257 py::arg("type"), py::arg("value"), 258 "Gets an uniqued integer attribute associated to a type"); 259 c.def_property_readonly( 260 "value", 261 [](PyIntegerAttribute &self) -> py::int_ { 262 MlirType type = mlirAttributeGetType(self); 263 if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) 264 return mlirIntegerAttrGetValueInt(self); 265 if (mlirIntegerTypeIsSigned(type)) 266 return mlirIntegerAttrGetValueSInt(self); 267 return mlirIntegerAttrGetValueUInt(self); 268 }, 269 "Returns the value of the integer attribute"); 270 } 271 }; 272 273 /// Bool Attribute subclass - BoolAttr. 274 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 275 public: 276 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 277 static constexpr const char *pyClassName = "BoolAttr"; 278 using PyConcreteAttribute::PyConcreteAttribute; 279 280 static void bindDerived(ClassTy &c) { 281 c.def_static( 282 "get", 283 [](bool value, DefaultingPyMlirContext context) { 284 MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 285 return PyBoolAttribute(context->getRef(), attr); 286 }, 287 py::arg("value"), py::arg("context") = py::none(), 288 "Gets an uniqued bool attribute"); 289 c.def_property_readonly( 290 "value", 291 [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, 292 "Returns the value of the bool attribute"); 293 } 294 }; 295 296 class PyFlatSymbolRefAttribute 297 : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 298 public: 299 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 300 static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 301 using PyConcreteAttribute::PyConcreteAttribute; 302 303 static void bindDerived(ClassTy &c) { 304 c.def_static( 305 "get", 306 [](std::string value, DefaultingPyMlirContext context) { 307 MlirAttribute attr = 308 mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 309 return PyFlatSymbolRefAttribute(context->getRef(), attr); 310 }, 311 py::arg("value"), py::arg("context") = py::none(), 312 "Gets a uniqued FlatSymbolRef attribute"); 313 c.def_property_readonly( 314 "value", 315 [](PyFlatSymbolRefAttribute &self) { 316 MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 317 return py::str(stringRef.data, stringRef.length); 318 }, 319 "Returns the value of the FlatSymbolRef attribute as a string"); 320 } 321 }; 322 323 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 324 public: 325 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 326 static constexpr const char *pyClassName = "StringAttr"; 327 using PyConcreteAttribute::PyConcreteAttribute; 328 329 static void bindDerived(ClassTy &c) { 330 c.def_static( 331 "get", 332 [](std::string value, DefaultingPyMlirContext context) { 333 MlirAttribute attr = 334 mlirStringAttrGet(context->get(), toMlirStringRef(value)); 335 return PyStringAttribute(context->getRef(), attr); 336 }, 337 py::arg("value"), py::arg("context") = py::none(), 338 "Gets a uniqued string attribute"); 339 c.def_static( 340 "get_typed", 341 [](PyType &type, std::string value) { 342 MlirAttribute attr = 343 mlirStringAttrTypedGet(type, toMlirStringRef(value)); 344 return PyStringAttribute(type.getContext(), attr); 345 }, 346 py::arg("type"), py::arg("value"), 347 "Gets a uniqued string attribute associated to a type"); 348 c.def_property_readonly( 349 "value", 350 [](PyStringAttribute &self) { 351 MlirStringRef stringRef = mlirStringAttrGetValue(self); 352 return py::str(stringRef.data, stringRef.length); 353 }, 354 "Returns the value of the string attribute"); 355 } 356 }; 357 358 // TODO: Support construction of string elements. 359 class PyDenseElementsAttribute 360 : public PyConcreteAttribute<PyDenseElementsAttribute> { 361 public: 362 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 363 static constexpr const char *pyClassName = "DenseElementsAttr"; 364 using PyConcreteAttribute::PyConcreteAttribute; 365 366 static PyDenseElementsAttribute 367 getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType, 368 Optional<std::vector<int64_t>> explicitShape, 369 DefaultingPyMlirContext contextWrapper) { 370 // Request a contiguous view. In exotic cases, this will cause a copy. 371 int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; 372 Py_buffer *view = new Py_buffer(); 373 if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { 374 delete view; 375 throw py::error_already_set(); 376 } 377 py::buffer_info arrayInfo(view); 378 SmallVector<int64_t> shape; 379 if (explicitShape) { 380 shape.append(explicitShape->begin(), explicitShape->end()); 381 } else { 382 shape.append(arrayInfo.shape.begin(), 383 arrayInfo.shape.begin() + arrayInfo.ndim); 384 } 385 386 MlirAttribute encodingAttr = mlirAttributeGetNull(); 387 MlirContext context = contextWrapper->get(); 388 389 // Detect format codes that are suitable for bulk loading. This includes 390 // all byte aligned integer and floating point types up to 8 bytes. 391 // Notably, this excludes, bool (which needs to be bit-packed) and 392 // other exotics which do not have a direct representation in the buffer 393 // protocol (i.e. complex, etc). 394 Optional<MlirType> bulkLoadElementType; 395 if (explicitType) { 396 bulkLoadElementType = *explicitType; 397 } else if (arrayInfo.format == "f") { 398 // f32 399 assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); 400 bulkLoadElementType = mlirF32TypeGet(context); 401 } else if (arrayInfo.format == "d") { 402 // f64 403 assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); 404 bulkLoadElementType = mlirF64TypeGet(context); 405 } else if (arrayInfo.format == "e") { 406 // f16 407 assert(arrayInfo.itemsize == 2 && "mismatched array itemsize"); 408 bulkLoadElementType = mlirF16TypeGet(context); 409 } else if (isSignedIntegerFormat(arrayInfo.format)) { 410 if (arrayInfo.itemsize == 4) { 411 // i32 412 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32) 413 : mlirIntegerTypeSignedGet(context, 32); 414 } else if (arrayInfo.itemsize == 8) { 415 // i64 416 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64) 417 : mlirIntegerTypeSignedGet(context, 64); 418 } else if (arrayInfo.itemsize == 1) { 419 // i8 420 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 421 : mlirIntegerTypeSignedGet(context, 8); 422 } else if (arrayInfo.itemsize == 2) { 423 // i16 424 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16) 425 : mlirIntegerTypeSignedGet(context, 16); 426 } 427 } else if (isUnsignedIntegerFormat(arrayInfo.format)) { 428 if (arrayInfo.itemsize == 4) { 429 // unsigned i32 430 bulkLoadElementType = signless 431 ? mlirIntegerTypeGet(context, 32) 432 : mlirIntegerTypeUnsignedGet(context, 32); 433 } else if (arrayInfo.itemsize == 8) { 434 // unsigned i64 435 bulkLoadElementType = signless 436 ? mlirIntegerTypeGet(context, 64) 437 : mlirIntegerTypeUnsignedGet(context, 64); 438 } else if (arrayInfo.itemsize == 1) { 439 // i8 440 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 441 : mlirIntegerTypeUnsignedGet(context, 8); 442 } else if (arrayInfo.itemsize == 2) { 443 // i16 444 bulkLoadElementType = signless 445 ? mlirIntegerTypeGet(context, 16) 446 : mlirIntegerTypeUnsignedGet(context, 16); 447 } 448 } 449 if (bulkLoadElementType) { 450 auto shapedType = mlirRankedTensorTypeGet( 451 shape.size(), shape.data(), *bulkLoadElementType, encodingAttr); 452 size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize; 453 MlirAttribute attr = mlirDenseElementsAttrRawBufferGet( 454 shapedType, rawBufferSize, arrayInfo.ptr); 455 if (mlirAttributeIsNull(attr)) { 456 throw std::invalid_argument( 457 "DenseElementsAttr could not be constructed from the given buffer. " 458 "This may mean that the Python buffer layout does not match that " 459 "MLIR expected layout and is a bug."); 460 } 461 return PyDenseElementsAttribute(contextWrapper->getRef(), attr); 462 } 463 464 throw std::invalid_argument( 465 std::string("unimplemented array format conversion from format: ") + 466 arrayInfo.format); 467 } 468 469 static PyDenseElementsAttribute getSplat(const PyType &shapedType, 470 PyAttribute &elementAttr) { 471 auto contextWrapper = 472 PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 473 if (!mlirAttributeIsAInteger(elementAttr) && 474 !mlirAttributeIsAFloat(elementAttr)) { 475 std::string message = "Illegal element type for DenseElementsAttr: "; 476 message.append(py::repr(py::cast(elementAttr))); 477 throw SetPyError(PyExc_ValueError, message); 478 } 479 if (!mlirTypeIsAShaped(shapedType) || 480 !mlirShapedTypeHasStaticShape(shapedType)) { 481 std::string message = 482 "Expected a static ShapedType for the shaped_type parameter: "; 483 message.append(py::repr(py::cast(shapedType))); 484 throw SetPyError(PyExc_ValueError, message); 485 } 486 MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 487 MlirType attrType = mlirAttributeGetType(elementAttr); 488 if (!mlirTypeEqual(shapedElementType, attrType)) { 489 std::string message = 490 "Shaped element type and attribute type must be equal: shaped="; 491 message.append(py::repr(py::cast(shapedType))); 492 message.append(", element="); 493 message.append(py::repr(py::cast(elementAttr))); 494 throw SetPyError(PyExc_ValueError, message); 495 } 496 497 MlirAttribute elements = 498 mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 499 return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 500 } 501 502 intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 503 504 py::buffer_info accessBuffer() { 505 if (mlirDenseElementsAttrIsSplat(*this)) { 506 // TODO: Currently crashes the program. 507 // Reported as https://github.com/pybind/pybind11/issues/3336 508 throw std::invalid_argument( 509 "unsupported data type for conversion to Python buffer"); 510 } 511 512 MlirType shapedType = mlirAttributeGetType(*this); 513 MlirType elementType = mlirShapedTypeGetElementType(shapedType); 514 std::string format; 515 516 if (mlirTypeIsAF32(elementType)) { 517 // f32 518 return bufferInfo<float>(shapedType); 519 } 520 if (mlirTypeIsAF64(elementType)) { 521 // f64 522 return bufferInfo<double>(shapedType); 523 } 524 if (mlirTypeIsAF16(elementType)) { 525 // f16 526 return bufferInfo<uint16_t>(shapedType, "e"); 527 } 528 if (mlirTypeIsAInteger(elementType) && 529 mlirIntegerTypeGetWidth(elementType) == 32) { 530 if (mlirIntegerTypeIsSignless(elementType) || 531 mlirIntegerTypeIsSigned(elementType)) { 532 // i32 533 return bufferInfo<int32_t>(shapedType); 534 } 535 if (mlirIntegerTypeIsUnsigned(elementType)) { 536 // unsigned i32 537 return bufferInfo<uint32_t>(shapedType); 538 } 539 } else if (mlirTypeIsAInteger(elementType) && 540 mlirIntegerTypeGetWidth(elementType) == 64) { 541 if (mlirIntegerTypeIsSignless(elementType) || 542 mlirIntegerTypeIsSigned(elementType)) { 543 // i64 544 return bufferInfo<int64_t>(shapedType); 545 } 546 if (mlirIntegerTypeIsUnsigned(elementType)) { 547 // unsigned i64 548 return bufferInfo<uint64_t>(shapedType); 549 } 550 } else if (mlirTypeIsAInteger(elementType) && 551 mlirIntegerTypeGetWidth(elementType) == 8) { 552 if (mlirIntegerTypeIsSignless(elementType) || 553 mlirIntegerTypeIsSigned(elementType)) { 554 // i8 555 return bufferInfo<int8_t>(shapedType); 556 } 557 if (mlirIntegerTypeIsUnsigned(elementType)) { 558 // unsigned i8 559 return bufferInfo<uint8_t>(shapedType); 560 } 561 } else if (mlirTypeIsAInteger(elementType) && 562 mlirIntegerTypeGetWidth(elementType) == 16) { 563 if (mlirIntegerTypeIsSignless(elementType) || 564 mlirIntegerTypeIsSigned(elementType)) { 565 // i16 566 return bufferInfo<int16_t>(shapedType); 567 } 568 if (mlirIntegerTypeIsUnsigned(elementType)) { 569 // unsigned i16 570 return bufferInfo<uint16_t>(shapedType); 571 } 572 } 573 574 // TODO: Currently crashes the program. 575 // Reported as https://github.com/pybind/pybind11/issues/3336 576 throw std::invalid_argument( 577 "unsupported data type for conversion to Python buffer"); 578 } 579 580 static void bindDerived(ClassTy &c) { 581 c.def("__len__", &PyDenseElementsAttribute::dunderLen) 582 .def_static("get", PyDenseElementsAttribute::getFromBuffer, 583 py::arg("array"), py::arg("signless") = true, 584 py::arg("type") = py::none(), py::arg("shape") = py::none(), 585 py::arg("context") = py::none(), 586 kDenseElementsAttrGetDocstring) 587 .def_static("get_splat", PyDenseElementsAttribute::getSplat, 588 py::arg("shaped_type"), py::arg("element_attr"), 589 "Gets a DenseElementsAttr where all values are the same") 590 .def_property_readonly("is_splat", 591 [](PyDenseElementsAttribute &self) -> bool { 592 return mlirDenseElementsAttrIsSplat(self); 593 }) 594 .def_buffer(&PyDenseElementsAttribute::accessBuffer); 595 } 596 597 private: 598 static bool isUnsignedIntegerFormat(const std::string &format) { 599 if (format.empty()) 600 return false; 601 char code = format[0]; 602 return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 603 code == 'Q'; 604 } 605 606 static bool isSignedIntegerFormat(const std::string &format) { 607 if (format.empty()) 608 return false; 609 char code = format[0]; 610 return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 611 code == 'q'; 612 } 613 614 template <typename Type> 615 py::buffer_info bufferInfo(MlirType shapedType, 616 const char *explicitFormat = nullptr) { 617 intptr_t rank = mlirShapedTypeGetRank(shapedType); 618 // Prepare the data for the buffer_info. 619 // Buffer is configured for read-only access below. 620 Type *data = static_cast<Type *>( 621 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 622 // Prepare the shape for the buffer_info. 623 SmallVector<intptr_t, 4> shape; 624 for (intptr_t i = 0; i < rank; ++i) 625 shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 626 // Prepare the strides for the buffer_info. 627 SmallVector<intptr_t, 4> strides; 628 intptr_t strideFactor = 1; 629 for (intptr_t i = 1; i < rank; ++i) { 630 strideFactor = 1; 631 for (intptr_t j = i; j < rank; ++j) { 632 strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 633 } 634 strides.push_back(sizeof(Type) * strideFactor); 635 } 636 strides.push_back(sizeof(Type)); 637 std::string format; 638 if (explicitFormat) { 639 format = explicitFormat; 640 } else { 641 format = py::format_descriptor<Type>::format(); 642 } 643 return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, 644 /*readonly=*/true); 645 } 646 }; // namespace 647 648 /// Refinement of the PyDenseElementsAttribute for attributes containing integer 649 /// (and boolean) values. Supports element access. 650 class PyDenseIntElementsAttribute 651 : public PyConcreteAttribute<PyDenseIntElementsAttribute, 652 PyDenseElementsAttribute> { 653 public: 654 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 655 static constexpr const char *pyClassName = "DenseIntElementsAttr"; 656 using PyConcreteAttribute::PyConcreteAttribute; 657 658 /// Returns the element at the given linear position. Asserts if the index is 659 /// out of range. 660 py::int_ dunderGetItem(intptr_t pos) { 661 if (pos < 0 || pos >= dunderLen()) { 662 throw SetPyError(PyExc_IndexError, 663 "attempt to access out of bounds element"); 664 } 665 666 MlirType type = mlirAttributeGetType(*this); 667 type = mlirShapedTypeGetElementType(type); 668 assert(mlirTypeIsAInteger(type) && 669 "expected integer element type in dense int elements attribute"); 670 // Dispatch element extraction to an appropriate C function based on the 671 // elemental type of the attribute. py::int_ is implicitly constructible 672 // from any C++ integral type and handles bitwidth correctly. 673 // TODO: consider caching the type properties in the constructor to avoid 674 // querying them on each element access. 675 unsigned width = mlirIntegerTypeGetWidth(type); 676 bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 677 if (isUnsigned) { 678 if (width == 1) { 679 return mlirDenseElementsAttrGetBoolValue(*this, pos); 680 } 681 if (width == 8) { 682 return mlirDenseElementsAttrGetUInt8Value(*this, pos); 683 } 684 if (width == 16) { 685 return mlirDenseElementsAttrGetUInt16Value(*this, pos); 686 } 687 if (width == 32) { 688 return mlirDenseElementsAttrGetUInt32Value(*this, pos); 689 } 690 if (width == 64) { 691 return mlirDenseElementsAttrGetUInt64Value(*this, pos); 692 } 693 } else { 694 if (width == 1) { 695 return mlirDenseElementsAttrGetBoolValue(*this, pos); 696 } 697 if (width == 8) { 698 return mlirDenseElementsAttrGetInt8Value(*this, pos); 699 } 700 if (width == 16) { 701 return mlirDenseElementsAttrGetInt16Value(*this, pos); 702 } 703 if (width == 32) { 704 return mlirDenseElementsAttrGetInt32Value(*this, pos); 705 } 706 if (width == 64) { 707 return mlirDenseElementsAttrGetInt64Value(*this, pos); 708 } 709 } 710 throw SetPyError(PyExc_TypeError, "Unsupported integer type"); 711 } 712 713 static void bindDerived(ClassTy &c) { 714 c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 715 } 716 }; 717 718 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 719 public: 720 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 721 static constexpr const char *pyClassName = "DictAttr"; 722 using PyConcreteAttribute::PyConcreteAttribute; 723 724 intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 725 726 bool dunderContains(const std::string &name) { 727 return !mlirAttributeIsNull( 728 mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); 729 } 730 731 static void bindDerived(ClassTy &c) { 732 c.def("__contains__", &PyDictAttribute::dunderContains); 733 c.def("__len__", &PyDictAttribute::dunderLen); 734 c.def_static( 735 "get", 736 [](py::dict attributes, DefaultingPyMlirContext context) { 737 SmallVector<MlirNamedAttribute> mlirNamedAttributes; 738 mlirNamedAttributes.reserve(attributes.size()); 739 for (auto &it : attributes) { 740 auto &mlirAttr = it.second.cast<PyAttribute &>(); 741 auto name = it.first.cast<std::string>(); 742 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 743 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), 744 toMlirStringRef(name)), 745 mlirAttr)); 746 } 747 MlirAttribute attr = 748 mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 749 mlirNamedAttributes.data()); 750 return PyDictAttribute(context->getRef(), attr); 751 }, 752 py::arg("value") = py::dict(), py::arg("context") = py::none(), 753 "Gets an uniqued dict attribute"); 754 c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 755 MlirAttribute attr = 756 mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 757 if (mlirAttributeIsNull(attr)) { 758 throw SetPyError(PyExc_KeyError, 759 "attempt to access a non-existent attribute"); 760 } 761 return PyAttribute(self.getContext(), attr); 762 }); 763 c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 764 if (index < 0 || index >= self.dunderLen()) { 765 throw SetPyError(PyExc_IndexError, 766 "attempt to access out of bounds attribute"); 767 } 768 MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 769 return PyNamedAttribute( 770 namedAttr.attribute, 771 std::string(mlirIdentifierStr(namedAttr.name).data)); 772 }); 773 } 774 }; 775 776 /// Refinement of PyDenseElementsAttribute for attributes containing 777 /// floating-point values. Supports element access. 778 class PyDenseFPElementsAttribute 779 : public PyConcreteAttribute<PyDenseFPElementsAttribute, 780 PyDenseElementsAttribute> { 781 public: 782 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 783 static constexpr const char *pyClassName = "DenseFPElementsAttr"; 784 using PyConcreteAttribute::PyConcreteAttribute; 785 786 py::float_ dunderGetItem(intptr_t pos) { 787 if (pos < 0 || pos >= dunderLen()) { 788 throw SetPyError(PyExc_IndexError, 789 "attempt to access out of bounds element"); 790 } 791 792 MlirType type = mlirAttributeGetType(*this); 793 type = mlirShapedTypeGetElementType(type); 794 // Dispatch element extraction to an appropriate C function based on the 795 // elemental type of the attribute. py::float_ is implicitly constructible 796 // from float and double. 797 // TODO: consider caching the type properties in the constructor to avoid 798 // querying them on each element access. 799 if (mlirTypeIsAF32(type)) { 800 return mlirDenseElementsAttrGetFloatValue(*this, pos); 801 } 802 if (mlirTypeIsAF64(type)) { 803 return mlirDenseElementsAttrGetDoubleValue(*this, pos); 804 } 805 throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); 806 } 807 808 static void bindDerived(ClassTy &c) { 809 c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 810 } 811 }; 812 813 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 814 public: 815 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 816 static constexpr const char *pyClassName = "TypeAttr"; 817 using PyConcreteAttribute::PyConcreteAttribute; 818 819 static void bindDerived(ClassTy &c) { 820 c.def_static( 821 "get", 822 [](PyType value, DefaultingPyMlirContext context) { 823 MlirAttribute attr = mlirTypeAttrGet(value.get()); 824 return PyTypeAttribute(context->getRef(), attr); 825 }, 826 py::arg("value"), py::arg("context") = py::none(), 827 "Gets a uniqued Type attribute"); 828 c.def_property_readonly("value", [](PyTypeAttribute &self) { 829 return PyType(self.getContext()->getRef(), 830 mlirTypeAttrGetValue(self.get())); 831 }); 832 } 833 }; 834 835 /// Unit Attribute subclass. Unit attributes don't have values. 836 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 837 public: 838 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 839 static constexpr const char *pyClassName = "UnitAttr"; 840 using PyConcreteAttribute::PyConcreteAttribute; 841 842 static void bindDerived(ClassTy &c) { 843 c.def_static( 844 "get", 845 [](DefaultingPyMlirContext context) { 846 return PyUnitAttribute(context->getRef(), 847 mlirUnitAttrGet(context->get())); 848 }, 849 py::arg("context") = py::none(), "Create a Unit attribute."); 850 } 851 }; 852 853 } // namespace 854 855 void mlir::python::populateIRAttributes(py::module &m) { 856 PyAffineMapAttribute::bind(m); 857 PyArrayAttribute::bind(m); 858 PyArrayAttribute::PyArrayAttributeIterator::bind(m); 859 PyBoolAttribute::bind(m); 860 PyDenseElementsAttribute::bind(m); 861 PyDenseFPElementsAttribute::bind(m); 862 PyDenseIntElementsAttribute::bind(m); 863 PyDictAttribute::bind(m); 864 PyFlatSymbolRefAttribute::bind(m); 865 PyFloatAttribute::bind(m); 866 PyIntegerAttribute::bind(m); 867 PyStringAttribute::bind(m); 868 PyTypeAttribute::bind(m); 869 PyUnitAttribute::bind(m); 870 } 871