1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "PybindUtils.h" 12 13 #include "mlir-c/BuiltinAttributes.h" 14 #include "mlir-c/BuiltinTypes.h" 15 16 namespace py = pybind11; 17 using namespace mlir; 18 using namespace mlir::python; 19 20 using llvm::None; 21 using llvm::Optional; 22 using llvm::SmallVector; 23 using llvm::Twine; 24 25 //------------------------------------------------------------------------------ 26 // Docstrings (trivial, non-duplicated docstrings are included inline). 27 //------------------------------------------------------------------------------ 28 29 static const char kDenseElementsAttrGetDocstring[] = 30 R"(Gets a DenseElementsAttr from a Python buffer or array. 31 32 When `type` is not provided, then some limited type inferencing is done based 33 on the buffer format. Support presently exists for 8/16/32/64 signed and 34 unsigned integers and float16/float32/float64. DenseElementsAttrs of these 35 types can also be converted back to a corresponding buffer. 36 37 For conversions outside of these types, a `type=` must be explicitly provided 38 and the buffer contents must be bit-castable to the MLIR internal 39 representation: 40 41 * Integer types (except for i1): the buffer must be byte aligned to the 42 next byte boundary. 43 * Floating point types: Must be bit-castable to the given floating point 44 size. 45 * i1 (bool): Bit packed into 8bit words where the bit pattern matches a 46 row major ordering. An arbitrary Numpy `bool_` array can be bit packed to 47 this specification with: `np.packbits(ary, axis=None, bitorder='little')`. 48 49 If a single element buffer is passed (or for i1, a single byte with value 0 50 or 255), then a splat will be created. 51 52 Args: 53 array: The array or buffer to convert. 54 signless: If inferring an appropriate MLIR type, use signless types for 55 integers (defaults True). 56 type: Skips inference of the MLIR element type and uses this instead. The 57 storage size must be consistent with the actual contents of the buffer. 58 shape: Overrides the shape of the buffer when constructing the MLIR 59 shaped type. This is needed when the physical and logical shape differ (as 60 for i1). 61 context: Explicit context, if not from context manager. 62 63 Returns: 64 DenseElementsAttr on success. 65 66 Raises: 67 ValueError: If the type of the buffer or array cannot be matched to an MLIR 68 type or if the buffer does not meet expectations. 69 )"; 70 71 namespace { 72 73 static MlirStringRef toMlirStringRef(const std::string &s) { 74 return mlirStringRefCreate(s.data(), s.size()); 75 } 76 77 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 78 public: 79 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 80 static constexpr const char *pyClassName = "AffineMapAttr"; 81 using PyConcreteAttribute::PyConcreteAttribute; 82 83 static void bindDerived(ClassTy &c) { 84 c.def_static( 85 "get", 86 [](PyAffineMap &affineMap) { 87 MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 88 return PyAffineMapAttribute(affineMap.getContext(), attr); 89 }, 90 py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 91 } 92 }; 93 94 template <typename T> 95 static T pyTryCast(py::handle object) { 96 try { 97 return object.cast<T>(); 98 } catch (py::cast_error &err) { 99 std::string msg = 100 std::string( 101 "Invalid attribute when attempting to create an ArrayAttribute (") + 102 err.what() + ")"; 103 throw py::cast_error(msg); 104 } catch (py::reference_cast_error &err) { 105 std::string msg = std::string("Invalid attribute (None?) when attempting " 106 "to create an ArrayAttribute (") + 107 err.what() + ")"; 108 throw py::cast_error(msg); 109 } 110 } 111 112 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 113 public: 114 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 115 static constexpr const char *pyClassName = "ArrayAttr"; 116 using PyConcreteAttribute::PyConcreteAttribute; 117 118 class PyArrayAttributeIterator { 119 public: 120 PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {} 121 122 PyArrayAttributeIterator &dunderIter() { return *this; } 123 124 PyAttribute dunderNext() { 125 if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { 126 throw py::stop_iteration(); 127 } 128 return PyAttribute(attr.getContext(), 129 mlirArrayAttrGetElement(attr.get(), nextIndex++)); 130 } 131 132 static void bind(py::module &m) { 133 py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator", 134 py::module_local()) 135 .def("__iter__", &PyArrayAttributeIterator::dunderIter) 136 .def("__next__", &PyArrayAttributeIterator::dunderNext); 137 } 138 139 private: 140 PyAttribute attr; 141 int nextIndex = 0; 142 }; 143 144 PyAttribute getItem(intptr_t i) { 145 return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i)); 146 } 147 148 static void bindDerived(ClassTy &c) { 149 c.def_static( 150 "get", 151 [](py::list attributes, DefaultingPyMlirContext context) { 152 SmallVector<MlirAttribute> mlirAttributes; 153 mlirAttributes.reserve(py::len(attributes)); 154 for (auto attribute : attributes) { 155 mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); 156 } 157 MlirAttribute attr = mlirArrayAttrGet( 158 context->get(), mlirAttributes.size(), mlirAttributes.data()); 159 return PyArrayAttribute(context->getRef(), attr); 160 }, 161 py::arg("attributes"), py::arg("context") = py::none(), 162 "Gets a uniqued Array attribute"); 163 c.def("__getitem__", 164 [](PyArrayAttribute &arr, intptr_t i) { 165 if (i >= mlirArrayAttrGetNumElements(arr)) 166 throw py::index_error("ArrayAttribute index out of range"); 167 return arr.getItem(i); 168 }) 169 .def("__len__", 170 [](const PyArrayAttribute &arr) { 171 return mlirArrayAttrGetNumElements(arr); 172 }) 173 .def("__iter__", [](const PyArrayAttribute &arr) { 174 return PyArrayAttributeIterator(arr); 175 }); 176 c.def("__add__", [](PyArrayAttribute arr, py::list extras) { 177 std::vector<MlirAttribute> attributes; 178 intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); 179 attributes.reserve(numOldElements + py::len(extras)); 180 for (intptr_t i = 0; i < numOldElements; ++i) 181 attributes.push_back(arr.getItem(i)); 182 for (py::handle attr : extras) 183 attributes.push_back(pyTryCast<PyAttribute>(attr)); 184 MlirAttribute arrayAttr = mlirArrayAttrGet( 185 arr.getContext()->get(), attributes.size(), attributes.data()); 186 return PyArrayAttribute(arr.getContext(), arrayAttr); 187 }); 188 } 189 }; 190 191 /// Float Point Attribute subclass - FloatAttr. 192 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 193 public: 194 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 195 static constexpr const char *pyClassName = "FloatAttr"; 196 using PyConcreteAttribute::PyConcreteAttribute; 197 198 static void bindDerived(ClassTy &c) { 199 c.def_static( 200 "get", 201 [](PyType &type, double value, DefaultingPyLocation loc) { 202 MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 203 // TODO: Rework error reporting once diagnostic engine is exposed 204 // in C API. 205 if (mlirAttributeIsNull(attr)) { 206 throw SetPyError(PyExc_ValueError, 207 Twine("invalid '") + 208 py::repr(py::cast(type)).cast<std::string>() + 209 "' and expected floating point type."); 210 } 211 return PyFloatAttribute(type.getContext(), attr); 212 }, 213 py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 214 "Gets an uniqued float point attribute associated to a type"); 215 c.def_static( 216 "get_f32", 217 [](double value, DefaultingPyMlirContext context) { 218 MlirAttribute attr = mlirFloatAttrDoubleGet( 219 context->get(), mlirF32TypeGet(context->get()), value); 220 return PyFloatAttribute(context->getRef(), attr); 221 }, 222 py::arg("value"), py::arg("context") = py::none(), 223 "Gets an uniqued float point attribute associated to a f32 type"); 224 c.def_static( 225 "get_f64", 226 [](double value, DefaultingPyMlirContext context) { 227 MlirAttribute attr = mlirFloatAttrDoubleGet( 228 context->get(), mlirF64TypeGet(context->get()), value); 229 return PyFloatAttribute(context->getRef(), attr); 230 }, 231 py::arg("value"), py::arg("context") = py::none(), 232 "Gets an uniqued float point attribute associated to a f64 type"); 233 c.def_property_readonly( 234 "value", 235 [](PyFloatAttribute &self) { 236 return mlirFloatAttrGetValueDouble(self); 237 }, 238 "Returns the value of the float point attribute"); 239 } 240 }; 241 242 /// Integer Attribute subclass - IntegerAttr. 243 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 244 public: 245 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 246 static constexpr const char *pyClassName = "IntegerAttr"; 247 using PyConcreteAttribute::PyConcreteAttribute; 248 249 static void bindDerived(ClassTy &c) { 250 c.def_static( 251 "get", 252 [](PyType &type, int64_t value) { 253 MlirAttribute attr = mlirIntegerAttrGet(type, value); 254 return PyIntegerAttribute(type.getContext(), attr); 255 }, 256 py::arg("type"), py::arg("value"), 257 "Gets an uniqued integer attribute associated to a type"); 258 c.def_property_readonly( 259 "value", 260 [](PyIntegerAttribute &self) { 261 return mlirIntegerAttrGetValueInt(self); 262 }, 263 "Returns the value of the integer attribute"); 264 } 265 }; 266 267 /// Bool Attribute subclass - BoolAttr. 268 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 269 public: 270 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 271 static constexpr const char *pyClassName = "BoolAttr"; 272 using PyConcreteAttribute::PyConcreteAttribute; 273 274 static void bindDerived(ClassTy &c) { 275 c.def_static( 276 "get", 277 [](bool value, DefaultingPyMlirContext context) { 278 MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 279 return PyBoolAttribute(context->getRef(), attr); 280 }, 281 py::arg("value"), py::arg("context") = py::none(), 282 "Gets an uniqued bool attribute"); 283 c.def_property_readonly( 284 "value", 285 [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, 286 "Returns the value of the bool attribute"); 287 } 288 }; 289 290 class PyFlatSymbolRefAttribute 291 : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 292 public: 293 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 294 static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 295 using PyConcreteAttribute::PyConcreteAttribute; 296 297 static void bindDerived(ClassTy &c) { 298 c.def_static( 299 "get", 300 [](std::string value, DefaultingPyMlirContext context) { 301 MlirAttribute attr = 302 mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 303 return PyFlatSymbolRefAttribute(context->getRef(), attr); 304 }, 305 py::arg("value"), py::arg("context") = py::none(), 306 "Gets a uniqued FlatSymbolRef attribute"); 307 c.def_property_readonly( 308 "value", 309 [](PyFlatSymbolRefAttribute &self) { 310 MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 311 return py::str(stringRef.data, stringRef.length); 312 }, 313 "Returns the value of the FlatSymbolRef attribute as a string"); 314 } 315 }; 316 317 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 318 public: 319 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 320 static constexpr const char *pyClassName = "StringAttr"; 321 using PyConcreteAttribute::PyConcreteAttribute; 322 323 static void bindDerived(ClassTy &c) { 324 c.def_static( 325 "get", 326 [](std::string value, DefaultingPyMlirContext context) { 327 MlirAttribute attr = 328 mlirStringAttrGet(context->get(), toMlirStringRef(value)); 329 return PyStringAttribute(context->getRef(), attr); 330 }, 331 py::arg("value"), py::arg("context") = py::none(), 332 "Gets a uniqued string attribute"); 333 c.def_static( 334 "get_typed", 335 [](PyType &type, std::string value) { 336 MlirAttribute attr = 337 mlirStringAttrTypedGet(type, toMlirStringRef(value)); 338 return PyStringAttribute(type.getContext(), attr); 339 }, 340 py::arg("type"), py::arg("value"), 341 "Gets a uniqued string attribute associated to a type"); 342 c.def_property_readonly( 343 "value", 344 [](PyStringAttribute &self) { 345 MlirStringRef stringRef = mlirStringAttrGetValue(self); 346 return py::str(stringRef.data, stringRef.length); 347 }, 348 "Returns the value of the string attribute"); 349 } 350 }; 351 352 // TODO: Support construction of string elements. 353 class PyDenseElementsAttribute 354 : public PyConcreteAttribute<PyDenseElementsAttribute> { 355 public: 356 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 357 static constexpr const char *pyClassName = "DenseElementsAttr"; 358 using PyConcreteAttribute::PyConcreteAttribute; 359 360 static PyDenseElementsAttribute 361 getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType, 362 Optional<std::vector<int64_t>> explicitShape, 363 DefaultingPyMlirContext contextWrapper) { 364 // Request a contiguous view. In exotic cases, this will cause a copy. 365 int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; 366 Py_buffer *view = new Py_buffer(); 367 if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { 368 delete view; 369 throw py::error_already_set(); 370 } 371 py::buffer_info arrayInfo(view); 372 SmallVector<int64_t> shape; 373 if (explicitShape) { 374 shape.append(explicitShape->begin(), explicitShape->end()); 375 } else { 376 shape.append(arrayInfo.shape.begin(), 377 arrayInfo.shape.begin() + arrayInfo.ndim); 378 } 379 380 MlirAttribute encodingAttr = mlirAttributeGetNull(); 381 MlirContext context = contextWrapper->get(); 382 383 // Detect format codes that are suitable for bulk loading. This includes 384 // all byte aligned integer and floating point types up to 8 bytes. 385 // Notably, this excludes, bool (which needs to be bit-packed) and 386 // other exotics which do not have a direct representation in the buffer 387 // protocol (i.e. complex, etc). 388 Optional<MlirType> bulkLoadElementType; 389 if (explicitType) { 390 bulkLoadElementType = *explicitType; 391 } else if (arrayInfo.format == "f") { 392 // f32 393 assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); 394 bulkLoadElementType = mlirF32TypeGet(context); 395 } else if (arrayInfo.format == "d") { 396 // f64 397 assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); 398 bulkLoadElementType = mlirF64TypeGet(context); 399 } else if (arrayInfo.format == "e") { 400 // f16 401 assert(arrayInfo.itemsize == 2 && "mismatched array itemsize"); 402 bulkLoadElementType = mlirF16TypeGet(context); 403 } else if (isSignedIntegerFormat(arrayInfo.format)) { 404 if (arrayInfo.itemsize == 4) { 405 // i32 406 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32) 407 : mlirIntegerTypeSignedGet(context, 32); 408 } else if (arrayInfo.itemsize == 8) { 409 // i64 410 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64) 411 : mlirIntegerTypeSignedGet(context, 64); 412 } else if (arrayInfo.itemsize == 1) { 413 // i8 414 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 415 : mlirIntegerTypeSignedGet(context, 8); 416 } else if (arrayInfo.itemsize == 2) { 417 // i16 418 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16) 419 : mlirIntegerTypeSignedGet(context, 16); 420 } 421 } else if (isUnsignedIntegerFormat(arrayInfo.format)) { 422 if (arrayInfo.itemsize == 4) { 423 // unsigned i32 424 bulkLoadElementType = signless 425 ? mlirIntegerTypeGet(context, 32) 426 : mlirIntegerTypeUnsignedGet(context, 32); 427 } else if (arrayInfo.itemsize == 8) { 428 // unsigned i64 429 bulkLoadElementType = signless 430 ? mlirIntegerTypeGet(context, 64) 431 : mlirIntegerTypeUnsignedGet(context, 64); 432 } else if (arrayInfo.itemsize == 1) { 433 // i8 434 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 435 : mlirIntegerTypeUnsignedGet(context, 8); 436 } else if (arrayInfo.itemsize == 2) { 437 // i16 438 bulkLoadElementType = signless 439 ? mlirIntegerTypeGet(context, 16) 440 : mlirIntegerTypeUnsignedGet(context, 16); 441 } 442 } 443 if (bulkLoadElementType) { 444 auto shapedType = mlirRankedTensorTypeGet( 445 shape.size(), shape.data(), *bulkLoadElementType, encodingAttr); 446 size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize; 447 MlirAttribute attr = mlirDenseElementsAttrRawBufferGet( 448 shapedType, rawBufferSize, arrayInfo.ptr); 449 if (mlirAttributeIsNull(attr)) { 450 throw std::invalid_argument( 451 "DenseElementsAttr could not be constructed from the given buffer. " 452 "This may mean that the Python buffer layout does not match that " 453 "MLIR expected layout and is a bug."); 454 } 455 return PyDenseElementsAttribute(contextWrapper->getRef(), attr); 456 } 457 458 throw std::invalid_argument( 459 std::string("unimplemented array format conversion from format: ") + 460 arrayInfo.format); 461 } 462 463 static PyDenseElementsAttribute getSplat(PyType shapedType, 464 PyAttribute &elementAttr) { 465 auto contextWrapper = 466 PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 467 if (!mlirAttributeIsAInteger(elementAttr) && 468 !mlirAttributeIsAFloat(elementAttr)) { 469 std::string message = "Illegal element type for DenseElementsAttr: "; 470 message.append(py::repr(py::cast(elementAttr))); 471 throw SetPyError(PyExc_ValueError, message); 472 } 473 if (!mlirTypeIsAShaped(shapedType) || 474 !mlirShapedTypeHasStaticShape(shapedType)) { 475 std::string message = 476 "Expected a static ShapedType for the shaped_type parameter: "; 477 message.append(py::repr(py::cast(shapedType))); 478 throw SetPyError(PyExc_ValueError, message); 479 } 480 MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 481 MlirType attrType = mlirAttributeGetType(elementAttr); 482 if (!mlirTypeEqual(shapedElementType, attrType)) { 483 std::string message = 484 "Shaped element type and attribute type must be equal: shaped="; 485 message.append(py::repr(py::cast(shapedType))); 486 message.append(", element="); 487 message.append(py::repr(py::cast(elementAttr))); 488 throw SetPyError(PyExc_ValueError, message); 489 } 490 491 MlirAttribute elements = 492 mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 493 return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 494 } 495 496 intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 497 498 py::buffer_info accessBuffer() { 499 if (mlirDenseElementsAttrIsSplat(*this)) { 500 // TODO: Currently crashes the program. 501 // Reported as https://github.com/pybind/pybind11/issues/3336 502 throw std::invalid_argument( 503 "unsupported data type for conversion to Python buffer"); 504 } 505 506 MlirType shapedType = mlirAttributeGetType(*this); 507 MlirType elementType = mlirShapedTypeGetElementType(shapedType); 508 std::string format; 509 510 if (mlirTypeIsAF32(elementType)) { 511 // f32 512 return bufferInfo<float>(shapedType); 513 } else if (mlirTypeIsAF64(elementType)) { 514 // f64 515 return bufferInfo<double>(shapedType); 516 } else if (mlirTypeIsAF16(elementType)) { 517 // f16 518 return bufferInfo<uint16_t>(shapedType, "e"); 519 } else if (mlirTypeIsAInteger(elementType) && 520 mlirIntegerTypeGetWidth(elementType) == 32) { 521 if (mlirIntegerTypeIsSignless(elementType) || 522 mlirIntegerTypeIsSigned(elementType)) { 523 // i32 524 return bufferInfo<int32_t>(shapedType); 525 } else if (mlirIntegerTypeIsUnsigned(elementType)) { 526 // unsigned i32 527 return bufferInfo<uint32_t>(shapedType); 528 } 529 } else if (mlirTypeIsAInteger(elementType) && 530 mlirIntegerTypeGetWidth(elementType) == 64) { 531 if (mlirIntegerTypeIsSignless(elementType) || 532 mlirIntegerTypeIsSigned(elementType)) { 533 // i64 534 return bufferInfo<int64_t>(shapedType); 535 } else if (mlirIntegerTypeIsUnsigned(elementType)) { 536 // unsigned i64 537 return bufferInfo<uint64_t>(shapedType); 538 } 539 } else if (mlirTypeIsAInteger(elementType) && 540 mlirIntegerTypeGetWidth(elementType) == 8) { 541 if (mlirIntegerTypeIsSignless(elementType) || 542 mlirIntegerTypeIsSigned(elementType)) { 543 // i8 544 return bufferInfo<int8_t>(shapedType); 545 } else if (mlirIntegerTypeIsUnsigned(elementType)) { 546 // unsigned i8 547 return bufferInfo<uint8_t>(shapedType); 548 } 549 } else if (mlirTypeIsAInteger(elementType) && 550 mlirIntegerTypeGetWidth(elementType) == 16) { 551 if (mlirIntegerTypeIsSignless(elementType) || 552 mlirIntegerTypeIsSigned(elementType)) { 553 // i16 554 return bufferInfo<int16_t>(shapedType); 555 } else if (mlirIntegerTypeIsUnsigned(elementType)) { 556 // unsigned i16 557 return bufferInfo<uint16_t>(shapedType); 558 } 559 } 560 561 // TODO: Currently crashes the program. 562 // Reported as https://github.com/pybind/pybind11/issues/3336 563 throw std::invalid_argument( 564 "unsupported data type for conversion to Python buffer"); 565 } 566 567 static void bindDerived(ClassTy &c) { 568 c.def("__len__", &PyDenseElementsAttribute::dunderLen) 569 .def_static("get", PyDenseElementsAttribute::getFromBuffer, 570 py::arg("array"), py::arg("signless") = true, 571 py::arg("type") = py::none(), py::arg("shape") = py::none(), 572 py::arg("context") = py::none(), 573 kDenseElementsAttrGetDocstring) 574 .def_static("get_splat", PyDenseElementsAttribute::getSplat, 575 py::arg("shaped_type"), py::arg("element_attr"), 576 "Gets a DenseElementsAttr where all values are the same") 577 .def_property_readonly("is_splat", 578 [](PyDenseElementsAttribute &self) -> bool { 579 return mlirDenseElementsAttrIsSplat(self); 580 }) 581 .def_buffer(&PyDenseElementsAttribute::accessBuffer); 582 } 583 584 private: 585 static bool isUnsignedIntegerFormat(const std::string &format) { 586 if (format.empty()) 587 return false; 588 char code = format[0]; 589 return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 590 code == 'Q'; 591 } 592 593 static bool isSignedIntegerFormat(const std::string &format) { 594 if (format.empty()) 595 return false; 596 char code = format[0]; 597 return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 598 code == 'q'; 599 } 600 601 template <typename Type> 602 py::buffer_info bufferInfo(MlirType shapedType, 603 const char *explicitFormat = nullptr) { 604 intptr_t rank = mlirShapedTypeGetRank(shapedType); 605 // Prepare the data for the buffer_info. 606 // Buffer is configured for read-only access below. 607 Type *data = static_cast<Type *>( 608 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 609 // Prepare the shape for the buffer_info. 610 SmallVector<intptr_t, 4> shape; 611 for (intptr_t i = 0; i < rank; ++i) 612 shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 613 // Prepare the strides for the buffer_info. 614 SmallVector<intptr_t, 4> strides; 615 intptr_t strideFactor = 1; 616 for (intptr_t i = 1; i < rank; ++i) { 617 strideFactor = 1; 618 for (intptr_t j = i; j < rank; ++j) { 619 strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 620 } 621 strides.push_back(sizeof(Type) * strideFactor); 622 } 623 strides.push_back(sizeof(Type)); 624 std::string format; 625 if (explicitFormat) { 626 format = explicitFormat; 627 } else { 628 format = py::format_descriptor<Type>::format(); 629 } 630 return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, 631 /*readonly=*/true); 632 } 633 }; // namespace 634 635 /// Refinement of the PyDenseElementsAttribute for attributes containing integer 636 /// (and boolean) values. Supports element access. 637 class PyDenseIntElementsAttribute 638 : public PyConcreteAttribute<PyDenseIntElementsAttribute, 639 PyDenseElementsAttribute> { 640 public: 641 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 642 static constexpr const char *pyClassName = "DenseIntElementsAttr"; 643 using PyConcreteAttribute::PyConcreteAttribute; 644 645 /// Returns the element at the given linear position. Asserts if the index is 646 /// out of range. 647 py::int_ dunderGetItem(intptr_t pos) { 648 if (pos < 0 || pos >= dunderLen()) { 649 throw SetPyError(PyExc_IndexError, 650 "attempt to access out of bounds element"); 651 } 652 653 MlirType type = mlirAttributeGetType(*this); 654 type = mlirShapedTypeGetElementType(type); 655 assert(mlirTypeIsAInteger(type) && 656 "expected integer element type in dense int elements attribute"); 657 // Dispatch element extraction to an appropriate C function based on the 658 // elemental type of the attribute. py::int_ is implicitly constructible 659 // from any C++ integral type and handles bitwidth correctly. 660 // TODO: consider caching the type properties in the constructor to avoid 661 // querying them on each element access. 662 unsigned width = mlirIntegerTypeGetWidth(type); 663 bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 664 if (isUnsigned) { 665 if (width == 1) { 666 return mlirDenseElementsAttrGetBoolValue(*this, pos); 667 } 668 if (width == 32) { 669 return mlirDenseElementsAttrGetUInt32Value(*this, pos); 670 } 671 if (width == 64) { 672 return mlirDenseElementsAttrGetUInt64Value(*this, pos); 673 } 674 } else { 675 if (width == 1) { 676 return mlirDenseElementsAttrGetBoolValue(*this, pos); 677 } 678 if (width == 32) { 679 return mlirDenseElementsAttrGetInt32Value(*this, pos); 680 } 681 if (width == 64) { 682 return mlirDenseElementsAttrGetInt64Value(*this, pos); 683 } 684 } 685 throw SetPyError(PyExc_TypeError, "Unsupported integer type"); 686 } 687 688 static void bindDerived(ClassTy &c) { 689 c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 690 } 691 }; 692 693 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 694 public: 695 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 696 static constexpr const char *pyClassName = "DictAttr"; 697 using PyConcreteAttribute::PyConcreteAttribute; 698 699 intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 700 701 bool dunderContains(const std::string &name) { 702 return !mlirAttributeIsNull( 703 mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); 704 } 705 706 static void bindDerived(ClassTy &c) { 707 c.def("__contains__", &PyDictAttribute::dunderContains); 708 c.def("__len__", &PyDictAttribute::dunderLen); 709 c.def_static( 710 "get", 711 [](py::dict attributes, DefaultingPyMlirContext context) { 712 SmallVector<MlirNamedAttribute> mlirNamedAttributes; 713 mlirNamedAttributes.reserve(attributes.size()); 714 for (auto &it : attributes) { 715 auto &mlir_attr = it.second.cast<PyAttribute &>(); 716 auto name = it.first.cast<std::string>(); 717 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 718 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr), 719 toMlirStringRef(name)), 720 mlir_attr)); 721 } 722 MlirAttribute attr = 723 mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 724 mlirNamedAttributes.data()); 725 return PyDictAttribute(context->getRef(), attr); 726 }, 727 py::arg("value") = py::dict(), py::arg("context") = py::none(), 728 "Gets an uniqued dict attribute"); 729 c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 730 MlirAttribute attr = 731 mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 732 if (mlirAttributeIsNull(attr)) { 733 throw SetPyError(PyExc_KeyError, 734 "attempt to access a non-existent attribute"); 735 } 736 return PyAttribute(self.getContext(), attr); 737 }); 738 c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 739 if (index < 0 || index >= self.dunderLen()) { 740 throw SetPyError(PyExc_IndexError, 741 "attempt to access out of bounds attribute"); 742 } 743 MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 744 return PyNamedAttribute( 745 namedAttr.attribute, 746 std::string(mlirIdentifierStr(namedAttr.name).data)); 747 }); 748 } 749 }; 750 751 /// Refinement of PyDenseElementsAttribute for attributes containing 752 /// floating-point values. Supports element access. 753 class PyDenseFPElementsAttribute 754 : public PyConcreteAttribute<PyDenseFPElementsAttribute, 755 PyDenseElementsAttribute> { 756 public: 757 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 758 static constexpr const char *pyClassName = "DenseFPElementsAttr"; 759 using PyConcreteAttribute::PyConcreteAttribute; 760 761 py::float_ dunderGetItem(intptr_t pos) { 762 if (pos < 0 || pos >= dunderLen()) { 763 throw SetPyError(PyExc_IndexError, 764 "attempt to access out of bounds element"); 765 } 766 767 MlirType type = mlirAttributeGetType(*this); 768 type = mlirShapedTypeGetElementType(type); 769 // Dispatch element extraction to an appropriate C function based on the 770 // elemental type of the attribute. py::float_ is implicitly constructible 771 // from float and double. 772 // TODO: consider caching the type properties in the constructor to avoid 773 // querying them on each element access. 774 if (mlirTypeIsAF32(type)) { 775 return mlirDenseElementsAttrGetFloatValue(*this, pos); 776 } 777 if (mlirTypeIsAF64(type)) { 778 return mlirDenseElementsAttrGetDoubleValue(*this, pos); 779 } 780 throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); 781 } 782 783 static void bindDerived(ClassTy &c) { 784 c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 785 } 786 }; 787 788 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 789 public: 790 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 791 static constexpr const char *pyClassName = "TypeAttr"; 792 using PyConcreteAttribute::PyConcreteAttribute; 793 794 static void bindDerived(ClassTy &c) { 795 c.def_static( 796 "get", 797 [](PyType value, DefaultingPyMlirContext context) { 798 MlirAttribute attr = mlirTypeAttrGet(value.get()); 799 return PyTypeAttribute(context->getRef(), attr); 800 }, 801 py::arg("value"), py::arg("context") = py::none(), 802 "Gets a uniqued Type attribute"); 803 c.def_property_readonly("value", [](PyTypeAttribute &self) { 804 return PyType(self.getContext()->getRef(), 805 mlirTypeAttrGetValue(self.get())); 806 }); 807 } 808 }; 809 810 /// Unit Attribute subclass. Unit attributes don't have values. 811 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 812 public: 813 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 814 static constexpr const char *pyClassName = "UnitAttr"; 815 using PyConcreteAttribute::PyConcreteAttribute; 816 817 static void bindDerived(ClassTy &c) { 818 c.def_static( 819 "get", 820 [](DefaultingPyMlirContext context) { 821 return PyUnitAttribute(context->getRef(), 822 mlirUnitAttrGet(context->get())); 823 }, 824 py::arg("context") = py::none(), "Create a Unit attribute."); 825 } 826 }; 827 828 } // namespace 829 830 void mlir::python::populateIRAttributes(py::module &m) { 831 PyAffineMapAttribute::bind(m); 832 PyArrayAttribute::bind(m); 833 PyArrayAttribute::PyArrayAttributeIterator::bind(m); 834 PyBoolAttribute::bind(m); 835 PyDenseElementsAttribute::bind(m); 836 PyDenseFPElementsAttribute::bind(m); 837 PyDenseIntElementsAttribute::bind(m); 838 PyDictAttribute::bind(m); 839 PyFlatSymbolRefAttribute::bind(m); 840 PyFloatAttribute::bind(m); 841 PyIntegerAttribute::bind(m); 842 PyStringAttribute::bind(m); 843 PyTypeAttribute::bind(m); 844 PyUnitAttribute::bind(m); 845 } 846