1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include <utility>
10
11 #include "IRModule.h"
12
13 #include "PybindUtils.h"
14
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17
18 namespace py = pybind11;
19 using namespace mlir;
20 using namespace mlir::python;
21
22 using llvm::Optional;
23 using llvm::SmallVector;
24 using llvm::Twine;
25
26 //------------------------------------------------------------------------------
27 // Docstrings (trivial, non-duplicated docstrings are included inline).
28 //------------------------------------------------------------------------------
29
30 static const char kDenseElementsAttrGetDocstring[] =
31 R"(Gets a DenseElementsAttr from a Python buffer or array.
32
33 When `type` is not provided, then some limited type inferencing is done based
34 on the buffer format. Support presently exists for 8/16/32/64 signed and
35 unsigned integers and float16/float32/float64. DenseElementsAttrs of these
36 types can also be converted back to a corresponding buffer.
37
38 For conversions outside of these types, a `type=` must be explicitly provided
39 and the buffer contents must be bit-castable to the MLIR internal
40 representation:
41
42 * Integer types (except for i1): the buffer must be byte aligned to the
43 next byte boundary.
44 * Floating point types: Must be bit-castable to the given floating point
45 size.
46 * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
47 row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
48 this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
49
50 If a single element buffer is passed (or for i1, a single byte with value 0
51 or 255), then a splat will be created.
52
53 Args:
54 array: The array or buffer to convert.
55 signless: If inferring an appropriate MLIR type, use signless types for
56 integers (defaults True).
57 type: Skips inference of the MLIR element type and uses this instead. The
58 storage size must be consistent with the actual contents of the buffer.
59 shape: Overrides the shape of the buffer when constructing the MLIR
60 shaped type. This is needed when the physical and logical shape differ (as
61 for i1).
62 context: Explicit context, if not from context manager.
63
64 Returns:
65 DenseElementsAttr on success.
66
67 Raises:
68 ValueError: If the type of the buffer or array cannot be matched to an MLIR
69 type or if the buffer does not meet expectations.
70 )";
71
72 namespace {
73
toMlirStringRef(const std::string & s)74 static MlirStringRef toMlirStringRef(const std::string &s) {
75 return mlirStringRefCreate(s.data(), s.size());
76 }
77
78 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
79 public:
80 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
81 static constexpr const char *pyClassName = "AffineMapAttr";
82 using PyConcreteAttribute::PyConcreteAttribute;
83
bindDerived(ClassTy & c)84 static void bindDerived(ClassTy &c) {
85 c.def_static(
86 "get",
87 [](PyAffineMap &affineMap) {
88 MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
89 return PyAffineMapAttribute(affineMap.getContext(), attr);
90 },
91 py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
92 }
93 };
94
95 template <typename T>
pyTryCast(py::handle object)96 static T pyTryCast(py::handle object) {
97 try {
98 return object.cast<T>();
99 } catch (py::cast_error &err) {
100 std::string msg =
101 std::string(
102 "Invalid attribute when attempting to create an ArrayAttribute (") +
103 err.what() + ")";
104 throw py::cast_error(msg);
105 } catch (py::reference_cast_error &err) {
106 std::string msg = std::string("Invalid attribute (None?) when attempting "
107 "to create an ArrayAttribute (") +
108 err.what() + ")";
109 throw py::cast_error(msg);
110 }
111 }
112
113 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
114 public:
115 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
116 static constexpr const char *pyClassName = "ArrayAttr";
117 using PyConcreteAttribute::PyConcreteAttribute;
118
119 class PyArrayAttributeIterator {
120 public:
PyArrayAttributeIterator(PyAttribute attr)121 PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
122
dunderIter()123 PyArrayAttributeIterator &dunderIter() { return *this; }
124
dunderNext()125 PyAttribute dunderNext() {
126 if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
127 throw py::stop_iteration();
128 }
129 return PyAttribute(attr.getContext(),
130 mlirArrayAttrGetElement(attr.get(), nextIndex++));
131 }
132
bind(py::module & m)133 static void bind(py::module &m) {
134 py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
135 py::module_local())
136 .def("__iter__", &PyArrayAttributeIterator::dunderIter)
137 .def("__next__", &PyArrayAttributeIterator::dunderNext);
138 }
139
140 private:
141 PyAttribute attr;
142 int nextIndex = 0;
143 };
144
getItem(intptr_t i)145 PyAttribute getItem(intptr_t i) {
146 return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
147 }
148
bindDerived(ClassTy & c)149 static void bindDerived(ClassTy &c) {
150 c.def_static(
151 "get",
152 [](py::list attributes, DefaultingPyMlirContext context) {
153 SmallVector<MlirAttribute> mlirAttributes;
154 mlirAttributes.reserve(py::len(attributes));
155 for (auto attribute : attributes) {
156 mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
157 }
158 MlirAttribute attr = mlirArrayAttrGet(
159 context->get(), mlirAttributes.size(), mlirAttributes.data());
160 return PyArrayAttribute(context->getRef(), attr);
161 },
162 py::arg("attributes"), py::arg("context") = py::none(),
163 "Gets a uniqued Array attribute");
164 c.def("__getitem__",
165 [](PyArrayAttribute &arr, intptr_t i) {
166 if (i >= mlirArrayAttrGetNumElements(arr))
167 throw py::index_error("ArrayAttribute index out of range");
168 return arr.getItem(i);
169 })
170 .def("__len__",
171 [](const PyArrayAttribute &arr) {
172 return mlirArrayAttrGetNumElements(arr);
173 })
174 .def("__iter__", [](const PyArrayAttribute &arr) {
175 return PyArrayAttributeIterator(arr);
176 });
177 c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
178 std::vector<MlirAttribute> attributes;
179 intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
180 attributes.reserve(numOldElements + py::len(extras));
181 for (intptr_t i = 0; i < numOldElements; ++i)
182 attributes.push_back(arr.getItem(i));
183 for (py::handle attr : extras)
184 attributes.push_back(pyTryCast<PyAttribute>(attr));
185 MlirAttribute arrayAttr = mlirArrayAttrGet(
186 arr.getContext()->get(), attributes.size(), attributes.data());
187 return PyArrayAttribute(arr.getContext(), arrayAttr);
188 });
189 }
190 };
191
192 /// Float Point Attribute subclass - FloatAttr.
193 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
194 public:
195 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
196 static constexpr const char *pyClassName = "FloatAttr";
197 using PyConcreteAttribute::PyConcreteAttribute;
198
bindDerived(ClassTy & c)199 static void bindDerived(ClassTy &c) {
200 c.def_static(
201 "get",
202 [](PyType &type, double value, DefaultingPyLocation loc) {
203 MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
204 // TODO: Rework error reporting once diagnostic engine is exposed
205 // in C API.
206 if (mlirAttributeIsNull(attr)) {
207 throw SetPyError(PyExc_ValueError,
208 Twine("invalid '") +
209 py::repr(py::cast(type)).cast<std::string>() +
210 "' and expected floating point type.");
211 }
212 return PyFloatAttribute(type.getContext(), attr);
213 },
214 py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
215 "Gets an uniqued float point attribute associated to a type");
216 c.def_static(
217 "get_f32",
218 [](double value, DefaultingPyMlirContext context) {
219 MlirAttribute attr = mlirFloatAttrDoubleGet(
220 context->get(), mlirF32TypeGet(context->get()), value);
221 return PyFloatAttribute(context->getRef(), attr);
222 },
223 py::arg("value"), py::arg("context") = py::none(),
224 "Gets an uniqued float point attribute associated to a f32 type");
225 c.def_static(
226 "get_f64",
227 [](double value, DefaultingPyMlirContext context) {
228 MlirAttribute attr = mlirFloatAttrDoubleGet(
229 context->get(), mlirF64TypeGet(context->get()), value);
230 return PyFloatAttribute(context->getRef(), attr);
231 },
232 py::arg("value"), py::arg("context") = py::none(),
233 "Gets an uniqued float point attribute associated to a f64 type");
234 c.def_property_readonly(
235 "value",
236 [](PyFloatAttribute &self) {
237 return mlirFloatAttrGetValueDouble(self);
238 },
239 "Returns the value of the float point attribute");
240 }
241 };
242
243 /// Integer Attribute subclass - IntegerAttr.
244 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
245 public:
246 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
247 static constexpr const char *pyClassName = "IntegerAttr";
248 using PyConcreteAttribute::PyConcreteAttribute;
249
bindDerived(ClassTy & c)250 static void bindDerived(ClassTy &c) {
251 c.def_static(
252 "get",
253 [](PyType &type, int64_t value) {
254 MlirAttribute attr = mlirIntegerAttrGet(type, value);
255 return PyIntegerAttribute(type.getContext(), attr);
256 },
257 py::arg("type"), py::arg("value"),
258 "Gets an uniqued integer attribute associated to a type");
259 c.def_property_readonly(
260 "value",
261 [](PyIntegerAttribute &self) -> py::int_ {
262 MlirType type = mlirAttributeGetType(self);
263 if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
264 return mlirIntegerAttrGetValueInt(self);
265 if (mlirIntegerTypeIsSigned(type))
266 return mlirIntegerAttrGetValueSInt(self);
267 return mlirIntegerAttrGetValueUInt(self);
268 },
269 "Returns the value of the integer attribute");
270 }
271 };
272
273 /// Bool Attribute subclass - BoolAttr.
274 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
275 public:
276 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
277 static constexpr const char *pyClassName = "BoolAttr";
278 using PyConcreteAttribute::PyConcreteAttribute;
279
bindDerived(ClassTy & c)280 static void bindDerived(ClassTy &c) {
281 c.def_static(
282 "get",
283 [](bool value, DefaultingPyMlirContext context) {
284 MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
285 return PyBoolAttribute(context->getRef(), attr);
286 },
287 py::arg("value"), py::arg("context") = py::none(),
288 "Gets an uniqued bool attribute");
289 c.def_property_readonly(
290 "value",
291 [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
292 "Returns the value of the bool attribute");
293 }
294 };
295
296 class PyFlatSymbolRefAttribute
297 : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
298 public:
299 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
300 static constexpr const char *pyClassName = "FlatSymbolRefAttr";
301 using PyConcreteAttribute::PyConcreteAttribute;
302
bindDerived(ClassTy & c)303 static void bindDerived(ClassTy &c) {
304 c.def_static(
305 "get",
306 [](std::string value, DefaultingPyMlirContext context) {
307 MlirAttribute attr =
308 mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
309 return PyFlatSymbolRefAttribute(context->getRef(), attr);
310 },
311 py::arg("value"), py::arg("context") = py::none(),
312 "Gets a uniqued FlatSymbolRef attribute");
313 c.def_property_readonly(
314 "value",
315 [](PyFlatSymbolRefAttribute &self) {
316 MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
317 return py::str(stringRef.data, stringRef.length);
318 },
319 "Returns the value of the FlatSymbolRef attribute as a string");
320 }
321 };
322
323 class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
324 public:
325 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
326 static constexpr const char *pyClassName = "OpaqueAttr";
327 using PyConcreteAttribute::PyConcreteAttribute;
328
bindDerived(ClassTy & c)329 static void bindDerived(ClassTy &c) {
330 c.def_static(
331 "get",
332 [](std::string dialectNamespace, py::buffer buffer, PyType &type,
333 DefaultingPyMlirContext context) {
334 const py::buffer_info bufferInfo = buffer.request();
335 intptr_t bufferSize = bufferInfo.size;
336 MlirAttribute attr = mlirOpaqueAttrGet(
337 context->get(), toMlirStringRef(dialectNamespace), bufferSize,
338 static_cast<char *>(bufferInfo.ptr), type);
339 return PyOpaqueAttribute(context->getRef(), attr);
340 },
341 py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"),
342 py::arg("context") = py::none(), "Gets an Opaque attribute.");
343 c.def_property_readonly(
344 "dialect_namespace",
345 [](PyOpaqueAttribute &self) {
346 MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
347 return py::str(stringRef.data, stringRef.length);
348 },
349 "Returns the dialect namespace for the Opaque attribute as a string");
350 c.def_property_readonly(
351 "data",
352 [](PyOpaqueAttribute &self) {
353 MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
354 return py::str(stringRef.data, stringRef.length);
355 },
356 "Returns the data for the Opaqued attributes as a string");
357 }
358 };
359
360 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
361 public:
362 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
363 static constexpr const char *pyClassName = "StringAttr";
364 using PyConcreteAttribute::PyConcreteAttribute;
365
bindDerived(ClassTy & c)366 static void bindDerived(ClassTy &c) {
367 c.def_static(
368 "get",
369 [](std::string value, DefaultingPyMlirContext context) {
370 MlirAttribute attr =
371 mlirStringAttrGet(context->get(), toMlirStringRef(value));
372 return PyStringAttribute(context->getRef(), attr);
373 },
374 py::arg("value"), py::arg("context") = py::none(),
375 "Gets a uniqued string attribute");
376 c.def_static(
377 "get_typed",
378 [](PyType &type, std::string value) {
379 MlirAttribute attr =
380 mlirStringAttrTypedGet(type, toMlirStringRef(value));
381 return PyStringAttribute(type.getContext(), attr);
382 },
383 py::arg("type"), py::arg("value"),
384 "Gets a uniqued string attribute associated to a type");
385 c.def_property_readonly(
386 "value",
387 [](PyStringAttribute &self) {
388 MlirStringRef stringRef = mlirStringAttrGetValue(self);
389 return py::str(stringRef.data, stringRef.length);
390 },
391 "Returns the value of the string attribute");
392 }
393 };
394
395 // TODO: Support construction of string elements.
396 class PyDenseElementsAttribute
397 : public PyConcreteAttribute<PyDenseElementsAttribute> {
398 public:
399 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
400 static constexpr const char *pyClassName = "DenseElementsAttr";
401 using PyConcreteAttribute::PyConcreteAttribute;
402
403 static PyDenseElementsAttribute
getFromBuffer(py::buffer array,bool signless,Optional<PyType> explicitType,Optional<std::vector<int64_t>> explicitShape,DefaultingPyMlirContext contextWrapper)404 getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType,
405 Optional<std::vector<int64_t>> explicitShape,
406 DefaultingPyMlirContext contextWrapper) {
407 // Request a contiguous view. In exotic cases, this will cause a copy.
408 int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
409 Py_buffer *view = new Py_buffer();
410 if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
411 delete view;
412 throw py::error_already_set();
413 }
414 py::buffer_info arrayInfo(view);
415 SmallVector<int64_t> shape;
416 if (explicitShape) {
417 shape.append(explicitShape->begin(), explicitShape->end());
418 } else {
419 shape.append(arrayInfo.shape.begin(),
420 arrayInfo.shape.begin() + arrayInfo.ndim);
421 }
422
423 MlirAttribute encodingAttr = mlirAttributeGetNull();
424 MlirContext context = contextWrapper->get();
425
426 // Detect format codes that are suitable for bulk loading. This includes
427 // all byte aligned integer and floating point types up to 8 bytes.
428 // Notably, this excludes, bool (which needs to be bit-packed) and
429 // other exotics which do not have a direct representation in the buffer
430 // protocol (i.e. complex, etc).
431 Optional<MlirType> bulkLoadElementType;
432 if (explicitType) {
433 bulkLoadElementType = *explicitType;
434 } else if (arrayInfo.format == "f") {
435 // f32
436 assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
437 bulkLoadElementType = mlirF32TypeGet(context);
438 } else if (arrayInfo.format == "d") {
439 // f64
440 assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
441 bulkLoadElementType = mlirF64TypeGet(context);
442 } else if (arrayInfo.format == "e") {
443 // f16
444 assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
445 bulkLoadElementType = mlirF16TypeGet(context);
446 } else if (isSignedIntegerFormat(arrayInfo.format)) {
447 if (arrayInfo.itemsize == 4) {
448 // i32
449 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
450 : mlirIntegerTypeSignedGet(context, 32);
451 } else if (arrayInfo.itemsize == 8) {
452 // i64
453 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
454 : mlirIntegerTypeSignedGet(context, 64);
455 } else if (arrayInfo.itemsize == 1) {
456 // i8
457 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
458 : mlirIntegerTypeSignedGet(context, 8);
459 } else if (arrayInfo.itemsize == 2) {
460 // i16
461 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
462 : mlirIntegerTypeSignedGet(context, 16);
463 }
464 } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
465 if (arrayInfo.itemsize == 4) {
466 // unsigned i32
467 bulkLoadElementType = signless
468 ? mlirIntegerTypeGet(context, 32)
469 : mlirIntegerTypeUnsignedGet(context, 32);
470 } else if (arrayInfo.itemsize == 8) {
471 // unsigned i64
472 bulkLoadElementType = signless
473 ? mlirIntegerTypeGet(context, 64)
474 : mlirIntegerTypeUnsignedGet(context, 64);
475 } else if (arrayInfo.itemsize == 1) {
476 // i8
477 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
478 : mlirIntegerTypeUnsignedGet(context, 8);
479 } else if (arrayInfo.itemsize == 2) {
480 // i16
481 bulkLoadElementType = signless
482 ? mlirIntegerTypeGet(context, 16)
483 : mlirIntegerTypeUnsignedGet(context, 16);
484 }
485 }
486 if (bulkLoadElementType) {
487 auto shapedType = mlirRankedTensorTypeGet(
488 shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
489 size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
490 MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
491 shapedType, rawBufferSize, arrayInfo.ptr);
492 if (mlirAttributeIsNull(attr)) {
493 throw std::invalid_argument(
494 "DenseElementsAttr could not be constructed from the given buffer. "
495 "This may mean that the Python buffer layout does not match that "
496 "MLIR expected layout and is a bug.");
497 }
498 return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
499 }
500
501 throw std::invalid_argument(
502 std::string("unimplemented array format conversion from format: ") +
503 arrayInfo.format);
504 }
505
getSplat(const PyType & shapedType,PyAttribute & elementAttr)506 static PyDenseElementsAttribute getSplat(const PyType &shapedType,
507 PyAttribute &elementAttr) {
508 auto contextWrapper =
509 PyMlirContext::forContext(mlirTypeGetContext(shapedType));
510 if (!mlirAttributeIsAInteger(elementAttr) &&
511 !mlirAttributeIsAFloat(elementAttr)) {
512 std::string message = "Illegal element type for DenseElementsAttr: ";
513 message.append(py::repr(py::cast(elementAttr)));
514 throw SetPyError(PyExc_ValueError, message);
515 }
516 if (!mlirTypeIsAShaped(shapedType) ||
517 !mlirShapedTypeHasStaticShape(shapedType)) {
518 std::string message =
519 "Expected a static ShapedType for the shaped_type parameter: ";
520 message.append(py::repr(py::cast(shapedType)));
521 throw SetPyError(PyExc_ValueError, message);
522 }
523 MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
524 MlirType attrType = mlirAttributeGetType(elementAttr);
525 if (!mlirTypeEqual(shapedElementType, attrType)) {
526 std::string message =
527 "Shaped element type and attribute type must be equal: shaped=";
528 message.append(py::repr(py::cast(shapedType)));
529 message.append(", element=");
530 message.append(py::repr(py::cast(elementAttr)));
531 throw SetPyError(PyExc_ValueError, message);
532 }
533
534 MlirAttribute elements =
535 mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
536 return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
537 }
538
dunderLen()539 intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
540
accessBuffer()541 py::buffer_info accessBuffer() {
542 if (mlirDenseElementsAttrIsSplat(*this)) {
543 // TODO: Currently crashes the program.
544 // Reported as https://github.com/pybind/pybind11/issues/3336
545 throw std::invalid_argument(
546 "unsupported data type for conversion to Python buffer");
547 }
548
549 MlirType shapedType = mlirAttributeGetType(*this);
550 MlirType elementType = mlirShapedTypeGetElementType(shapedType);
551 std::string format;
552
553 if (mlirTypeIsAF32(elementType)) {
554 // f32
555 return bufferInfo<float>(shapedType);
556 }
557 if (mlirTypeIsAF64(elementType)) {
558 // f64
559 return bufferInfo<double>(shapedType);
560 }
561 if (mlirTypeIsAF16(elementType)) {
562 // f16
563 return bufferInfo<uint16_t>(shapedType, "e");
564 }
565 if (mlirTypeIsAInteger(elementType) &&
566 mlirIntegerTypeGetWidth(elementType) == 32) {
567 if (mlirIntegerTypeIsSignless(elementType) ||
568 mlirIntegerTypeIsSigned(elementType)) {
569 // i32
570 return bufferInfo<int32_t>(shapedType);
571 }
572 if (mlirIntegerTypeIsUnsigned(elementType)) {
573 // unsigned i32
574 return bufferInfo<uint32_t>(shapedType);
575 }
576 } else if (mlirTypeIsAInteger(elementType) &&
577 mlirIntegerTypeGetWidth(elementType) == 64) {
578 if (mlirIntegerTypeIsSignless(elementType) ||
579 mlirIntegerTypeIsSigned(elementType)) {
580 // i64
581 return bufferInfo<int64_t>(shapedType);
582 }
583 if (mlirIntegerTypeIsUnsigned(elementType)) {
584 // unsigned i64
585 return bufferInfo<uint64_t>(shapedType);
586 }
587 } else if (mlirTypeIsAInteger(elementType) &&
588 mlirIntegerTypeGetWidth(elementType) == 8) {
589 if (mlirIntegerTypeIsSignless(elementType) ||
590 mlirIntegerTypeIsSigned(elementType)) {
591 // i8
592 return bufferInfo<int8_t>(shapedType);
593 }
594 if (mlirIntegerTypeIsUnsigned(elementType)) {
595 // unsigned i8
596 return bufferInfo<uint8_t>(shapedType);
597 }
598 } else if (mlirTypeIsAInteger(elementType) &&
599 mlirIntegerTypeGetWidth(elementType) == 16) {
600 if (mlirIntegerTypeIsSignless(elementType) ||
601 mlirIntegerTypeIsSigned(elementType)) {
602 // i16
603 return bufferInfo<int16_t>(shapedType);
604 }
605 if (mlirIntegerTypeIsUnsigned(elementType)) {
606 // unsigned i16
607 return bufferInfo<uint16_t>(shapedType);
608 }
609 }
610
611 // TODO: Currently crashes the program.
612 // Reported as https://github.com/pybind/pybind11/issues/3336
613 throw std::invalid_argument(
614 "unsupported data type for conversion to Python buffer");
615 }
616
bindDerived(ClassTy & c)617 static void bindDerived(ClassTy &c) {
618 c.def("__len__", &PyDenseElementsAttribute::dunderLen)
619 .def_static("get", PyDenseElementsAttribute::getFromBuffer,
620 py::arg("array"), py::arg("signless") = true,
621 py::arg("type") = py::none(), py::arg("shape") = py::none(),
622 py::arg("context") = py::none(),
623 kDenseElementsAttrGetDocstring)
624 .def_static("get_splat", PyDenseElementsAttribute::getSplat,
625 py::arg("shaped_type"), py::arg("element_attr"),
626 "Gets a DenseElementsAttr where all values are the same")
627 .def_property_readonly("is_splat",
628 [](PyDenseElementsAttribute &self) -> bool {
629 return mlirDenseElementsAttrIsSplat(self);
630 })
631 .def_buffer(&PyDenseElementsAttribute::accessBuffer);
632 }
633
634 private:
isUnsignedIntegerFormat(const std::string & format)635 static bool isUnsignedIntegerFormat(const std::string &format) {
636 if (format.empty())
637 return false;
638 char code = format[0];
639 return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
640 code == 'Q';
641 }
642
isSignedIntegerFormat(const std::string & format)643 static bool isSignedIntegerFormat(const std::string &format) {
644 if (format.empty())
645 return false;
646 char code = format[0];
647 return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
648 code == 'q';
649 }
650
651 template <typename Type>
bufferInfo(MlirType shapedType,const char * explicitFormat=nullptr)652 py::buffer_info bufferInfo(MlirType shapedType,
653 const char *explicitFormat = nullptr) {
654 intptr_t rank = mlirShapedTypeGetRank(shapedType);
655 // Prepare the data for the buffer_info.
656 // Buffer is configured for read-only access below.
657 Type *data = static_cast<Type *>(
658 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
659 // Prepare the shape for the buffer_info.
660 SmallVector<intptr_t, 4> shape;
661 for (intptr_t i = 0; i < rank; ++i)
662 shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
663 // Prepare the strides for the buffer_info.
664 SmallVector<intptr_t, 4> strides;
665 intptr_t strideFactor = 1;
666 for (intptr_t i = 1; i < rank; ++i) {
667 strideFactor = 1;
668 for (intptr_t j = i; j < rank; ++j) {
669 strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
670 }
671 strides.push_back(sizeof(Type) * strideFactor);
672 }
673 strides.push_back(sizeof(Type));
674 std::string format;
675 if (explicitFormat) {
676 format = explicitFormat;
677 } else {
678 format = py::format_descriptor<Type>::format();
679 }
680 return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
681 /*readonly=*/true);
682 }
683 }; // namespace
684
685 /// Refinement of the PyDenseElementsAttribute for attributes containing integer
686 /// (and boolean) values. Supports element access.
687 class PyDenseIntElementsAttribute
688 : public PyConcreteAttribute<PyDenseIntElementsAttribute,
689 PyDenseElementsAttribute> {
690 public:
691 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
692 static constexpr const char *pyClassName = "DenseIntElementsAttr";
693 using PyConcreteAttribute::PyConcreteAttribute;
694
695 /// Returns the element at the given linear position. Asserts if the index is
696 /// out of range.
dunderGetItem(intptr_t pos)697 py::int_ dunderGetItem(intptr_t pos) {
698 if (pos < 0 || pos >= dunderLen()) {
699 throw SetPyError(PyExc_IndexError,
700 "attempt to access out of bounds element");
701 }
702
703 MlirType type = mlirAttributeGetType(*this);
704 type = mlirShapedTypeGetElementType(type);
705 assert(mlirTypeIsAInteger(type) &&
706 "expected integer element type in dense int elements attribute");
707 // Dispatch element extraction to an appropriate C function based on the
708 // elemental type of the attribute. py::int_ is implicitly constructible
709 // from any C++ integral type and handles bitwidth correctly.
710 // TODO: consider caching the type properties in the constructor to avoid
711 // querying them on each element access.
712 unsigned width = mlirIntegerTypeGetWidth(type);
713 bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
714 if (isUnsigned) {
715 if (width == 1) {
716 return mlirDenseElementsAttrGetBoolValue(*this, pos);
717 }
718 if (width == 8) {
719 return mlirDenseElementsAttrGetUInt8Value(*this, pos);
720 }
721 if (width == 16) {
722 return mlirDenseElementsAttrGetUInt16Value(*this, pos);
723 }
724 if (width == 32) {
725 return mlirDenseElementsAttrGetUInt32Value(*this, pos);
726 }
727 if (width == 64) {
728 return mlirDenseElementsAttrGetUInt64Value(*this, pos);
729 }
730 } else {
731 if (width == 1) {
732 return mlirDenseElementsAttrGetBoolValue(*this, pos);
733 }
734 if (width == 8) {
735 return mlirDenseElementsAttrGetInt8Value(*this, pos);
736 }
737 if (width == 16) {
738 return mlirDenseElementsAttrGetInt16Value(*this, pos);
739 }
740 if (width == 32) {
741 return mlirDenseElementsAttrGetInt32Value(*this, pos);
742 }
743 if (width == 64) {
744 return mlirDenseElementsAttrGetInt64Value(*this, pos);
745 }
746 }
747 throw SetPyError(PyExc_TypeError, "Unsupported integer type");
748 }
749
bindDerived(ClassTy & c)750 static void bindDerived(ClassTy &c) {
751 c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
752 }
753 };
754
755 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
756 public:
757 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
758 static constexpr const char *pyClassName = "DictAttr";
759 using PyConcreteAttribute::PyConcreteAttribute;
760
dunderLen()761 intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
762
dunderContains(const std::string & name)763 bool dunderContains(const std::string &name) {
764 return !mlirAttributeIsNull(
765 mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
766 }
767
bindDerived(ClassTy & c)768 static void bindDerived(ClassTy &c) {
769 c.def("__contains__", &PyDictAttribute::dunderContains);
770 c.def("__len__", &PyDictAttribute::dunderLen);
771 c.def_static(
772 "get",
773 [](py::dict attributes, DefaultingPyMlirContext context) {
774 SmallVector<MlirNamedAttribute> mlirNamedAttributes;
775 mlirNamedAttributes.reserve(attributes.size());
776 for (auto &it : attributes) {
777 auto &mlirAttr = it.second.cast<PyAttribute &>();
778 auto name = it.first.cast<std::string>();
779 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
780 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
781 toMlirStringRef(name)),
782 mlirAttr));
783 }
784 MlirAttribute attr =
785 mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
786 mlirNamedAttributes.data());
787 return PyDictAttribute(context->getRef(), attr);
788 },
789 py::arg("value") = py::dict(), py::arg("context") = py::none(),
790 "Gets an uniqued dict attribute");
791 c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
792 MlirAttribute attr =
793 mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
794 if (mlirAttributeIsNull(attr)) {
795 throw SetPyError(PyExc_KeyError,
796 "attempt to access a non-existent attribute");
797 }
798 return PyAttribute(self.getContext(), attr);
799 });
800 c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
801 if (index < 0 || index >= self.dunderLen()) {
802 throw SetPyError(PyExc_IndexError,
803 "attempt to access out of bounds attribute");
804 }
805 MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
806 return PyNamedAttribute(
807 namedAttr.attribute,
808 std::string(mlirIdentifierStr(namedAttr.name).data));
809 });
810 }
811 };
812
813 /// Refinement of PyDenseElementsAttribute for attributes containing
814 /// floating-point values. Supports element access.
815 class PyDenseFPElementsAttribute
816 : public PyConcreteAttribute<PyDenseFPElementsAttribute,
817 PyDenseElementsAttribute> {
818 public:
819 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
820 static constexpr const char *pyClassName = "DenseFPElementsAttr";
821 using PyConcreteAttribute::PyConcreteAttribute;
822
dunderGetItem(intptr_t pos)823 py::float_ dunderGetItem(intptr_t pos) {
824 if (pos < 0 || pos >= dunderLen()) {
825 throw SetPyError(PyExc_IndexError,
826 "attempt to access out of bounds element");
827 }
828
829 MlirType type = mlirAttributeGetType(*this);
830 type = mlirShapedTypeGetElementType(type);
831 // Dispatch element extraction to an appropriate C function based on the
832 // elemental type of the attribute. py::float_ is implicitly constructible
833 // from float and double.
834 // TODO: consider caching the type properties in the constructor to avoid
835 // querying them on each element access.
836 if (mlirTypeIsAF32(type)) {
837 return mlirDenseElementsAttrGetFloatValue(*this, pos);
838 }
839 if (mlirTypeIsAF64(type)) {
840 return mlirDenseElementsAttrGetDoubleValue(*this, pos);
841 }
842 throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
843 }
844
bindDerived(ClassTy & c)845 static void bindDerived(ClassTy &c) {
846 c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
847 }
848 };
849
850 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
851 public:
852 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
853 static constexpr const char *pyClassName = "TypeAttr";
854 using PyConcreteAttribute::PyConcreteAttribute;
855
bindDerived(ClassTy & c)856 static void bindDerived(ClassTy &c) {
857 c.def_static(
858 "get",
859 [](PyType value, DefaultingPyMlirContext context) {
860 MlirAttribute attr = mlirTypeAttrGet(value.get());
861 return PyTypeAttribute(context->getRef(), attr);
862 },
863 py::arg("value"), py::arg("context") = py::none(),
864 "Gets a uniqued Type attribute");
865 c.def_property_readonly("value", [](PyTypeAttribute &self) {
866 return PyType(self.getContext()->getRef(),
867 mlirTypeAttrGetValue(self.get()));
868 });
869 }
870 };
871
872 /// Unit Attribute subclass. Unit attributes don't have values.
873 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
874 public:
875 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
876 static constexpr const char *pyClassName = "UnitAttr";
877 using PyConcreteAttribute::PyConcreteAttribute;
878
bindDerived(ClassTy & c)879 static void bindDerived(ClassTy &c) {
880 c.def_static(
881 "get",
882 [](DefaultingPyMlirContext context) {
883 return PyUnitAttribute(context->getRef(),
884 mlirUnitAttrGet(context->get()));
885 },
886 py::arg("context") = py::none(), "Create a Unit attribute.");
887 }
888 };
889
890 } // namespace
891
populateIRAttributes(py::module & m)892 void mlir::python::populateIRAttributes(py::module &m) {
893 PyAffineMapAttribute::bind(m);
894 PyArrayAttribute::bind(m);
895 PyArrayAttribute::PyArrayAttributeIterator::bind(m);
896 PyBoolAttribute::bind(m);
897 PyDenseElementsAttribute::bind(m);
898 PyDenseFPElementsAttribute::bind(m);
899 PyDenseIntElementsAttribute::bind(m);
900 PyDictAttribute::bind(m);
901 PyFlatSymbolRefAttribute::bind(m);
902 PyOpaqueAttribute::bind(m);
903 PyFloatAttribute::bind(m);
904 PyIntegerAttribute::bind(m);
905 PyStringAttribute::bind(m);
906 PyTypeAttribute::bind(m);
907 PyUnitAttribute::bind(m);
908 }
909