1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "PybindUtils.h" 12 13 #include "mlir-c/BuiltinAttributes.h" 14 #include "mlir-c/BuiltinTypes.h" 15 16 namespace py = pybind11; 17 using namespace mlir; 18 using namespace mlir::python; 19 20 using llvm::Optional; 21 using llvm::SmallVector; 22 using llvm::Twine; 23 24 //------------------------------------------------------------------------------ 25 // Docstrings (trivial, non-duplicated docstrings are included inline). 26 //------------------------------------------------------------------------------ 27 28 static const char kDenseElementsAttrGetDocstring[] = 29 R"(Gets a DenseElementsAttr from a Python buffer or array. 30 31 When `type` is not provided, then some limited type inferencing is done based 32 on the buffer format. Support presently exists for 8/16/32/64 signed and 33 unsigned integers and float16/float32/float64. DenseElementsAttrs of these 34 types can also be converted back to a corresponding buffer. 35 36 For conversions outside of these types, a `type=` must be explicitly provided 37 and the buffer contents must be bit-castable to the MLIR internal 38 representation: 39 40 * Integer types (except for i1): the buffer must be byte aligned to the 41 next byte boundary. 42 * Floating point types: Must be bit-castable to the given floating point 43 size. 44 * i1 (bool): Bit packed into 8bit words where the bit pattern matches a 45 row major ordering. An arbitrary Numpy `bool_` array can be bit packed to 46 this specification with: `np.packbits(ary, axis=None, bitorder='little')`. 47 48 If a single element buffer is passed (or for i1, a single byte with value 0 49 or 255), then a splat will be created. 50 51 Args: 52 array: The array or buffer to convert. 53 signless: If inferring an appropriate MLIR type, use signless types for 54 integers (defaults True). 55 type: Skips inference of the MLIR element type and uses this instead. The 56 storage size must be consistent with the actual contents of the buffer. 57 shape: Overrides the shape of the buffer when constructing the MLIR 58 shaped type. This is needed when the physical and logical shape differ (as 59 for i1). 60 context: Explicit context, if not from context manager. 61 62 Returns: 63 DenseElementsAttr on success. 64 65 Raises: 66 ValueError: If the type of the buffer or array cannot be matched to an MLIR 67 type or if the buffer does not meet expectations. 68 )"; 69 70 namespace { 71 72 static MlirStringRef toMlirStringRef(const std::string &s) { 73 return mlirStringRefCreate(s.data(), s.size()); 74 } 75 76 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 77 public: 78 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 79 static constexpr const char *pyClassName = "AffineMapAttr"; 80 using PyConcreteAttribute::PyConcreteAttribute; 81 82 static void bindDerived(ClassTy &c) { 83 c.def_static( 84 "get", 85 [](PyAffineMap &affineMap) { 86 MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 87 return PyAffineMapAttribute(affineMap.getContext(), attr); 88 }, 89 py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 90 } 91 }; 92 93 template <typename T> 94 static T pyTryCast(py::handle object) { 95 try { 96 return object.cast<T>(); 97 } catch (py::cast_error &err) { 98 std::string msg = 99 std::string( 100 "Invalid attribute when attempting to create an ArrayAttribute (") + 101 err.what() + ")"; 102 throw py::cast_error(msg); 103 } catch (py::reference_cast_error &err) { 104 std::string msg = std::string("Invalid attribute (None?) when attempting " 105 "to create an ArrayAttribute (") + 106 err.what() + ")"; 107 throw py::cast_error(msg); 108 } 109 } 110 111 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 112 public: 113 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 114 static constexpr const char *pyClassName = "ArrayAttr"; 115 using PyConcreteAttribute::PyConcreteAttribute; 116 117 class PyArrayAttributeIterator { 118 public: 119 PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {} 120 121 PyArrayAttributeIterator &dunderIter() { return *this; } 122 123 PyAttribute dunderNext() { 124 if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { 125 throw py::stop_iteration(); 126 } 127 return PyAttribute(attr.getContext(), 128 mlirArrayAttrGetElement(attr.get(), nextIndex++)); 129 } 130 131 static void bind(py::module &m) { 132 py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator", 133 py::module_local()) 134 .def("__iter__", &PyArrayAttributeIterator::dunderIter) 135 .def("__next__", &PyArrayAttributeIterator::dunderNext); 136 } 137 138 private: 139 PyAttribute attr; 140 int nextIndex = 0; 141 }; 142 143 PyAttribute getItem(intptr_t i) { 144 return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i)); 145 } 146 147 static void bindDerived(ClassTy &c) { 148 c.def_static( 149 "get", 150 [](py::list attributes, DefaultingPyMlirContext context) { 151 SmallVector<MlirAttribute> mlirAttributes; 152 mlirAttributes.reserve(py::len(attributes)); 153 for (auto attribute : attributes) { 154 mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); 155 } 156 MlirAttribute attr = mlirArrayAttrGet( 157 context->get(), mlirAttributes.size(), mlirAttributes.data()); 158 return PyArrayAttribute(context->getRef(), attr); 159 }, 160 py::arg("attributes"), py::arg("context") = py::none(), 161 "Gets a uniqued Array attribute"); 162 c.def("__getitem__", 163 [](PyArrayAttribute &arr, intptr_t i) { 164 if (i >= mlirArrayAttrGetNumElements(arr)) 165 throw py::index_error("ArrayAttribute index out of range"); 166 return arr.getItem(i); 167 }) 168 .def("__len__", 169 [](const PyArrayAttribute &arr) { 170 return mlirArrayAttrGetNumElements(arr); 171 }) 172 .def("__iter__", [](const PyArrayAttribute &arr) { 173 return PyArrayAttributeIterator(arr); 174 }); 175 c.def("__add__", [](PyArrayAttribute arr, py::list extras) { 176 std::vector<MlirAttribute> attributes; 177 intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); 178 attributes.reserve(numOldElements + py::len(extras)); 179 for (intptr_t i = 0; i < numOldElements; ++i) 180 attributes.push_back(arr.getItem(i)); 181 for (py::handle attr : extras) 182 attributes.push_back(pyTryCast<PyAttribute>(attr)); 183 MlirAttribute arrayAttr = mlirArrayAttrGet( 184 arr.getContext()->get(), attributes.size(), attributes.data()); 185 return PyArrayAttribute(arr.getContext(), arrayAttr); 186 }); 187 } 188 }; 189 190 /// Float Point Attribute subclass - FloatAttr. 191 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 192 public: 193 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 194 static constexpr const char *pyClassName = "FloatAttr"; 195 using PyConcreteAttribute::PyConcreteAttribute; 196 197 static void bindDerived(ClassTy &c) { 198 c.def_static( 199 "get", 200 [](PyType &type, double value, DefaultingPyLocation loc) { 201 MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 202 // TODO: Rework error reporting once diagnostic engine is exposed 203 // in C API. 204 if (mlirAttributeIsNull(attr)) { 205 throw SetPyError(PyExc_ValueError, 206 Twine("invalid '") + 207 py::repr(py::cast(type)).cast<std::string>() + 208 "' and expected floating point type."); 209 } 210 return PyFloatAttribute(type.getContext(), attr); 211 }, 212 py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 213 "Gets an uniqued float point attribute associated to a type"); 214 c.def_static( 215 "get_f32", 216 [](double value, DefaultingPyMlirContext context) { 217 MlirAttribute attr = mlirFloatAttrDoubleGet( 218 context->get(), mlirF32TypeGet(context->get()), value); 219 return PyFloatAttribute(context->getRef(), attr); 220 }, 221 py::arg("value"), py::arg("context") = py::none(), 222 "Gets an uniqued float point attribute associated to a f32 type"); 223 c.def_static( 224 "get_f64", 225 [](double value, DefaultingPyMlirContext context) { 226 MlirAttribute attr = mlirFloatAttrDoubleGet( 227 context->get(), mlirF64TypeGet(context->get()), value); 228 return PyFloatAttribute(context->getRef(), attr); 229 }, 230 py::arg("value"), py::arg("context") = py::none(), 231 "Gets an uniqued float point attribute associated to a f64 type"); 232 c.def_property_readonly( 233 "value", 234 [](PyFloatAttribute &self) { 235 return mlirFloatAttrGetValueDouble(self); 236 }, 237 "Returns the value of the float point attribute"); 238 } 239 }; 240 241 /// Integer Attribute subclass - IntegerAttr. 242 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 243 public: 244 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 245 static constexpr const char *pyClassName = "IntegerAttr"; 246 using PyConcreteAttribute::PyConcreteAttribute; 247 248 static void bindDerived(ClassTy &c) { 249 c.def_static( 250 "get", 251 [](PyType &type, int64_t value) { 252 MlirAttribute attr = mlirIntegerAttrGet(type, value); 253 return PyIntegerAttribute(type.getContext(), attr); 254 }, 255 py::arg("type"), py::arg("value"), 256 "Gets an uniqued integer attribute associated to a type"); 257 c.def_property_readonly( 258 "value", 259 [](PyIntegerAttribute &self) { 260 return mlirIntegerAttrGetValueInt(self); 261 }, 262 "Returns the value of the integer attribute"); 263 } 264 }; 265 266 /// Bool Attribute subclass - BoolAttr. 267 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 268 public: 269 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 270 static constexpr const char *pyClassName = "BoolAttr"; 271 using PyConcreteAttribute::PyConcreteAttribute; 272 273 static void bindDerived(ClassTy &c) { 274 c.def_static( 275 "get", 276 [](bool value, DefaultingPyMlirContext context) { 277 MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 278 return PyBoolAttribute(context->getRef(), attr); 279 }, 280 py::arg("value"), py::arg("context") = py::none(), 281 "Gets an uniqued bool attribute"); 282 c.def_property_readonly( 283 "value", 284 [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, 285 "Returns the value of the bool attribute"); 286 } 287 }; 288 289 class PyFlatSymbolRefAttribute 290 : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 291 public: 292 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 293 static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 294 using PyConcreteAttribute::PyConcreteAttribute; 295 296 static void bindDerived(ClassTy &c) { 297 c.def_static( 298 "get", 299 [](std::string value, DefaultingPyMlirContext context) { 300 MlirAttribute attr = 301 mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 302 return PyFlatSymbolRefAttribute(context->getRef(), attr); 303 }, 304 py::arg("value"), py::arg("context") = py::none(), 305 "Gets a uniqued FlatSymbolRef attribute"); 306 c.def_property_readonly( 307 "value", 308 [](PyFlatSymbolRefAttribute &self) { 309 MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 310 return py::str(stringRef.data, stringRef.length); 311 }, 312 "Returns the value of the FlatSymbolRef attribute as a string"); 313 } 314 }; 315 316 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 317 public: 318 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 319 static constexpr const char *pyClassName = "StringAttr"; 320 using PyConcreteAttribute::PyConcreteAttribute; 321 322 static void bindDerived(ClassTy &c) { 323 c.def_static( 324 "get", 325 [](std::string value, DefaultingPyMlirContext context) { 326 MlirAttribute attr = 327 mlirStringAttrGet(context->get(), toMlirStringRef(value)); 328 return PyStringAttribute(context->getRef(), attr); 329 }, 330 py::arg("value"), py::arg("context") = py::none(), 331 "Gets a uniqued string attribute"); 332 c.def_static( 333 "get_typed", 334 [](PyType &type, std::string value) { 335 MlirAttribute attr = 336 mlirStringAttrTypedGet(type, toMlirStringRef(value)); 337 return PyStringAttribute(type.getContext(), attr); 338 }, 339 py::arg("type"), py::arg("value"), 340 "Gets a uniqued string attribute associated to a type"); 341 c.def_property_readonly( 342 "value", 343 [](PyStringAttribute &self) { 344 MlirStringRef stringRef = mlirStringAttrGetValue(self); 345 return py::str(stringRef.data, stringRef.length); 346 }, 347 "Returns the value of the string attribute"); 348 } 349 }; 350 351 // TODO: Support construction of string elements. 352 class PyDenseElementsAttribute 353 : public PyConcreteAttribute<PyDenseElementsAttribute> { 354 public: 355 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 356 static constexpr const char *pyClassName = "DenseElementsAttr"; 357 using PyConcreteAttribute::PyConcreteAttribute; 358 359 static PyDenseElementsAttribute 360 getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType, 361 Optional<std::vector<int64_t>> explicitShape, 362 DefaultingPyMlirContext contextWrapper) { 363 // Request a contiguous view. In exotic cases, this will cause a copy. 364 int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; 365 Py_buffer *view = new Py_buffer(); 366 if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { 367 delete view; 368 throw py::error_already_set(); 369 } 370 py::buffer_info arrayInfo(view); 371 SmallVector<int64_t> shape; 372 if (explicitShape) { 373 shape.append(explicitShape->begin(), explicitShape->end()); 374 } else { 375 shape.append(arrayInfo.shape.begin(), 376 arrayInfo.shape.begin() + arrayInfo.ndim); 377 } 378 379 MlirAttribute encodingAttr = mlirAttributeGetNull(); 380 MlirContext context = contextWrapper->get(); 381 382 // Detect format codes that are suitable for bulk loading. This includes 383 // all byte aligned integer and floating point types up to 8 bytes. 384 // Notably, this excludes, bool (which needs to be bit-packed) and 385 // other exotics which do not have a direct representation in the buffer 386 // protocol (i.e. complex, etc). 387 Optional<MlirType> bulkLoadElementType; 388 if (explicitType) { 389 bulkLoadElementType = *explicitType; 390 } else if (arrayInfo.format == "f") { 391 // f32 392 assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); 393 bulkLoadElementType = mlirF32TypeGet(context); 394 } else if (arrayInfo.format == "d") { 395 // f64 396 assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); 397 bulkLoadElementType = mlirF64TypeGet(context); 398 } else if (arrayInfo.format == "e") { 399 // f16 400 assert(arrayInfo.itemsize == 2 && "mismatched array itemsize"); 401 bulkLoadElementType = mlirF16TypeGet(context); 402 } else if (isSignedIntegerFormat(arrayInfo.format)) { 403 if (arrayInfo.itemsize == 4) { 404 // i32 405 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32) 406 : mlirIntegerTypeSignedGet(context, 32); 407 } else if (arrayInfo.itemsize == 8) { 408 // i64 409 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64) 410 : mlirIntegerTypeSignedGet(context, 64); 411 } else if (arrayInfo.itemsize == 1) { 412 // i8 413 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 414 : mlirIntegerTypeSignedGet(context, 8); 415 } else if (arrayInfo.itemsize == 2) { 416 // i16 417 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16) 418 : mlirIntegerTypeSignedGet(context, 16); 419 } 420 } else if (isUnsignedIntegerFormat(arrayInfo.format)) { 421 if (arrayInfo.itemsize == 4) { 422 // unsigned i32 423 bulkLoadElementType = signless 424 ? mlirIntegerTypeGet(context, 32) 425 : mlirIntegerTypeUnsignedGet(context, 32); 426 } else if (arrayInfo.itemsize == 8) { 427 // unsigned i64 428 bulkLoadElementType = signless 429 ? mlirIntegerTypeGet(context, 64) 430 : mlirIntegerTypeUnsignedGet(context, 64); 431 } else if (arrayInfo.itemsize == 1) { 432 // i8 433 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 434 : mlirIntegerTypeUnsignedGet(context, 8); 435 } else if (arrayInfo.itemsize == 2) { 436 // i16 437 bulkLoadElementType = signless 438 ? mlirIntegerTypeGet(context, 16) 439 : mlirIntegerTypeUnsignedGet(context, 16); 440 } 441 } 442 if (bulkLoadElementType) { 443 auto shapedType = mlirRankedTensorTypeGet( 444 shape.size(), shape.data(), *bulkLoadElementType, encodingAttr); 445 size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize; 446 MlirAttribute attr = mlirDenseElementsAttrRawBufferGet( 447 shapedType, rawBufferSize, arrayInfo.ptr); 448 if (mlirAttributeIsNull(attr)) { 449 throw std::invalid_argument( 450 "DenseElementsAttr could not be constructed from the given buffer. " 451 "This may mean that the Python buffer layout does not match that " 452 "MLIR expected layout and is a bug."); 453 } 454 return PyDenseElementsAttribute(contextWrapper->getRef(), attr); 455 } 456 457 throw std::invalid_argument( 458 std::string("unimplemented array format conversion from format: ") + 459 arrayInfo.format); 460 } 461 462 static PyDenseElementsAttribute getSplat(PyType shapedType, 463 PyAttribute &elementAttr) { 464 auto contextWrapper = 465 PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 466 if (!mlirAttributeIsAInteger(elementAttr) && 467 !mlirAttributeIsAFloat(elementAttr)) { 468 std::string message = "Illegal element type for DenseElementsAttr: "; 469 message.append(py::repr(py::cast(elementAttr))); 470 throw SetPyError(PyExc_ValueError, message); 471 } 472 if (!mlirTypeIsAShaped(shapedType) || 473 !mlirShapedTypeHasStaticShape(shapedType)) { 474 std::string message = 475 "Expected a static ShapedType for the shaped_type parameter: "; 476 message.append(py::repr(py::cast(shapedType))); 477 throw SetPyError(PyExc_ValueError, message); 478 } 479 MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 480 MlirType attrType = mlirAttributeGetType(elementAttr); 481 if (!mlirTypeEqual(shapedElementType, attrType)) { 482 std::string message = 483 "Shaped element type and attribute type must be equal: shaped="; 484 message.append(py::repr(py::cast(shapedType))); 485 message.append(", element="); 486 message.append(py::repr(py::cast(elementAttr))); 487 throw SetPyError(PyExc_ValueError, message); 488 } 489 490 MlirAttribute elements = 491 mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 492 return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 493 } 494 495 intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 496 497 py::buffer_info accessBuffer() { 498 if (mlirDenseElementsAttrIsSplat(*this)) { 499 // TODO: Currently crashes the program. 500 // Reported as https://github.com/pybind/pybind11/issues/3336 501 throw std::invalid_argument( 502 "unsupported data type for conversion to Python buffer"); 503 } 504 505 MlirType shapedType = mlirAttributeGetType(*this); 506 MlirType elementType = mlirShapedTypeGetElementType(shapedType); 507 std::string format; 508 509 if (mlirTypeIsAF32(elementType)) { 510 // f32 511 return bufferInfo<float>(shapedType); 512 } 513 if (mlirTypeIsAF64(elementType)) { 514 // f64 515 return bufferInfo<double>(shapedType); 516 } 517 if (mlirTypeIsAF16(elementType)) { 518 // f16 519 return bufferInfo<uint16_t>(shapedType, "e"); 520 } 521 if (mlirTypeIsAInteger(elementType) && 522 mlirIntegerTypeGetWidth(elementType) == 32) { 523 if (mlirIntegerTypeIsSignless(elementType) || 524 mlirIntegerTypeIsSigned(elementType)) { 525 // i32 526 return bufferInfo<int32_t>(shapedType); 527 } 528 if (mlirIntegerTypeIsUnsigned(elementType)) { 529 // unsigned i32 530 return bufferInfo<uint32_t>(shapedType); 531 } 532 } else if (mlirTypeIsAInteger(elementType) && 533 mlirIntegerTypeGetWidth(elementType) == 64) { 534 if (mlirIntegerTypeIsSignless(elementType) || 535 mlirIntegerTypeIsSigned(elementType)) { 536 // i64 537 return bufferInfo<int64_t>(shapedType); 538 } 539 if (mlirIntegerTypeIsUnsigned(elementType)) { 540 // unsigned i64 541 return bufferInfo<uint64_t>(shapedType); 542 } 543 } else if (mlirTypeIsAInteger(elementType) && 544 mlirIntegerTypeGetWidth(elementType) == 8) { 545 if (mlirIntegerTypeIsSignless(elementType) || 546 mlirIntegerTypeIsSigned(elementType)) { 547 // i8 548 return bufferInfo<int8_t>(shapedType); 549 } 550 if (mlirIntegerTypeIsUnsigned(elementType)) { 551 // unsigned i8 552 return bufferInfo<uint8_t>(shapedType); 553 } 554 } else if (mlirTypeIsAInteger(elementType) && 555 mlirIntegerTypeGetWidth(elementType) == 16) { 556 if (mlirIntegerTypeIsSignless(elementType) || 557 mlirIntegerTypeIsSigned(elementType)) { 558 // i16 559 return bufferInfo<int16_t>(shapedType); 560 } 561 if (mlirIntegerTypeIsUnsigned(elementType)) { 562 // unsigned i16 563 return bufferInfo<uint16_t>(shapedType); 564 } 565 } 566 567 // TODO: Currently crashes the program. 568 // Reported as https://github.com/pybind/pybind11/issues/3336 569 throw std::invalid_argument( 570 "unsupported data type for conversion to Python buffer"); 571 } 572 573 static void bindDerived(ClassTy &c) { 574 c.def("__len__", &PyDenseElementsAttribute::dunderLen) 575 .def_static("get", PyDenseElementsAttribute::getFromBuffer, 576 py::arg("array"), py::arg("signless") = true, 577 py::arg("type") = py::none(), py::arg("shape") = py::none(), 578 py::arg("context") = py::none(), 579 kDenseElementsAttrGetDocstring) 580 .def_static("get_splat", PyDenseElementsAttribute::getSplat, 581 py::arg("shaped_type"), py::arg("element_attr"), 582 "Gets a DenseElementsAttr where all values are the same") 583 .def_property_readonly("is_splat", 584 [](PyDenseElementsAttribute &self) -> bool { 585 return mlirDenseElementsAttrIsSplat(self); 586 }) 587 .def_buffer(&PyDenseElementsAttribute::accessBuffer); 588 } 589 590 private: 591 static bool isUnsignedIntegerFormat(const std::string &format) { 592 if (format.empty()) 593 return false; 594 char code = format[0]; 595 return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 596 code == 'Q'; 597 } 598 599 static bool isSignedIntegerFormat(const std::string &format) { 600 if (format.empty()) 601 return false; 602 char code = format[0]; 603 return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 604 code == 'q'; 605 } 606 607 template <typename Type> 608 py::buffer_info bufferInfo(MlirType shapedType, 609 const char *explicitFormat = nullptr) { 610 intptr_t rank = mlirShapedTypeGetRank(shapedType); 611 // Prepare the data for the buffer_info. 612 // Buffer is configured for read-only access below. 613 Type *data = static_cast<Type *>( 614 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 615 // Prepare the shape for the buffer_info. 616 SmallVector<intptr_t, 4> shape; 617 for (intptr_t i = 0; i < rank; ++i) 618 shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 619 // Prepare the strides for the buffer_info. 620 SmallVector<intptr_t, 4> strides; 621 intptr_t strideFactor = 1; 622 for (intptr_t i = 1; i < rank; ++i) { 623 strideFactor = 1; 624 for (intptr_t j = i; j < rank; ++j) { 625 strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 626 } 627 strides.push_back(sizeof(Type) * strideFactor); 628 } 629 strides.push_back(sizeof(Type)); 630 std::string format; 631 if (explicitFormat) { 632 format = explicitFormat; 633 } else { 634 format = py::format_descriptor<Type>::format(); 635 } 636 return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, 637 /*readonly=*/true); 638 } 639 }; // namespace 640 641 /// Refinement of the PyDenseElementsAttribute for attributes containing integer 642 /// (and boolean) values. Supports element access. 643 class PyDenseIntElementsAttribute 644 : public PyConcreteAttribute<PyDenseIntElementsAttribute, 645 PyDenseElementsAttribute> { 646 public: 647 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 648 static constexpr const char *pyClassName = "DenseIntElementsAttr"; 649 using PyConcreteAttribute::PyConcreteAttribute; 650 651 /// Returns the element at the given linear position. Asserts if the index is 652 /// out of range. 653 py::int_ dunderGetItem(intptr_t pos) { 654 if (pos < 0 || pos >= dunderLen()) { 655 throw SetPyError(PyExc_IndexError, 656 "attempt to access out of bounds element"); 657 } 658 659 MlirType type = mlirAttributeGetType(*this); 660 type = mlirShapedTypeGetElementType(type); 661 assert(mlirTypeIsAInteger(type) && 662 "expected integer element type in dense int elements attribute"); 663 // Dispatch element extraction to an appropriate C function based on the 664 // elemental type of the attribute. py::int_ is implicitly constructible 665 // from any C++ integral type and handles bitwidth correctly. 666 // TODO: consider caching the type properties in the constructor to avoid 667 // querying them on each element access. 668 unsigned width = mlirIntegerTypeGetWidth(type); 669 bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 670 if (isUnsigned) { 671 if (width == 1) { 672 return mlirDenseElementsAttrGetBoolValue(*this, pos); 673 } 674 if (width == 32) { 675 return mlirDenseElementsAttrGetUInt32Value(*this, pos); 676 } 677 if (width == 64) { 678 return mlirDenseElementsAttrGetUInt64Value(*this, pos); 679 } 680 } else { 681 if (width == 1) { 682 return mlirDenseElementsAttrGetBoolValue(*this, pos); 683 } 684 if (width == 32) { 685 return mlirDenseElementsAttrGetInt32Value(*this, pos); 686 } 687 if (width == 64) { 688 return mlirDenseElementsAttrGetInt64Value(*this, pos); 689 } 690 } 691 throw SetPyError(PyExc_TypeError, "Unsupported integer type"); 692 } 693 694 static void bindDerived(ClassTy &c) { 695 c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 696 } 697 }; 698 699 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 700 public: 701 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 702 static constexpr const char *pyClassName = "DictAttr"; 703 using PyConcreteAttribute::PyConcreteAttribute; 704 705 intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 706 707 bool dunderContains(const std::string &name) { 708 return !mlirAttributeIsNull( 709 mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); 710 } 711 712 static void bindDerived(ClassTy &c) { 713 c.def("__contains__", &PyDictAttribute::dunderContains); 714 c.def("__len__", &PyDictAttribute::dunderLen); 715 c.def_static( 716 "get", 717 [](py::dict attributes, DefaultingPyMlirContext context) { 718 SmallVector<MlirNamedAttribute> mlirNamedAttributes; 719 mlirNamedAttributes.reserve(attributes.size()); 720 for (auto &it : attributes) { 721 auto &mlirAttr = it.second.cast<PyAttribute &>(); 722 auto name = it.first.cast<std::string>(); 723 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 724 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), 725 toMlirStringRef(name)), 726 mlirAttr)); 727 } 728 MlirAttribute attr = 729 mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 730 mlirNamedAttributes.data()); 731 return PyDictAttribute(context->getRef(), attr); 732 }, 733 py::arg("value") = py::dict(), py::arg("context") = py::none(), 734 "Gets an uniqued dict attribute"); 735 c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 736 MlirAttribute attr = 737 mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 738 if (mlirAttributeIsNull(attr)) { 739 throw SetPyError(PyExc_KeyError, 740 "attempt to access a non-existent attribute"); 741 } 742 return PyAttribute(self.getContext(), attr); 743 }); 744 c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 745 if (index < 0 || index >= self.dunderLen()) { 746 throw SetPyError(PyExc_IndexError, 747 "attempt to access out of bounds attribute"); 748 } 749 MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 750 return PyNamedAttribute( 751 namedAttr.attribute, 752 std::string(mlirIdentifierStr(namedAttr.name).data)); 753 }); 754 } 755 }; 756 757 /// Refinement of PyDenseElementsAttribute for attributes containing 758 /// floating-point values. Supports element access. 759 class PyDenseFPElementsAttribute 760 : public PyConcreteAttribute<PyDenseFPElementsAttribute, 761 PyDenseElementsAttribute> { 762 public: 763 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 764 static constexpr const char *pyClassName = "DenseFPElementsAttr"; 765 using PyConcreteAttribute::PyConcreteAttribute; 766 767 py::float_ dunderGetItem(intptr_t pos) { 768 if (pos < 0 || pos >= dunderLen()) { 769 throw SetPyError(PyExc_IndexError, 770 "attempt to access out of bounds element"); 771 } 772 773 MlirType type = mlirAttributeGetType(*this); 774 type = mlirShapedTypeGetElementType(type); 775 // Dispatch element extraction to an appropriate C function based on the 776 // elemental type of the attribute. py::float_ is implicitly constructible 777 // from float and double. 778 // TODO: consider caching the type properties in the constructor to avoid 779 // querying them on each element access. 780 if (mlirTypeIsAF32(type)) { 781 return mlirDenseElementsAttrGetFloatValue(*this, pos); 782 } 783 if (mlirTypeIsAF64(type)) { 784 return mlirDenseElementsAttrGetDoubleValue(*this, pos); 785 } 786 throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); 787 } 788 789 static void bindDerived(ClassTy &c) { 790 c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 791 } 792 }; 793 794 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 795 public: 796 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 797 static constexpr const char *pyClassName = "TypeAttr"; 798 using PyConcreteAttribute::PyConcreteAttribute; 799 800 static void bindDerived(ClassTy &c) { 801 c.def_static( 802 "get", 803 [](PyType value, DefaultingPyMlirContext context) { 804 MlirAttribute attr = mlirTypeAttrGet(value.get()); 805 return PyTypeAttribute(context->getRef(), attr); 806 }, 807 py::arg("value"), py::arg("context") = py::none(), 808 "Gets a uniqued Type attribute"); 809 c.def_property_readonly("value", [](PyTypeAttribute &self) { 810 return PyType(self.getContext()->getRef(), 811 mlirTypeAttrGetValue(self.get())); 812 }); 813 } 814 }; 815 816 /// Unit Attribute subclass. Unit attributes don't have values. 817 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 818 public: 819 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 820 static constexpr const char *pyClassName = "UnitAttr"; 821 using PyConcreteAttribute::PyConcreteAttribute; 822 823 static void bindDerived(ClassTy &c) { 824 c.def_static( 825 "get", 826 [](DefaultingPyMlirContext context) { 827 return PyUnitAttribute(context->getRef(), 828 mlirUnitAttrGet(context->get())); 829 }, 830 py::arg("context") = py::none(), "Create a Unit attribute."); 831 } 832 }; 833 834 } // namespace 835 836 void mlir::python::populateIRAttributes(py::module &m) { 837 PyAffineMapAttribute::bind(m); 838 PyArrayAttribute::bind(m); 839 PyArrayAttribute::PyArrayAttributeIterator::bind(m); 840 PyBoolAttribute::bind(m); 841 PyDenseElementsAttribute::bind(m); 842 PyDenseFPElementsAttribute::bind(m); 843 PyDenseIntElementsAttribute::bind(m); 844 PyDictAttribute::bind(m); 845 PyFlatSymbolRefAttribute::bind(m); 846 PyFloatAttribute::bind(m); 847 PyIntegerAttribute::bind(m); 848 PyStringAttribute::bind(m); 849 PyTypeAttribute::bind(m); 850 PyUnitAttribute::bind(m); 851 } 852