1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "IRModule.h" 12 13 #include "PybindUtils.h" 14 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 18 namespace py = pybind11; 19 using namespace mlir; 20 using namespace mlir::python; 21 22 using llvm::Optional; 23 using llvm::SmallVector; 24 using llvm::Twine; 25 26 //------------------------------------------------------------------------------ 27 // Docstrings (trivial, non-duplicated docstrings are included inline). 28 //------------------------------------------------------------------------------ 29 30 static const char kDenseElementsAttrGetDocstring[] = 31 R"(Gets a DenseElementsAttr from a Python buffer or array. 32 33 When `type` is not provided, then some limited type inferencing is done based 34 on the buffer format. Support presently exists for 8/16/32/64 signed and 35 unsigned integers and float16/float32/float64. DenseElementsAttrs of these 36 types can also be converted back to a corresponding buffer. 37 38 For conversions outside of these types, a `type=` must be explicitly provided 39 and the buffer contents must be bit-castable to the MLIR internal 40 representation: 41 42 * Integer types (except for i1): the buffer must be byte aligned to the 43 next byte boundary. 44 * Floating point types: Must be bit-castable to the given floating point 45 size. 46 * i1 (bool): Bit packed into 8bit words where the bit pattern matches a 47 row major ordering. An arbitrary Numpy `bool_` array can be bit packed to 48 this specification with: `np.packbits(ary, axis=None, bitorder='little')`. 49 50 If a single element buffer is passed (or for i1, a single byte with value 0 51 or 255), then a splat will be created. 52 53 Args: 54 array: The array or buffer to convert. 55 signless: If inferring an appropriate MLIR type, use signless types for 56 integers (defaults True). 57 type: Skips inference of the MLIR element type and uses this instead. The 58 storage size must be consistent with the actual contents of the buffer. 59 shape: Overrides the shape of the buffer when constructing the MLIR 60 shaped type. This is needed when the physical and logical shape differ (as 61 for i1). 62 context: Explicit context, if not from context manager. 63 64 Returns: 65 DenseElementsAttr on success. 66 67 Raises: 68 ValueError: If the type of the buffer or array cannot be matched to an MLIR 69 type or if the buffer does not meet expectations. 70 )"; 71 72 namespace { 73 74 static MlirStringRef toMlirStringRef(const std::string &s) { 75 return mlirStringRefCreate(s.data(), s.size()); 76 } 77 78 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 79 public: 80 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 81 static constexpr const char *pyClassName = "AffineMapAttr"; 82 using PyConcreteAttribute::PyConcreteAttribute; 83 84 static void bindDerived(ClassTy &c) { 85 c.def_static( 86 "get", 87 [](PyAffineMap &affineMap) { 88 MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 89 return PyAffineMapAttribute(affineMap.getContext(), attr); 90 }, 91 py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 92 } 93 }; 94 95 template <typename T> 96 static T pyTryCast(py::handle object) { 97 try { 98 return object.cast<T>(); 99 } catch (py::cast_error &err) { 100 std::string msg = 101 std::string( 102 "Invalid attribute when attempting to create an ArrayAttribute (") + 103 err.what() + ")"; 104 throw py::cast_error(msg); 105 } catch (py::reference_cast_error &err) { 106 std::string msg = std::string("Invalid attribute (None?) when attempting " 107 "to create an ArrayAttribute (") + 108 err.what() + ")"; 109 throw py::cast_error(msg); 110 } 111 } 112 113 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 114 public: 115 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 116 static constexpr const char *pyClassName = "ArrayAttr"; 117 using PyConcreteAttribute::PyConcreteAttribute; 118 119 class PyArrayAttributeIterator { 120 public: 121 PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} 122 123 PyArrayAttributeIterator &dunderIter() { return *this; } 124 125 PyAttribute dunderNext() { 126 if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { 127 throw py::stop_iteration(); 128 } 129 return PyAttribute(attr.getContext(), 130 mlirArrayAttrGetElement(attr.get(), nextIndex++)); 131 } 132 133 static void bind(py::module &m) { 134 py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator", 135 py::module_local()) 136 .def("__iter__", &PyArrayAttributeIterator::dunderIter) 137 .def("__next__", &PyArrayAttributeIterator::dunderNext); 138 } 139 140 private: 141 PyAttribute attr; 142 int nextIndex = 0; 143 }; 144 145 PyAttribute getItem(intptr_t i) { 146 return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i)); 147 } 148 149 static void bindDerived(ClassTy &c) { 150 c.def_static( 151 "get", 152 [](py::list attributes, DefaultingPyMlirContext context) { 153 SmallVector<MlirAttribute> mlirAttributes; 154 mlirAttributes.reserve(py::len(attributes)); 155 for (auto attribute : attributes) { 156 mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); 157 } 158 MlirAttribute attr = mlirArrayAttrGet( 159 context->get(), mlirAttributes.size(), mlirAttributes.data()); 160 return PyArrayAttribute(context->getRef(), attr); 161 }, 162 py::arg("attributes"), py::arg("context") = py::none(), 163 "Gets a uniqued Array attribute"); 164 c.def("__getitem__", 165 [](PyArrayAttribute &arr, intptr_t i) { 166 if (i >= mlirArrayAttrGetNumElements(arr)) 167 throw py::index_error("ArrayAttribute index out of range"); 168 return arr.getItem(i); 169 }) 170 .def("__len__", 171 [](const PyArrayAttribute &arr) { 172 return mlirArrayAttrGetNumElements(arr); 173 }) 174 .def("__iter__", [](const PyArrayAttribute &arr) { 175 return PyArrayAttributeIterator(arr); 176 }); 177 c.def("__add__", [](PyArrayAttribute arr, py::list extras) { 178 std::vector<MlirAttribute> attributes; 179 intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); 180 attributes.reserve(numOldElements + py::len(extras)); 181 for (intptr_t i = 0; i < numOldElements; ++i) 182 attributes.push_back(arr.getItem(i)); 183 for (py::handle attr : extras) 184 attributes.push_back(pyTryCast<PyAttribute>(attr)); 185 MlirAttribute arrayAttr = mlirArrayAttrGet( 186 arr.getContext()->get(), attributes.size(), attributes.data()); 187 return PyArrayAttribute(arr.getContext(), arrayAttr); 188 }); 189 } 190 }; 191 192 /// Float Point Attribute subclass - FloatAttr. 193 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 194 public: 195 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 196 static constexpr const char *pyClassName = "FloatAttr"; 197 using PyConcreteAttribute::PyConcreteAttribute; 198 199 static void bindDerived(ClassTy &c) { 200 c.def_static( 201 "get", 202 [](PyType &type, double value, DefaultingPyLocation loc) { 203 MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 204 // TODO: Rework error reporting once diagnostic engine is exposed 205 // in C API. 206 if (mlirAttributeIsNull(attr)) { 207 throw SetPyError(PyExc_ValueError, 208 Twine("invalid '") + 209 py::repr(py::cast(type)).cast<std::string>() + 210 "' and expected floating point type."); 211 } 212 return PyFloatAttribute(type.getContext(), attr); 213 }, 214 py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 215 "Gets an uniqued float point attribute associated to a type"); 216 c.def_static( 217 "get_f32", 218 [](double value, DefaultingPyMlirContext context) { 219 MlirAttribute attr = mlirFloatAttrDoubleGet( 220 context->get(), mlirF32TypeGet(context->get()), value); 221 return PyFloatAttribute(context->getRef(), attr); 222 }, 223 py::arg("value"), py::arg("context") = py::none(), 224 "Gets an uniqued float point attribute associated to a f32 type"); 225 c.def_static( 226 "get_f64", 227 [](double value, DefaultingPyMlirContext context) { 228 MlirAttribute attr = mlirFloatAttrDoubleGet( 229 context->get(), mlirF64TypeGet(context->get()), value); 230 return PyFloatAttribute(context->getRef(), attr); 231 }, 232 py::arg("value"), py::arg("context") = py::none(), 233 "Gets an uniqued float point attribute associated to a f64 type"); 234 c.def_property_readonly( 235 "value", 236 [](PyFloatAttribute &self) { 237 return mlirFloatAttrGetValueDouble(self); 238 }, 239 "Returns the value of the float point attribute"); 240 } 241 }; 242 243 /// Integer Attribute subclass - IntegerAttr. 244 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 245 public: 246 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 247 static constexpr const char *pyClassName = "IntegerAttr"; 248 using PyConcreteAttribute::PyConcreteAttribute; 249 250 static void bindDerived(ClassTy &c) { 251 c.def_static( 252 "get", 253 [](PyType &type, int64_t value) { 254 MlirAttribute attr = mlirIntegerAttrGet(type, value); 255 return PyIntegerAttribute(type.getContext(), attr); 256 }, 257 py::arg("type"), py::arg("value"), 258 "Gets an uniqued integer attribute associated to a type"); 259 c.def_property_readonly( 260 "value", 261 [](PyIntegerAttribute &self) -> py::int_ { 262 MlirType type = mlirAttributeGetType(self); 263 if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) 264 return mlirIntegerAttrGetValueInt(self); 265 if (mlirIntegerTypeIsSigned(type)) 266 return mlirIntegerAttrGetValueSInt(self); 267 return mlirIntegerAttrGetValueUInt(self); 268 }, 269 "Returns the value of the integer attribute"); 270 } 271 }; 272 273 /// Bool Attribute subclass - BoolAttr. 274 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 275 public: 276 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 277 static constexpr const char *pyClassName = "BoolAttr"; 278 using PyConcreteAttribute::PyConcreteAttribute; 279 280 static void bindDerived(ClassTy &c) { 281 c.def_static( 282 "get", 283 [](bool value, DefaultingPyMlirContext context) { 284 MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 285 return PyBoolAttribute(context->getRef(), attr); 286 }, 287 py::arg("value"), py::arg("context") = py::none(), 288 "Gets an uniqued bool attribute"); 289 c.def_property_readonly( 290 "value", 291 [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, 292 "Returns the value of the bool attribute"); 293 } 294 }; 295 296 class PyFlatSymbolRefAttribute 297 : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 298 public: 299 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 300 static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 301 using PyConcreteAttribute::PyConcreteAttribute; 302 303 static void bindDerived(ClassTy &c) { 304 c.def_static( 305 "get", 306 [](std::string value, DefaultingPyMlirContext context) { 307 MlirAttribute attr = 308 mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 309 return PyFlatSymbolRefAttribute(context->getRef(), attr); 310 }, 311 py::arg("value"), py::arg("context") = py::none(), 312 "Gets a uniqued FlatSymbolRef attribute"); 313 c.def_property_readonly( 314 "value", 315 [](PyFlatSymbolRefAttribute &self) { 316 MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 317 return py::str(stringRef.data, stringRef.length); 318 }, 319 "Returns the value of the FlatSymbolRef attribute as a string"); 320 } 321 }; 322 323 class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> { 324 public: 325 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; 326 static constexpr const char *pyClassName = "OpaqueAttr"; 327 using PyConcreteAttribute::PyConcreteAttribute; 328 329 static void bindDerived(ClassTy &c) { 330 c.def_static( 331 "get", 332 [](std::string dialectNamespace, py::buffer buffer, PyType &type, 333 DefaultingPyMlirContext context) { 334 const py::buffer_info bufferInfo = buffer.request(); 335 intptr_t bufferSize = bufferInfo.size; 336 MlirAttribute attr = mlirOpaqueAttrGet( 337 context->get(), toMlirStringRef(dialectNamespace), bufferSize, 338 static_cast<char *>(bufferInfo.ptr), type); 339 return PyOpaqueAttribute(context->getRef(), attr); 340 }, 341 py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"), 342 py::arg("context") = py::none(), "Gets an Opaque attribute."); 343 c.def_property_readonly( 344 "dialect_namespace", 345 [](PyOpaqueAttribute &self) { 346 MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); 347 return py::str(stringRef.data, stringRef.length); 348 }, 349 "Returns the dialect namespace for the Opaque attribute as a string"); 350 c.def_property_readonly( 351 "data", 352 [](PyOpaqueAttribute &self) { 353 MlirStringRef stringRef = mlirOpaqueAttrGetData(self); 354 return py::str(stringRef.data, stringRef.length); 355 }, 356 "Returns the data for the Opaqued attributes as a string"); 357 } 358 }; 359 360 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 361 public: 362 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 363 static constexpr const char *pyClassName = "StringAttr"; 364 using PyConcreteAttribute::PyConcreteAttribute; 365 366 static void bindDerived(ClassTy &c) { 367 c.def_static( 368 "get", 369 [](std::string value, DefaultingPyMlirContext context) { 370 MlirAttribute attr = 371 mlirStringAttrGet(context->get(), toMlirStringRef(value)); 372 return PyStringAttribute(context->getRef(), attr); 373 }, 374 py::arg("value"), py::arg("context") = py::none(), 375 "Gets a uniqued string attribute"); 376 c.def_static( 377 "get_typed", 378 [](PyType &type, std::string value) { 379 MlirAttribute attr = 380 mlirStringAttrTypedGet(type, toMlirStringRef(value)); 381 return PyStringAttribute(type.getContext(), attr); 382 }, 383 py::arg("type"), py::arg("value"), 384 "Gets a uniqued string attribute associated to a type"); 385 c.def_property_readonly( 386 "value", 387 [](PyStringAttribute &self) { 388 MlirStringRef stringRef = mlirStringAttrGetValue(self); 389 return py::str(stringRef.data, stringRef.length); 390 }, 391 "Returns the value of the string attribute"); 392 } 393 }; 394 395 // TODO: Support construction of string elements. 396 class PyDenseElementsAttribute 397 : public PyConcreteAttribute<PyDenseElementsAttribute> { 398 public: 399 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 400 static constexpr const char *pyClassName = "DenseElementsAttr"; 401 using PyConcreteAttribute::PyConcreteAttribute; 402 403 static PyDenseElementsAttribute 404 getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType, 405 Optional<std::vector<int64_t>> explicitShape, 406 DefaultingPyMlirContext contextWrapper) { 407 // Request a contiguous view. In exotic cases, this will cause a copy. 408 int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; 409 Py_buffer *view = new Py_buffer(); 410 if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { 411 delete view; 412 throw py::error_already_set(); 413 } 414 py::buffer_info arrayInfo(view); 415 SmallVector<int64_t> shape; 416 if (explicitShape) { 417 shape.append(explicitShape->begin(), explicitShape->end()); 418 } else { 419 shape.append(arrayInfo.shape.begin(), 420 arrayInfo.shape.begin() + arrayInfo.ndim); 421 } 422 423 MlirAttribute encodingAttr = mlirAttributeGetNull(); 424 MlirContext context = contextWrapper->get(); 425 426 // Detect format codes that are suitable for bulk loading. This includes 427 // all byte aligned integer and floating point types up to 8 bytes. 428 // Notably, this excludes, bool (which needs to be bit-packed) and 429 // other exotics which do not have a direct representation in the buffer 430 // protocol (i.e. complex, etc). 431 Optional<MlirType> bulkLoadElementType; 432 if (explicitType) { 433 bulkLoadElementType = *explicitType; 434 } else if (arrayInfo.format == "f") { 435 // f32 436 assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); 437 bulkLoadElementType = mlirF32TypeGet(context); 438 } else if (arrayInfo.format == "d") { 439 // f64 440 assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); 441 bulkLoadElementType = mlirF64TypeGet(context); 442 } else if (arrayInfo.format == "e") { 443 // f16 444 assert(arrayInfo.itemsize == 2 && "mismatched array itemsize"); 445 bulkLoadElementType = mlirF16TypeGet(context); 446 } else if (isSignedIntegerFormat(arrayInfo.format)) { 447 if (arrayInfo.itemsize == 4) { 448 // i32 449 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32) 450 : mlirIntegerTypeSignedGet(context, 32); 451 } else if (arrayInfo.itemsize == 8) { 452 // i64 453 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64) 454 : mlirIntegerTypeSignedGet(context, 64); 455 } else if (arrayInfo.itemsize == 1) { 456 // i8 457 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 458 : mlirIntegerTypeSignedGet(context, 8); 459 } else if (arrayInfo.itemsize == 2) { 460 // i16 461 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16) 462 : mlirIntegerTypeSignedGet(context, 16); 463 } 464 } else if (isUnsignedIntegerFormat(arrayInfo.format)) { 465 if (arrayInfo.itemsize == 4) { 466 // unsigned i32 467 bulkLoadElementType = signless 468 ? mlirIntegerTypeGet(context, 32) 469 : mlirIntegerTypeUnsignedGet(context, 32); 470 } else if (arrayInfo.itemsize == 8) { 471 // unsigned i64 472 bulkLoadElementType = signless 473 ? mlirIntegerTypeGet(context, 64) 474 : mlirIntegerTypeUnsignedGet(context, 64); 475 } else if (arrayInfo.itemsize == 1) { 476 // i8 477 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 478 : mlirIntegerTypeUnsignedGet(context, 8); 479 } else if (arrayInfo.itemsize == 2) { 480 // i16 481 bulkLoadElementType = signless 482 ? mlirIntegerTypeGet(context, 16) 483 : mlirIntegerTypeUnsignedGet(context, 16); 484 } 485 } 486 if (bulkLoadElementType) { 487 auto shapedType = mlirRankedTensorTypeGet( 488 shape.size(), shape.data(), *bulkLoadElementType, encodingAttr); 489 size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize; 490 MlirAttribute attr = mlirDenseElementsAttrRawBufferGet( 491 shapedType, rawBufferSize, arrayInfo.ptr); 492 if (mlirAttributeIsNull(attr)) { 493 throw std::invalid_argument( 494 "DenseElementsAttr could not be constructed from the given buffer. " 495 "This may mean that the Python buffer layout does not match that " 496 "MLIR expected layout and is a bug."); 497 } 498 return PyDenseElementsAttribute(contextWrapper->getRef(), attr); 499 } 500 501 throw std::invalid_argument( 502 std::string("unimplemented array format conversion from format: ") + 503 arrayInfo.format); 504 } 505 506 static PyDenseElementsAttribute getSplat(const PyType &shapedType, 507 PyAttribute &elementAttr) { 508 auto contextWrapper = 509 PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 510 if (!mlirAttributeIsAInteger(elementAttr) && 511 !mlirAttributeIsAFloat(elementAttr)) { 512 std::string message = "Illegal element type for DenseElementsAttr: "; 513 message.append(py::repr(py::cast(elementAttr))); 514 throw SetPyError(PyExc_ValueError, message); 515 } 516 if (!mlirTypeIsAShaped(shapedType) || 517 !mlirShapedTypeHasStaticShape(shapedType)) { 518 std::string message = 519 "Expected a static ShapedType for the shaped_type parameter: "; 520 message.append(py::repr(py::cast(shapedType))); 521 throw SetPyError(PyExc_ValueError, message); 522 } 523 MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 524 MlirType attrType = mlirAttributeGetType(elementAttr); 525 if (!mlirTypeEqual(shapedElementType, attrType)) { 526 std::string message = 527 "Shaped element type and attribute type must be equal: shaped="; 528 message.append(py::repr(py::cast(shapedType))); 529 message.append(", element="); 530 message.append(py::repr(py::cast(elementAttr))); 531 throw SetPyError(PyExc_ValueError, message); 532 } 533 534 MlirAttribute elements = 535 mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 536 return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 537 } 538 539 intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 540 541 py::buffer_info accessBuffer() { 542 if (mlirDenseElementsAttrIsSplat(*this)) { 543 // TODO: Currently crashes the program. 544 // Reported as https://github.com/pybind/pybind11/issues/3336 545 throw std::invalid_argument( 546 "unsupported data type for conversion to Python buffer"); 547 } 548 549 MlirType shapedType = mlirAttributeGetType(*this); 550 MlirType elementType = mlirShapedTypeGetElementType(shapedType); 551 std::string format; 552 553 if (mlirTypeIsAF32(elementType)) { 554 // f32 555 return bufferInfo<float>(shapedType); 556 } 557 if (mlirTypeIsAF64(elementType)) { 558 // f64 559 return bufferInfo<double>(shapedType); 560 } 561 if (mlirTypeIsAF16(elementType)) { 562 // f16 563 return bufferInfo<uint16_t>(shapedType, "e"); 564 } 565 if (mlirTypeIsAInteger(elementType) && 566 mlirIntegerTypeGetWidth(elementType) == 32) { 567 if (mlirIntegerTypeIsSignless(elementType) || 568 mlirIntegerTypeIsSigned(elementType)) { 569 // i32 570 return bufferInfo<int32_t>(shapedType); 571 } 572 if (mlirIntegerTypeIsUnsigned(elementType)) { 573 // unsigned i32 574 return bufferInfo<uint32_t>(shapedType); 575 } 576 } else if (mlirTypeIsAInteger(elementType) && 577 mlirIntegerTypeGetWidth(elementType) == 64) { 578 if (mlirIntegerTypeIsSignless(elementType) || 579 mlirIntegerTypeIsSigned(elementType)) { 580 // i64 581 return bufferInfo<int64_t>(shapedType); 582 } 583 if (mlirIntegerTypeIsUnsigned(elementType)) { 584 // unsigned i64 585 return bufferInfo<uint64_t>(shapedType); 586 } 587 } else if (mlirTypeIsAInteger(elementType) && 588 mlirIntegerTypeGetWidth(elementType) == 8) { 589 if (mlirIntegerTypeIsSignless(elementType) || 590 mlirIntegerTypeIsSigned(elementType)) { 591 // i8 592 return bufferInfo<int8_t>(shapedType); 593 } 594 if (mlirIntegerTypeIsUnsigned(elementType)) { 595 // unsigned i8 596 return bufferInfo<uint8_t>(shapedType); 597 } 598 } else if (mlirTypeIsAInteger(elementType) && 599 mlirIntegerTypeGetWidth(elementType) == 16) { 600 if (mlirIntegerTypeIsSignless(elementType) || 601 mlirIntegerTypeIsSigned(elementType)) { 602 // i16 603 return bufferInfo<int16_t>(shapedType); 604 } 605 if (mlirIntegerTypeIsUnsigned(elementType)) { 606 // unsigned i16 607 return bufferInfo<uint16_t>(shapedType); 608 } 609 } 610 611 // TODO: Currently crashes the program. 612 // Reported as https://github.com/pybind/pybind11/issues/3336 613 throw std::invalid_argument( 614 "unsupported data type for conversion to Python buffer"); 615 } 616 617 static void bindDerived(ClassTy &c) { 618 c.def("__len__", &PyDenseElementsAttribute::dunderLen) 619 .def_static("get", PyDenseElementsAttribute::getFromBuffer, 620 py::arg("array"), py::arg("signless") = true, 621 py::arg("type") = py::none(), py::arg("shape") = py::none(), 622 py::arg("context") = py::none(), 623 kDenseElementsAttrGetDocstring) 624 .def_static("get_splat", PyDenseElementsAttribute::getSplat, 625 py::arg("shaped_type"), py::arg("element_attr"), 626 "Gets a DenseElementsAttr where all values are the same") 627 .def_property_readonly("is_splat", 628 [](PyDenseElementsAttribute &self) -> bool { 629 return mlirDenseElementsAttrIsSplat(self); 630 }) 631 .def_buffer(&PyDenseElementsAttribute::accessBuffer); 632 } 633 634 private: 635 static bool isUnsignedIntegerFormat(const std::string &format) { 636 if (format.empty()) 637 return false; 638 char code = format[0]; 639 return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 640 code == 'Q'; 641 } 642 643 static bool isSignedIntegerFormat(const std::string &format) { 644 if (format.empty()) 645 return false; 646 char code = format[0]; 647 return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 648 code == 'q'; 649 } 650 651 template <typename Type> 652 py::buffer_info bufferInfo(MlirType shapedType, 653 const char *explicitFormat = nullptr) { 654 intptr_t rank = mlirShapedTypeGetRank(shapedType); 655 // Prepare the data for the buffer_info. 656 // Buffer is configured for read-only access below. 657 Type *data = static_cast<Type *>( 658 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 659 // Prepare the shape for the buffer_info. 660 SmallVector<intptr_t, 4> shape; 661 for (intptr_t i = 0; i < rank; ++i) 662 shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 663 // Prepare the strides for the buffer_info. 664 SmallVector<intptr_t, 4> strides; 665 intptr_t strideFactor = 1; 666 for (intptr_t i = 1; i < rank; ++i) { 667 strideFactor = 1; 668 for (intptr_t j = i; j < rank; ++j) { 669 strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 670 } 671 strides.push_back(sizeof(Type) * strideFactor); 672 } 673 strides.push_back(sizeof(Type)); 674 std::string format; 675 if (explicitFormat) { 676 format = explicitFormat; 677 } else { 678 format = py::format_descriptor<Type>::format(); 679 } 680 return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, 681 /*readonly=*/true); 682 } 683 }; // namespace 684 685 /// Refinement of the PyDenseElementsAttribute for attributes containing integer 686 /// (and boolean) values. Supports element access. 687 class PyDenseIntElementsAttribute 688 : public PyConcreteAttribute<PyDenseIntElementsAttribute, 689 PyDenseElementsAttribute> { 690 public: 691 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 692 static constexpr const char *pyClassName = "DenseIntElementsAttr"; 693 using PyConcreteAttribute::PyConcreteAttribute; 694 695 /// Returns the element at the given linear position. Asserts if the index is 696 /// out of range. 697 py::int_ dunderGetItem(intptr_t pos) { 698 if (pos < 0 || pos >= dunderLen()) { 699 throw SetPyError(PyExc_IndexError, 700 "attempt to access out of bounds element"); 701 } 702 703 MlirType type = mlirAttributeGetType(*this); 704 type = mlirShapedTypeGetElementType(type); 705 assert(mlirTypeIsAInteger(type) && 706 "expected integer element type in dense int elements attribute"); 707 // Dispatch element extraction to an appropriate C function based on the 708 // elemental type of the attribute. py::int_ is implicitly constructible 709 // from any C++ integral type and handles bitwidth correctly. 710 // TODO: consider caching the type properties in the constructor to avoid 711 // querying them on each element access. 712 unsigned width = mlirIntegerTypeGetWidth(type); 713 bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 714 if (isUnsigned) { 715 if (width == 1) { 716 return mlirDenseElementsAttrGetBoolValue(*this, pos); 717 } 718 if (width == 8) { 719 return mlirDenseElementsAttrGetUInt8Value(*this, pos); 720 } 721 if (width == 16) { 722 return mlirDenseElementsAttrGetUInt16Value(*this, pos); 723 } 724 if (width == 32) { 725 return mlirDenseElementsAttrGetUInt32Value(*this, pos); 726 } 727 if (width == 64) { 728 return mlirDenseElementsAttrGetUInt64Value(*this, pos); 729 } 730 } else { 731 if (width == 1) { 732 return mlirDenseElementsAttrGetBoolValue(*this, pos); 733 } 734 if (width == 8) { 735 return mlirDenseElementsAttrGetInt8Value(*this, pos); 736 } 737 if (width == 16) { 738 return mlirDenseElementsAttrGetInt16Value(*this, pos); 739 } 740 if (width == 32) { 741 return mlirDenseElementsAttrGetInt32Value(*this, pos); 742 } 743 if (width == 64) { 744 return mlirDenseElementsAttrGetInt64Value(*this, pos); 745 } 746 } 747 throw SetPyError(PyExc_TypeError, "Unsupported integer type"); 748 } 749 750 static void bindDerived(ClassTy &c) { 751 c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 752 } 753 }; 754 755 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 756 public: 757 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 758 static constexpr const char *pyClassName = "DictAttr"; 759 using PyConcreteAttribute::PyConcreteAttribute; 760 761 intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 762 763 bool dunderContains(const std::string &name) { 764 return !mlirAttributeIsNull( 765 mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); 766 } 767 768 static void bindDerived(ClassTy &c) { 769 c.def("__contains__", &PyDictAttribute::dunderContains); 770 c.def("__len__", &PyDictAttribute::dunderLen); 771 c.def_static( 772 "get", 773 [](py::dict attributes, DefaultingPyMlirContext context) { 774 SmallVector<MlirNamedAttribute> mlirNamedAttributes; 775 mlirNamedAttributes.reserve(attributes.size()); 776 for (auto &it : attributes) { 777 auto &mlirAttr = it.second.cast<PyAttribute &>(); 778 auto name = it.first.cast<std::string>(); 779 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 780 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), 781 toMlirStringRef(name)), 782 mlirAttr)); 783 } 784 MlirAttribute attr = 785 mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 786 mlirNamedAttributes.data()); 787 return PyDictAttribute(context->getRef(), attr); 788 }, 789 py::arg("value") = py::dict(), py::arg("context") = py::none(), 790 "Gets an uniqued dict attribute"); 791 c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 792 MlirAttribute attr = 793 mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 794 if (mlirAttributeIsNull(attr)) { 795 throw SetPyError(PyExc_KeyError, 796 "attempt to access a non-existent attribute"); 797 } 798 return PyAttribute(self.getContext(), attr); 799 }); 800 c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 801 if (index < 0 || index >= self.dunderLen()) { 802 throw SetPyError(PyExc_IndexError, 803 "attempt to access out of bounds attribute"); 804 } 805 MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 806 return PyNamedAttribute( 807 namedAttr.attribute, 808 std::string(mlirIdentifierStr(namedAttr.name).data)); 809 }); 810 } 811 }; 812 813 /// Refinement of PyDenseElementsAttribute for attributes containing 814 /// floating-point values. Supports element access. 815 class PyDenseFPElementsAttribute 816 : public PyConcreteAttribute<PyDenseFPElementsAttribute, 817 PyDenseElementsAttribute> { 818 public: 819 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 820 static constexpr const char *pyClassName = "DenseFPElementsAttr"; 821 using PyConcreteAttribute::PyConcreteAttribute; 822 823 py::float_ dunderGetItem(intptr_t pos) { 824 if (pos < 0 || pos >= dunderLen()) { 825 throw SetPyError(PyExc_IndexError, 826 "attempt to access out of bounds element"); 827 } 828 829 MlirType type = mlirAttributeGetType(*this); 830 type = mlirShapedTypeGetElementType(type); 831 // Dispatch element extraction to an appropriate C function based on the 832 // elemental type of the attribute. py::float_ is implicitly constructible 833 // from float and double. 834 // TODO: consider caching the type properties in the constructor to avoid 835 // querying them on each element access. 836 if (mlirTypeIsAF32(type)) { 837 return mlirDenseElementsAttrGetFloatValue(*this, pos); 838 } 839 if (mlirTypeIsAF64(type)) { 840 return mlirDenseElementsAttrGetDoubleValue(*this, pos); 841 } 842 throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); 843 } 844 845 static void bindDerived(ClassTy &c) { 846 c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 847 } 848 }; 849 850 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 851 public: 852 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 853 static constexpr const char *pyClassName = "TypeAttr"; 854 using PyConcreteAttribute::PyConcreteAttribute; 855 856 static void bindDerived(ClassTy &c) { 857 c.def_static( 858 "get", 859 [](PyType value, DefaultingPyMlirContext context) { 860 MlirAttribute attr = mlirTypeAttrGet(value.get()); 861 return PyTypeAttribute(context->getRef(), attr); 862 }, 863 py::arg("value"), py::arg("context") = py::none(), 864 "Gets a uniqued Type attribute"); 865 c.def_property_readonly("value", [](PyTypeAttribute &self) { 866 return PyType(self.getContext()->getRef(), 867 mlirTypeAttrGetValue(self.get())); 868 }); 869 } 870 }; 871 872 /// Unit Attribute subclass. Unit attributes don't have values. 873 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 874 public: 875 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 876 static constexpr const char *pyClassName = "UnitAttr"; 877 using PyConcreteAttribute::PyConcreteAttribute; 878 879 static void bindDerived(ClassTy &c) { 880 c.def_static( 881 "get", 882 [](DefaultingPyMlirContext context) { 883 return PyUnitAttribute(context->getRef(), 884 mlirUnitAttrGet(context->get())); 885 }, 886 py::arg("context") = py::none(), "Create a Unit attribute."); 887 } 888 }; 889 890 } // namespace 891 892 void mlir::python::populateIRAttributes(py::module &m) { 893 PyAffineMapAttribute::bind(m); 894 PyArrayAttribute::bind(m); 895 PyArrayAttribute::PyArrayAttributeIterator::bind(m); 896 PyBoolAttribute::bind(m); 897 PyDenseElementsAttribute::bind(m); 898 PyDenseFPElementsAttribute::bind(m); 899 PyDenseIntElementsAttribute::bind(m); 900 PyDictAttribute::bind(m); 901 PyFlatSymbolRefAttribute::bind(m); 902 PyOpaqueAttribute::bind(m); 903 PyFloatAttribute::bind(m); 904 PyIntegerAttribute::bind(m); 905 PyStringAttribute::bind(m); 906 PyTypeAttribute::bind(m); 907 PyUnitAttribute::bind(m); 908 } 909