1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "PybindUtils.h"
12 
13 #include "mlir-c/BuiltinAttributes.h"
14 #include "mlir-c/BuiltinTypes.h"
15 
16 namespace py = pybind11;
17 using namespace mlir;
18 using namespace mlir::python;
19 
20 using llvm::SmallVector;
21 using llvm::StringRef;
22 using llvm::Twine;
23 
24 namespace {
25 
26 static MlirStringRef toMlirStringRef(const std::string &s) {
27   return mlirStringRefCreate(s.data(), s.size());
28 }
29 
30 /// CRTP base classes for Python attributes that subclass Attribute and should
31 /// be castable from it (i.e. via something like StringAttr(attr)).
32 /// By default, attribute class hierarchies are one level deep (i.e. a
33 /// concrete attribute class extends PyAttribute); however, intermediate
34 /// python-visible base classes can be modeled by specifying a BaseTy.
35 template <typename DerivedTy, typename BaseTy = PyAttribute>
36 class PyConcreteAttribute : public BaseTy {
37 public:
38   // Derived classes must define statics for:
39   //   IsAFunctionTy isaFunction
40   //   const char *pyClassName
41   using ClassTy = py::class_<DerivedTy, BaseTy>;
42   using IsAFunctionTy = bool (*)(MlirAttribute);
43 
44   PyConcreteAttribute() = default;
45   PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
46       : BaseTy(std::move(contextRef), attr) {}
47   PyConcreteAttribute(PyAttribute &orig)
48       : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
49 
50   static MlirAttribute castFrom(PyAttribute &orig) {
51     if (!DerivedTy::isaFunction(orig)) {
52       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
53       throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
54                                              DerivedTy::pyClassName +
55                                              " (from " + origRepr + ")");
56     }
57     return orig;
58   }
59 
60   static void bind(py::module &m) {
61     auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol());
62     cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
63     DerivedTy::bindDerived(cls);
64   }
65 
66   /// Implemented by derived classes to add methods to the Python subclass.
67   static void bindDerived(ClassTy &m) {}
68 };
69 
70 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
71 public:
72   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
73   static constexpr const char *pyClassName = "AffineMapAttr";
74   using PyConcreteAttribute::PyConcreteAttribute;
75 
76   static void bindDerived(ClassTy &c) {
77     c.def_static(
78         "get",
79         [](PyAffineMap &affineMap) {
80           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
81           return PyAffineMapAttribute(affineMap.getContext(), attr);
82         },
83         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
84   }
85 };
86 
87 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
88 public:
89   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
90   static constexpr const char *pyClassName = "ArrayAttr";
91   using PyConcreteAttribute::PyConcreteAttribute;
92 
93   class PyArrayAttributeIterator {
94   public:
95     PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
96 
97     PyArrayAttributeIterator &dunderIter() { return *this; }
98 
99     PyAttribute dunderNext() {
100       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
101         throw py::stop_iteration();
102       }
103       return PyAttribute(attr.getContext(),
104                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
105     }
106 
107     static void bind(py::module &m) {
108       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
109           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
110           .def("__next__", &PyArrayAttributeIterator::dunderNext);
111     }
112 
113   private:
114     PyAttribute attr;
115     int nextIndex = 0;
116   };
117 
118   static void bindDerived(ClassTy &c) {
119     c.def_static(
120         "get",
121         [](py::list attributes, DefaultingPyMlirContext context) {
122           SmallVector<MlirAttribute> mlirAttributes;
123           mlirAttributes.reserve(py::len(attributes));
124           for (auto attribute : attributes) {
125             try {
126               mlirAttributes.push_back(attribute.cast<PyAttribute>());
127             } catch (py::cast_error &err) {
128               std::string msg = std::string("Invalid attribute when attempting "
129                                             "to create an ArrayAttribute (") +
130                                 err.what() + ")";
131               throw py::cast_error(msg);
132             } catch (py::reference_cast_error &err) {
133               // This exception seems thrown when the value is "None".
134               std::string msg =
135                   std::string("Invalid attribute (None?) when attempting to "
136                               "create an ArrayAttribute (") +
137                   err.what() + ")";
138               throw py::cast_error(msg);
139             }
140           }
141           MlirAttribute attr = mlirArrayAttrGet(
142               context->get(), mlirAttributes.size(), mlirAttributes.data());
143           return PyArrayAttribute(context->getRef(), attr);
144         },
145         py::arg("attributes"), py::arg("context") = py::none(),
146         "Gets a uniqued Array attribute");
147     c.def("__getitem__",
148           [](PyArrayAttribute &arr, intptr_t i) {
149             if (i >= mlirArrayAttrGetNumElements(arr))
150               throw py::index_error("ArrayAttribute index out of range");
151             return PyAttribute(arr.getContext(),
152                                mlirArrayAttrGetElement(arr, i));
153           })
154         .def("__len__",
155              [](const PyArrayAttribute &arr) {
156                return mlirArrayAttrGetNumElements(arr);
157              })
158         .def("__iter__", [](const PyArrayAttribute &arr) {
159           return PyArrayAttributeIterator(arr);
160         });
161   }
162 };
163 
164 /// Float Point Attribute subclass - FloatAttr.
165 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
166 public:
167   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
168   static constexpr const char *pyClassName = "FloatAttr";
169   using PyConcreteAttribute::PyConcreteAttribute;
170 
171   static void bindDerived(ClassTy &c) {
172     c.def_static(
173         "get",
174         [](PyType &type, double value, DefaultingPyLocation loc) {
175           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
176           // TODO: Rework error reporting once diagnostic engine is exposed
177           // in C API.
178           if (mlirAttributeIsNull(attr)) {
179             throw SetPyError(PyExc_ValueError,
180                              Twine("invalid '") +
181                                  py::repr(py::cast(type)).cast<std::string>() +
182                                  "' and expected floating point type.");
183           }
184           return PyFloatAttribute(type.getContext(), attr);
185         },
186         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
187         "Gets an uniqued float point attribute associated to a type");
188     c.def_static(
189         "get_f32",
190         [](double value, DefaultingPyMlirContext context) {
191           MlirAttribute attr = mlirFloatAttrDoubleGet(
192               context->get(), mlirF32TypeGet(context->get()), value);
193           return PyFloatAttribute(context->getRef(), attr);
194         },
195         py::arg("value"), py::arg("context") = py::none(),
196         "Gets an uniqued float point attribute associated to a f32 type");
197     c.def_static(
198         "get_f64",
199         [](double value, DefaultingPyMlirContext context) {
200           MlirAttribute attr = mlirFloatAttrDoubleGet(
201               context->get(), mlirF64TypeGet(context->get()), value);
202           return PyFloatAttribute(context->getRef(), attr);
203         },
204         py::arg("value"), py::arg("context") = py::none(),
205         "Gets an uniqued float point attribute associated to a f64 type");
206     c.def_property_readonly(
207         "value",
208         [](PyFloatAttribute &self) {
209           return mlirFloatAttrGetValueDouble(self);
210         },
211         "Returns the value of the float point attribute");
212   }
213 };
214 
215 /// Integer Attribute subclass - IntegerAttr.
216 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
217 public:
218   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
219   static constexpr const char *pyClassName = "IntegerAttr";
220   using PyConcreteAttribute::PyConcreteAttribute;
221 
222   static void bindDerived(ClassTy &c) {
223     c.def_static(
224         "get",
225         [](PyType &type, int64_t value) {
226           MlirAttribute attr = mlirIntegerAttrGet(type, value);
227           return PyIntegerAttribute(type.getContext(), attr);
228         },
229         py::arg("type"), py::arg("value"),
230         "Gets an uniqued integer attribute associated to a type");
231     c.def_property_readonly(
232         "value",
233         [](PyIntegerAttribute &self) {
234           return mlirIntegerAttrGetValueInt(self);
235         },
236         "Returns the value of the integer attribute");
237   }
238 };
239 
240 /// Bool Attribute subclass - BoolAttr.
241 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
242 public:
243   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
244   static constexpr const char *pyClassName = "BoolAttr";
245   using PyConcreteAttribute::PyConcreteAttribute;
246 
247   static void bindDerived(ClassTy &c) {
248     c.def_static(
249         "get",
250         [](bool value, DefaultingPyMlirContext context) {
251           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
252           return PyBoolAttribute(context->getRef(), attr);
253         },
254         py::arg("value"), py::arg("context") = py::none(),
255         "Gets an uniqued bool attribute");
256     c.def_property_readonly(
257         "value",
258         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
259         "Returns the value of the bool attribute");
260   }
261 };
262 
263 class PyFlatSymbolRefAttribute
264     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
265 public:
266   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
267   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
268   using PyConcreteAttribute::PyConcreteAttribute;
269 
270   static void bindDerived(ClassTy &c) {
271     c.def_static(
272         "get",
273         [](std::string value, DefaultingPyMlirContext context) {
274           MlirAttribute attr =
275               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
276           return PyFlatSymbolRefAttribute(context->getRef(), attr);
277         },
278         py::arg("value"), py::arg("context") = py::none(),
279         "Gets a uniqued FlatSymbolRef attribute");
280     c.def_property_readonly(
281         "value",
282         [](PyFlatSymbolRefAttribute &self) {
283           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
284           return py::str(stringRef.data, stringRef.length);
285         },
286         "Returns the value of the FlatSymbolRef attribute as a string");
287   }
288 };
289 
290 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
291 public:
292   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
293   static constexpr const char *pyClassName = "StringAttr";
294   using PyConcreteAttribute::PyConcreteAttribute;
295 
296   static void bindDerived(ClassTy &c) {
297     c.def_static(
298         "get",
299         [](std::string value, DefaultingPyMlirContext context) {
300           MlirAttribute attr =
301               mlirStringAttrGet(context->get(), toMlirStringRef(value));
302           return PyStringAttribute(context->getRef(), attr);
303         },
304         py::arg("value"), py::arg("context") = py::none(),
305         "Gets a uniqued string attribute");
306     c.def_static(
307         "get_typed",
308         [](PyType &type, std::string value) {
309           MlirAttribute attr =
310               mlirStringAttrTypedGet(type, toMlirStringRef(value));
311           return PyStringAttribute(type.getContext(), attr);
312         },
313 
314         "Gets a uniqued string attribute associated to a type");
315     c.def_property_readonly(
316         "value",
317         [](PyStringAttribute &self) {
318           MlirStringRef stringRef = mlirStringAttrGetValue(self);
319           return py::str(stringRef.data, stringRef.length);
320         },
321         "Returns the value of the string attribute");
322   }
323 };
324 
325 // TODO: Support construction of bool elements.
326 // TODO: Support construction of string elements.
327 class PyDenseElementsAttribute
328     : public PyConcreteAttribute<PyDenseElementsAttribute> {
329 public:
330   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
331   static constexpr const char *pyClassName = "DenseElementsAttr";
332   using PyConcreteAttribute::PyConcreteAttribute;
333 
334   static PyDenseElementsAttribute
335   getFromBuffer(py::buffer array, bool signless,
336                 DefaultingPyMlirContext contextWrapper) {
337     // Request a contiguous view. In exotic cases, this will cause a copy.
338     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
339     Py_buffer *view = new Py_buffer();
340     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
341       delete view;
342       throw py::error_already_set();
343     }
344     py::buffer_info arrayInfo(view);
345 
346     MlirContext context = contextWrapper->get();
347     // Switch on the types that can be bulk loaded between the Python and
348     // MLIR-C APIs.
349     // See: https://docs.python.org/3/library/struct.html#format-characters
350     if (arrayInfo.format == "f") {
351       // f32
352       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
353       return PyDenseElementsAttribute(
354           contextWrapper->getRef(),
355           bulkLoad(context, mlirDenseElementsAttrFloatGet,
356                    mlirF32TypeGet(context), arrayInfo));
357     } else if (arrayInfo.format == "d") {
358       // f64
359       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
360       return PyDenseElementsAttribute(
361           contextWrapper->getRef(),
362           bulkLoad(context, mlirDenseElementsAttrDoubleGet,
363                    mlirF64TypeGet(context), arrayInfo));
364     } else if (isSignedIntegerFormat(arrayInfo.format)) {
365       if (arrayInfo.itemsize == 4) {
366         // i32
367         MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
368                                         : mlirIntegerTypeSignedGet(context, 32);
369         return PyDenseElementsAttribute(contextWrapper->getRef(),
370                                         bulkLoad(context,
371                                                  mlirDenseElementsAttrInt32Get,
372                                                  elementType, arrayInfo));
373       } else if (arrayInfo.itemsize == 8) {
374         // i64
375         MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
376                                         : mlirIntegerTypeSignedGet(context, 64);
377         return PyDenseElementsAttribute(contextWrapper->getRef(),
378                                         bulkLoad(context,
379                                                  mlirDenseElementsAttrInt64Get,
380                                                  elementType, arrayInfo));
381       }
382     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
383       if (arrayInfo.itemsize == 4) {
384         // unsigned i32
385         MlirType elementType = signless
386                                    ? mlirIntegerTypeGet(context, 32)
387                                    : mlirIntegerTypeUnsignedGet(context, 32);
388         return PyDenseElementsAttribute(contextWrapper->getRef(),
389                                         bulkLoad(context,
390                                                  mlirDenseElementsAttrUInt32Get,
391                                                  elementType, arrayInfo));
392       } else if (arrayInfo.itemsize == 8) {
393         // unsigned i64
394         MlirType elementType = signless
395                                    ? mlirIntegerTypeGet(context, 64)
396                                    : mlirIntegerTypeUnsignedGet(context, 64);
397         return PyDenseElementsAttribute(contextWrapper->getRef(),
398                                         bulkLoad(context,
399                                                  mlirDenseElementsAttrUInt64Get,
400                                                  elementType, arrayInfo));
401       }
402     }
403 
404     // TODO: Fall back to string-based get.
405     std::string message = "unimplemented array format conversion from format: ";
406     message.append(arrayInfo.format);
407     throw SetPyError(PyExc_ValueError, message);
408   }
409 
410   static PyDenseElementsAttribute getSplat(PyType shapedType,
411                                            PyAttribute &elementAttr) {
412     auto contextWrapper =
413         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
414     if (!mlirAttributeIsAInteger(elementAttr) &&
415         !mlirAttributeIsAFloat(elementAttr)) {
416       std::string message = "Illegal element type for DenseElementsAttr: ";
417       message.append(py::repr(py::cast(elementAttr)));
418       throw SetPyError(PyExc_ValueError, message);
419     }
420     if (!mlirTypeIsAShaped(shapedType) ||
421         !mlirShapedTypeHasStaticShape(shapedType)) {
422       std::string message =
423           "Expected a static ShapedType for the shaped_type parameter: ";
424       message.append(py::repr(py::cast(shapedType)));
425       throw SetPyError(PyExc_ValueError, message);
426     }
427     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
428     MlirType attrType = mlirAttributeGetType(elementAttr);
429     if (!mlirTypeEqual(shapedElementType, attrType)) {
430       std::string message =
431           "Shaped element type and attribute type must be equal: shaped=";
432       message.append(py::repr(py::cast(shapedType)));
433       message.append(", element=");
434       message.append(py::repr(py::cast(elementAttr)));
435       throw SetPyError(PyExc_ValueError, message);
436     }
437 
438     MlirAttribute elements =
439         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
440     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
441   }
442 
443   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
444 
445   py::buffer_info accessBuffer() {
446     MlirType shapedType = mlirAttributeGetType(*this);
447     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
448 
449     if (mlirTypeIsAF32(elementType)) {
450       // f32
451       return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
452     } else if (mlirTypeIsAF64(elementType)) {
453       // f64
454       return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
455     } else if (mlirTypeIsAInteger(elementType) &&
456                mlirIntegerTypeGetWidth(elementType) == 32) {
457       if (mlirIntegerTypeIsSignless(elementType) ||
458           mlirIntegerTypeIsSigned(elementType)) {
459         // i32
460         return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
461       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
462         // unsigned i32
463         return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
464       }
465     } else if (mlirTypeIsAInteger(elementType) &&
466                mlirIntegerTypeGetWidth(elementType) == 64) {
467       if (mlirIntegerTypeIsSignless(elementType) ||
468           mlirIntegerTypeIsSigned(elementType)) {
469         // i64
470         return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
471       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
472         // unsigned i64
473         return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
474       }
475     }
476 
477     std::string message = "unimplemented array format.";
478     throw SetPyError(PyExc_ValueError, message);
479   }
480 
481   static void bindDerived(ClassTy &c) {
482     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
483         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
484                     py::arg("array"), py::arg("signless") = true,
485                     py::arg("context") = py::none(),
486                     "Gets from a buffer or ndarray")
487         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
488                     py::arg("shaped_type"), py::arg("element_attr"),
489                     "Gets a DenseElementsAttr where all values are the same")
490         .def_property_readonly("is_splat",
491                                [](PyDenseElementsAttribute &self) -> bool {
492                                  return mlirDenseElementsAttrIsSplat(self);
493                                })
494         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
495   }
496 
497 private:
498   template <typename ElementTy>
499   static MlirAttribute
500   bulkLoad(MlirContext context,
501            MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
502            MlirType mlirElementType, py::buffer_info &arrayInfo) {
503     SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
504                                   arrayInfo.shape.begin() + arrayInfo.ndim);
505     MlirAttribute encodingAttr = mlirAttributeGetNull();
506     auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
507                                               mlirElementType, encodingAttr);
508     intptr_t numElements = arrayInfo.size;
509     const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
510     return ctor(shapedType, numElements, contents);
511   }
512 
513   static bool isUnsignedIntegerFormat(const std::string &format) {
514     if (format.empty())
515       return false;
516     char code = format[0];
517     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
518            code == 'Q';
519   }
520 
521   static bool isSignedIntegerFormat(const std::string &format) {
522     if (format.empty())
523       return false;
524     char code = format[0];
525     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
526            code == 'q';
527   }
528 
529   template <typename Type>
530   py::buffer_info bufferInfo(MlirType shapedType,
531                              Type (*value)(MlirAttribute, intptr_t)) {
532     intptr_t rank = mlirShapedTypeGetRank(shapedType);
533     // Prepare the data for the buffer_info.
534     // Buffer is configured for read-only access below.
535     Type *data = static_cast<Type *>(
536         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
537     // Prepare the shape for the buffer_info.
538     SmallVector<intptr_t, 4> shape;
539     for (intptr_t i = 0; i < rank; ++i)
540       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
541     // Prepare the strides for the buffer_info.
542     SmallVector<intptr_t, 4> strides;
543     intptr_t strideFactor = 1;
544     for (intptr_t i = 1; i < rank; ++i) {
545       strideFactor = 1;
546       for (intptr_t j = i; j < rank; ++j) {
547         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
548       }
549       strides.push_back(sizeof(Type) * strideFactor);
550     }
551     strides.push_back(sizeof(Type));
552     return py::buffer_info(data, sizeof(Type),
553                            py::format_descriptor<Type>::format(), rank, shape,
554                            strides, /*readonly=*/true);
555   }
556 }; // namespace
557 
558 /// Refinement of the PyDenseElementsAttribute for attributes containing integer
559 /// (and boolean) values. Supports element access.
560 class PyDenseIntElementsAttribute
561     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
562                                  PyDenseElementsAttribute> {
563 public:
564   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
565   static constexpr const char *pyClassName = "DenseIntElementsAttr";
566   using PyConcreteAttribute::PyConcreteAttribute;
567 
568   /// Returns the element at the given linear position. Asserts if the index is
569   /// out of range.
570   py::int_ dunderGetItem(intptr_t pos) {
571     if (pos < 0 || pos >= dunderLen()) {
572       throw SetPyError(PyExc_IndexError,
573                        "attempt to access out of bounds element");
574     }
575 
576     MlirType type = mlirAttributeGetType(*this);
577     type = mlirShapedTypeGetElementType(type);
578     assert(mlirTypeIsAInteger(type) &&
579            "expected integer element type in dense int elements attribute");
580     // Dispatch element extraction to an appropriate C function based on the
581     // elemental type of the attribute. py::int_ is implicitly constructible
582     // from any C++ integral type and handles bitwidth correctly.
583     // TODO: consider caching the type properties in the constructor to avoid
584     // querying them on each element access.
585     unsigned width = mlirIntegerTypeGetWidth(type);
586     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
587     if (isUnsigned) {
588       if (width == 1) {
589         return mlirDenseElementsAttrGetBoolValue(*this, pos);
590       }
591       if (width == 32) {
592         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
593       }
594       if (width == 64) {
595         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
596       }
597     } else {
598       if (width == 1) {
599         return mlirDenseElementsAttrGetBoolValue(*this, pos);
600       }
601       if (width == 32) {
602         return mlirDenseElementsAttrGetInt32Value(*this, pos);
603       }
604       if (width == 64) {
605         return mlirDenseElementsAttrGetInt64Value(*this, pos);
606       }
607     }
608     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
609   }
610 
611   static void bindDerived(ClassTy &c) {
612     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
613   }
614 };
615 
616 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
617 public:
618   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
619   static constexpr const char *pyClassName = "DictAttr";
620   using PyConcreteAttribute::PyConcreteAttribute;
621 
622   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
623 
624   static void bindDerived(ClassTy &c) {
625     c.def("__len__", &PyDictAttribute::dunderLen);
626     c.def_static(
627         "get",
628         [](py::dict attributes, DefaultingPyMlirContext context) {
629           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
630           mlirNamedAttributes.reserve(attributes.size());
631           for (auto &it : attributes) {
632             auto &mlir_attr = it.second.cast<PyAttribute &>();
633             auto name = it.first.cast<std::string>();
634             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
635                 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
636                                   toMlirStringRef(name)),
637                 mlir_attr));
638           }
639           MlirAttribute attr =
640               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
641                                     mlirNamedAttributes.data());
642           return PyDictAttribute(context->getRef(), attr);
643         },
644         py::arg("value"), py::arg("context") = py::none(),
645         "Gets an uniqued dict attribute");
646     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
647       MlirAttribute attr =
648           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
649       if (mlirAttributeIsNull(attr)) {
650         throw SetPyError(PyExc_KeyError,
651                          "attempt to access a non-existent attribute");
652       }
653       return PyAttribute(self.getContext(), attr);
654     });
655     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
656       if (index < 0 || index >= self.dunderLen()) {
657         throw SetPyError(PyExc_IndexError,
658                          "attempt to access out of bounds attribute");
659       }
660       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
661       return PyNamedAttribute(
662           namedAttr.attribute,
663           std::string(mlirIdentifierStr(namedAttr.name).data));
664     });
665   }
666 };
667 
668 /// Refinement of PyDenseElementsAttribute for attributes containing
669 /// floating-point values. Supports element access.
670 class PyDenseFPElementsAttribute
671     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
672                                  PyDenseElementsAttribute> {
673 public:
674   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
675   static constexpr const char *pyClassName = "DenseFPElementsAttr";
676   using PyConcreteAttribute::PyConcreteAttribute;
677 
678   py::float_ dunderGetItem(intptr_t pos) {
679     if (pos < 0 || pos >= dunderLen()) {
680       throw SetPyError(PyExc_IndexError,
681                        "attempt to access out of bounds element");
682     }
683 
684     MlirType type = mlirAttributeGetType(*this);
685     type = mlirShapedTypeGetElementType(type);
686     // Dispatch element extraction to an appropriate C function based on the
687     // elemental type of the attribute. py::float_ is implicitly constructible
688     // from float and double.
689     // TODO: consider caching the type properties in the constructor to avoid
690     // querying them on each element access.
691     if (mlirTypeIsAF32(type)) {
692       return mlirDenseElementsAttrGetFloatValue(*this, pos);
693     }
694     if (mlirTypeIsAF64(type)) {
695       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
696     }
697     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
698   }
699 
700   static void bindDerived(ClassTy &c) {
701     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
702   }
703 };
704 
705 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
706 public:
707   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
708   static constexpr const char *pyClassName = "TypeAttr";
709   using PyConcreteAttribute::PyConcreteAttribute;
710 
711   static void bindDerived(ClassTy &c) {
712     c.def_static(
713         "get",
714         [](PyType value, DefaultingPyMlirContext context) {
715           MlirAttribute attr = mlirTypeAttrGet(value.get());
716           return PyTypeAttribute(context->getRef(), attr);
717         },
718         py::arg("value"), py::arg("context") = py::none(),
719         "Gets a uniqued Type attribute");
720     c.def_property_readonly("value", [](PyTypeAttribute &self) {
721       return PyType(self.getContext()->getRef(),
722                     mlirTypeAttrGetValue(self.get()));
723     });
724   }
725 };
726 
727 /// Unit Attribute subclass. Unit attributes don't have values.
728 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
729 public:
730   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
731   static constexpr const char *pyClassName = "UnitAttr";
732   using PyConcreteAttribute::PyConcreteAttribute;
733 
734   static void bindDerived(ClassTy &c) {
735     c.def_static(
736         "get",
737         [](DefaultingPyMlirContext context) {
738           return PyUnitAttribute(context->getRef(),
739                                  mlirUnitAttrGet(context->get()));
740         },
741         py::arg("context") = py::none(), "Create a Unit attribute.");
742   }
743 };
744 
745 } // namespace
746 
747 void mlir::python::populateIRAttributes(py::module &m) {
748   PyAffineMapAttribute::bind(m);
749   PyArrayAttribute::bind(m);
750   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
751   PyBoolAttribute::bind(m);
752   PyDenseElementsAttribute::bind(m);
753   PyDenseFPElementsAttribute::bind(m);
754   PyDenseIntElementsAttribute::bind(m);
755   PyDictAttribute::bind(m);
756   PyFlatSymbolRefAttribute::bind(m);
757   PyFloatAttribute::bind(m);
758   PyIntegerAttribute::bind(m);
759   PyStringAttribute::bind(m);
760   PyTypeAttribute::bind(m);
761   PyUnitAttribute::bind(m);
762 }
763