1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "PybindUtils.h"
12 
13 #include "mlir-c/BuiltinAttributes.h"
14 #include "mlir-c/BuiltinTypes.h"
15 
16 namespace py = pybind11;
17 using namespace mlir;
18 using namespace mlir::python;
19 
20 using llvm::SmallVector;
21 using llvm::StringRef;
22 using llvm::Twine;
23 
24 namespace {
25 
26 static MlirStringRef toMlirStringRef(const std::string &s) {
27   return mlirStringRefCreate(s.data(), s.size());
28 }
29 
30 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
31 public:
32   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
33   static constexpr const char *pyClassName = "AffineMapAttr";
34   using PyConcreteAttribute::PyConcreteAttribute;
35 
36   static void bindDerived(ClassTy &c) {
37     c.def_static(
38         "get",
39         [](PyAffineMap &affineMap) {
40           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
41           return PyAffineMapAttribute(affineMap.getContext(), attr);
42         },
43         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
44   }
45 };
46 
47 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
48 public:
49   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
50   static constexpr const char *pyClassName = "ArrayAttr";
51   using PyConcreteAttribute::PyConcreteAttribute;
52 
53   class PyArrayAttributeIterator {
54   public:
55     PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
56 
57     PyArrayAttributeIterator &dunderIter() { return *this; }
58 
59     PyAttribute dunderNext() {
60       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
61         throw py::stop_iteration();
62       }
63       return PyAttribute(attr.getContext(),
64                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
65     }
66 
67     static void bind(py::module &m) {
68       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
69                                            py::module_local())
70           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
71           .def("__next__", &PyArrayAttributeIterator::dunderNext);
72     }
73 
74   private:
75     PyAttribute attr;
76     int nextIndex = 0;
77   };
78 
79   static void bindDerived(ClassTy &c) {
80     c.def_static(
81         "get",
82         [](py::list attributes, DefaultingPyMlirContext context) {
83           SmallVector<MlirAttribute> mlirAttributes;
84           mlirAttributes.reserve(py::len(attributes));
85           for (auto attribute : attributes) {
86             try {
87               mlirAttributes.push_back(attribute.cast<PyAttribute>());
88             } catch (py::cast_error &err) {
89               std::string msg = std::string("Invalid attribute when attempting "
90                                             "to create an ArrayAttribute (") +
91                                 err.what() + ")";
92               throw py::cast_error(msg);
93             } catch (py::reference_cast_error &err) {
94               // This exception seems thrown when the value is "None".
95               std::string msg =
96                   std::string("Invalid attribute (None?) when attempting to "
97                               "create an ArrayAttribute (") +
98                   err.what() + ")";
99               throw py::cast_error(msg);
100             }
101           }
102           MlirAttribute attr = mlirArrayAttrGet(
103               context->get(), mlirAttributes.size(), mlirAttributes.data());
104           return PyArrayAttribute(context->getRef(), attr);
105         },
106         py::arg("attributes"), py::arg("context") = py::none(),
107         "Gets a uniqued Array attribute");
108     c.def("__getitem__",
109           [](PyArrayAttribute &arr, intptr_t i) {
110             if (i >= mlirArrayAttrGetNumElements(arr))
111               throw py::index_error("ArrayAttribute index out of range");
112             return PyAttribute(arr.getContext(),
113                                mlirArrayAttrGetElement(arr, i));
114           })
115         .def("__len__",
116              [](const PyArrayAttribute &arr) {
117                return mlirArrayAttrGetNumElements(arr);
118              })
119         .def("__iter__", [](const PyArrayAttribute &arr) {
120           return PyArrayAttributeIterator(arr);
121         });
122   }
123 };
124 
125 /// Float Point Attribute subclass - FloatAttr.
126 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
127 public:
128   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
129   static constexpr const char *pyClassName = "FloatAttr";
130   using PyConcreteAttribute::PyConcreteAttribute;
131 
132   static void bindDerived(ClassTy &c) {
133     c.def_static(
134         "get",
135         [](PyType &type, double value, DefaultingPyLocation loc) {
136           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
137           // TODO: Rework error reporting once diagnostic engine is exposed
138           // in C API.
139           if (mlirAttributeIsNull(attr)) {
140             throw SetPyError(PyExc_ValueError,
141                              Twine("invalid '") +
142                                  py::repr(py::cast(type)).cast<std::string>() +
143                                  "' and expected floating point type.");
144           }
145           return PyFloatAttribute(type.getContext(), attr);
146         },
147         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
148         "Gets an uniqued float point attribute associated to a type");
149     c.def_static(
150         "get_f32",
151         [](double value, DefaultingPyMlirContext context) {
152           MlirAttribute attr = mlirFloatAttrDoubleGet(
153               context->get(), mlirF32TypeGet(context->get()), value);
154           return PyFloatAttribute(context->getRef(), attr);
155         },
156         py::arg("value"), py::arg("context") = py::none(),
157         "Gets an uniqued float point attribute associated to a f32 type");
158     c.def_static(
159         "get_f64",
160         [](double value, DefaultingPyMlirContext context) {
161           MlirAttribute attr = mlirFloatAttrDoubleGet(
162               context->get(), mlirF64TypeGet(context->get()), value);
163           return PyFloatAttribute(context->getRef(), attr);
164         },
165         py::arg("value"), py::arg("context") = py::none(),
166         "Gets an uniqued float point attribute associated to a f64 type");
167     c.def_property_readonly(
168         "value",
169         [](PyFloatAttribute &self) {
170           return mlirFloatAttrGetValueDouble(self);
171         },
172         "Returns the value of the float point attribute");
173   }
174 };
175 
176 /// Integer Attribute subclass - IntegerAttr.
177 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
178 public:
179   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
180   static constexpr const char *pyClassName = "IntegerAttr";
181   using PyConcreteAttribute::PyConcreteAttribute;
182 
183   static void bindDerived(ClassTy &c) {
184     c.def_static(
185         "get",
186         [](PyType &type, int64_t value) {
187           MlirAttribute attr = mlirIntegerAttrGet(type, value);
188           return PyIntegerAttribute(type.getContext(), attr);
189         },
190         py::arg("type"), py::arg("value"),
191         "Gets an uniqued integer attribute associated to a type");
192     c.def_property_readonly(
193         "value",
194         [](PyIntegerAttribute &self) {
195           return mlirIntegerAttrGetValueInt(self);
196         },
197         "Returns the value of the integer attribute");
198   }
199 };
200 
201 /// Bool Attribute subclass - BoolAttr.
202 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
203 public:
204   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
205   static constexpr const char *pyClassName = "BoolAttr";
206   using PyConcreteAttribute::PyConcreteAttribute;
207 
208   static void bindDerived(ClassTy &c) {
209     c.def_static(
210         "get",
211         [](bool value, DefaultingPyMlirContext context) {
212           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
213           return PyBoolAttribute(context->getRef(), attr);
214         },
215         py::arg("value"), py::arg("context") = py::none(),
216         "Gets an uniqued bool attribute");
217     c.def_property_readonly(
218         "value",
219         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
220         "Returns the value of the bool attribute");
221   }
222 };
223 
224 class PyFlatSymbolRefAttribute
225     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
226 public:
227   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
228   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
229   using PyConcreteAttribute::PyConcreteAttribute;
230 
231   static void bindDerived(ClassTy &c) {
232     c.def_static(
233         "get",
234         [](std::string value, DefaultingPyMlirContext context) {
235           MlirAttribute attr =
236               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
237           return PyFlatSymbolRefAttribute(context->getRef(), attr);
238         },
239         py::arg("value"), py::arg("context") = py::none(),
240         "Gets a uniqued FlatSymbolRef attribute");
241     c.def_property_readonly(
242         "value",
243         [](PyFlatSymbolRefAttribute &self) {
244           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
245           return py::str(stringRef.data, stringRef.length);
246         },
247         "Returns the value of the FlatSymbolRef attribute as a string");
248   }
249 };
250 
251 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
252 public:
253   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
254   static constexpr const char *pyClassName = "StringAttr";
255   using PyConcreteAttribute::PyConcreteAttribute;
256 
257   static void bindDerived(ClassTy &c) {
258     c.def_static(
259         "get",
260         [](std::string value, DefaultingPyMlirContext context) {
261           MlirAttribute attr =
262               mlirStringAttrGet(context->get(), toMlirStringRef(value));
263           return PyStringAttribute(context->getRef(), attr);
264         },
265         py::arg("value"), py::arg("context") = py::none(),
266         "Gets a uniqued string attribute");
267     c.def_static(
268         "get_typed",
269         [](PyType &type, std::string value) {
270           MlirAttribute attr =
271               mlirStringAttrTypedGet(type, toMlirStringRef(value));
272           return PyStringAttribute(type.getContext(), attr);
273         },
274 
275         "Gets a uniqued string attribute associated to a type");
276     c.def_property_readonly(
277         "value",
278         [](PyStringAttribute &self) {
279           MlirStringRef stringRef = mlirStringAttrGetValue(self);
280           return py::str(stringRef.data, stringRef.length);
281         },
282         "Returns the value of the string attribute");
283   }
284 };
285 
286 // TODO: Support construction of bool elements.
287 // TODO: Support construction of string elements.
288 class PyDenseElementsAttribute
289     : public PyConcreteAttribute<PyDenseElementsAttribute> {
290 public:
291   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
292   static constexpr const char *pyClassName = "DenseElementsAttr";
293   using PyConcreteAttribute::PyConcreteAttribute;
294 
295   static PyDenseElementsAttribute
296   getFromBuffer(py::buffer array, bool signless,
297                 DefaultingPyMlirContext contextWrapper) {
298     // Request a contiguous view. In exotic cases, this will cause a copy.
299     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
300     Py_buffer *view = new Py_buffer();
301     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
302       delete view;
303       throw py::error_already_set();
304     }
305     py::buffer_info arrayInfo(view);
306 
307     MlirContext context = contextWrapper->get();
308     // Switch on the types that can be bulk loaded between the Python and
309     // MLIR-C APIs.
310     // See: https://docs.python.org/3/library/struct.html#format-characters
311     if (arrayInfo.format == "f") {
312       // f32
313       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
314       return PyDenseElementsAttribute(
315           contextWrapper->getRef(),
316           bulkLoad(context, mlirDenseElementsAttrFloatGet,
317                    mlirF32TypeGet(context), arrayInfo));
318     } else if (arrayInfo.format == "d") {
319       // f64
320       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
321       return PyDenseElementsAttribute(
322           contextWrapper->getRef(),
323           bulkLoad(context, mlirDenseElementsAttrDoubleGet,
324                    mlirF64TypeGet(context), arrayInfo));
325     } else if (isSignedIntegerFormat(arrayInfo.format)) {
326       if (arrayInfo.itemsize == 4) {
327         // i32
328         MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
329                                         : mlirIntegerTypeSignedGet(context, 32);
330         return PyDenseElementsAttribute(contextWrapper->getRef(),
331                                         bulkLoad(context,
332                                                  mlirDenseElementsAttrInt32Get,
333                                                  elementType, arrayInfo));
334       } else if (arrayInfo.itemsize == 8) {
335         // i64
336         MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
337                                         : mlirIntegerTypeSignedGet(context, 64);
338         return PyDenseElementsAttribute(contextWrapper->getRef(),
339                                         bulkLoad(context,
340                                                  mlirDenseElementsAttrInt64Get,
341                                                  elementType, arrayInfo));
342       }
343     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
344       if (arrayInfo.itemsize == 4) {
345         // unsigned i32
346         MlirType elementType = signless
347                                    ? mlirIntegerTypeGet(context, 32)
348                                    : mlirIntegerTypeUnsignedGet(context, 32);
349         return PyDenseElementsAttribute(contextWrapper->getRef(),
350                                         bulkLoad(context,
351                                                  mlirDenseElementsAttrUInt32Get,
352                                                  elementType, arrayInfo));
353       } else if (arrayInfo.itemsize == 8) {
354         // unsigned i64
355         MlirType elementType = signless
356                                    ? mlirIntegerTypeGet(context, 64)
357                                    : mlirIntegerTypeUnsignedGet(context, 64);
358         return PyDenseElementsAttribute(contextWrapper->getRef(),
359                                         bulkLoad(context,
360                                                  mlirDenseElementsAttrUInt64Get,
361                                                  elementType, arrayInfo));
362       }
363     }
364 
365     // TODO: Fall back to string-based get.
366     std::string message = "unimplemented array format conversion from format: ";
367     message.append(arrayInfo.format);
368     throw SetPyError(PyExc_ValueError, message);
369   }
370 
371   static PyDenseElementsAttribute getSplat(PyType shapedType,
372                                            PyAttribute &elementAttr) {
373     auto contextWrapper =
374         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
375     if (!mlirAttributeIsAInteger(elementAttr) &&
376         !mlirAttributeIsAFloat(elementAttr)) {
377       std::string message = "Illegal element type for DenseElementsAttr: ";
378       message.append(py::repr(py::cast(elementAttr)));
379       throw SetPyError(PyExc_ValueError, message);
380     }
381     if (!mlirTypeIsAShaped(shapedType) ||
382         !mlirShapedTypeHasStaticShape(shapedType)) {
383       std::string message =
384           "Expected a static ShapedType for the shaped_type parameter: ";
385       message.append(py::repr(py::cast(shapedType)));
386       throw SetPyError(PyExc_ValueError, message);
387     }
388     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
389     MlirType attrType = mlirAttributeGetType(elementAttr);
390     if (!mlirTypeEqual(shapedElementType, attrType)) {
391       std::string message =
392           "Shaped element type and attribute type must be equal: shaped=";
393       message.append(py::repr(py::cast(shapedType)));
394       message.append(", element=");
395       message.append(py::repr(py::cast(elementAttr)));
396       throw SetPyError(PyExc_ValueError, message);
397     }
398 
399     MlirAttribute elements =
400         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
401     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
402   }
403 
404   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
405 
406   py::buffer_info accessBuffer() {
407     MlirType shapedType = mlirAttributeGetType(*this);
408     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
409 
410     if (mlirTypeIsAF32(elementType)) {
411       // f32
412       return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
413     } else if (mlirTypeIsAF64(elementType)) {
414       // f64
415       return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
416     } else if (mlirTypeIsAInteger(elementType) &&
417                mlirIntegerTypeGetWidth(elementType) == 32) {
418       if (mlirIntegerTypeIsSignless(elementType) ||
419           mlirIntegerTypeIsSigned(elementType)) {
420         // i32
421         return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
422       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
423         // unsigned i32
424         return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
425       }
426     } else if (mlirTypeIsAInteger(elementType) &&
427                mlirIntegerTypeGetWidth(elementType) == 64) {
428       if (mlirIntegerTypeIsSignless(elementType) ||
429           mlirIntegerTypeIsSigned(elementType)) {
430         // i64
431         return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
432       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
433         // unsigned i64
434         return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
435       }
436     }
437 
438     std::string message = "unimplemented array format.";
439     throw SetPyError(PyExc_ValueError, message);
440   }
441 
442   static void bindDerived(ClassTy &c) {
443     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
444         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
445                     py::arg("array"), py::arg("signless") = true,
446                     py::arg("context") = py::none(),
447                     "Gets from a buffer or ndarray")
448         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
449                     py::arg("shaped_type"), py::arg("element_attr"),
450                     "Gets a DenseElementsAttr where all values are the same")
451         .def_property_readonly("is_splat",
452                                [](PyDenseElementsAttribute &self) -> bool {
453                                  return mlirDenseElementsAttrIsSplat(self);
454                                })
455         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
456   }
457 
458 private:
459   template <typename ElementTy>
460   static MlirAttribute
461   bulkLoad(MlirContext context,
462            MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
463            MlirType mlirElementType, py::buffer_info &arrayInfo) {
464     SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
465                                   arrayInfo.shape.begin() + arrayInfo.ndim);
466     MlirAttribute encodingAttr = mlirAttributeGetNull();
467     auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
468                                               mlirElementType, encodingAttr);
469     intptr_t numElements = arrayInfo.size;
470     const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
471     return ctor(shapedType, numElements, contents);
472   }
473 
474   static bool isUnsignedIntegerFormat(const std::string &format) {
475     if (format.empty())
476       return false;
477     char code = format[0];
478     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
479            code == 'Q';
480   }
481 
482   static bool isSignedIntegerFormat(const std::string &format) {
483     if (format.empty())
484       return false;
485     char code = format[0];
486     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
487            code == 'q';
488   }
489 
490   template <typename Type>
491   py::buffer_info bufferInfo(MlirType shapedType,
492                              Type (*value)(MlirAttribute, intptr_t)) {
493     intptr_t rank = mlirShapedTypeGetRank(shapedType);
494     // Prepare the data for the buffer_info.
495     // Buffer is configured for read-only access below.
496     Type *data = static_cast<Type *>(
497         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
498     // Prepare the shape for the buffer_info.
499     SmallVector<intptr_t, 4> shape;
500     for (intptr_t i = 0; i < rank; ++i)
501       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
502     // Prepare the strides for the buffer_info.
503     SmallVector<intptr_t, 4> strides;
504     intptr_t strideFactor = 1;
505     for (intptr_t i = 1; i < rank; ++i) {
506       strideFactor = 1;
507       for (intptr_t j = i; j < rank; ++j) {
508         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
509       }
510       strides.push_back(sizeof(Type) * strideFactor);
511     }
512     strides.push_back(sizeof(Type));
513     return py::buffer_info(data, sizeof(Type),
514                            py::format_descriptor<Type>::format(), rank, shape,
515                            strides, /*readonly=*/true);
516   }
517 }; // namespace
518 
519 /// Refinement of the PyDenseElementsAttribute for attributes containing integer
520 /// (and boolean) values. Supports element access.
521 class PyDenseIntElementsAttribute
522     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
523                                  PyDenseElementsAttribute> {
524 public:
525   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
526   static constexpr const char *pyClassName = "DenseIntElementsAttr";
527   using PyConcreteAttribute::PyConcreteAttribute;
528 
529   /// Returns the element at the given linear position. Asserts if the index is
530   /// out of range.
531   py::int_ dunderGetItem(intptr_t pos) {
532     if (pos < 0 || pos >= dunderLen()) {
533       throw SetPyError(PyExc_IndexError,
534                        "attempt to access out of bounds element");
535     }
536 
537     MlirType type = mlirAttributeGetType(*this);
538     type = mlirShapedTypeGetElementType(type);
539     assert(mlirTypeIsAInteger(type) &&
540            "expected integer element type in dense int elements attribute");
541     // Dispatch element extraction to an appropriate C function based on the
542     // elemental type of the attribute. py::int_ is implicitly constructible
543     // from any C++ integral type and handles bitwidth correctly.
544     // TODO: consider caching the type properties in the constructor to avoid
545     // querying them on each element access.
546     unsigned width = mlirIntegerTypeGetWidth(type);
547     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
548     if (isUnsigned) {
549       if (width == 1) {
550         return mlirDenseElementsAttrGetBoolValue(*this, pos);
551       }
552       if (width == 32) {
553         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
554       }
555       if (width == 64) {
556         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
557       }
558     } else {
559       if (width == 1) {
560         return mlirDenseElementsAttrGetBoolValue(*this, pos);
561       }
562       if (width == 32) {
563         return mlirDenseElementsAttrGetInt32Value(*this, pos);
564       }
565       if (width == 64) {
566         return mlirDenseElementsAttrGetInt64Value(*this, pos);
567       }
568     }
569     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
570   }
571 
572   static void bindDerived(ClassTy &c) {
573     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
574   }
575 };
576 
577 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
578 public:
579   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
580   static constexpr const char *pyClassName = "DictAttr";
581   using PyConcreteAttribute::PyConcreteAttribute;
582 
583   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
584 
585   static void bindDerived(ClassTy &c) {
586     c.def("__len__", &PyDictAttribute::dunderLen);
587     c.def_static(
588         "get",
589         [](py::dict attributes, DefaultingPyMlirContext context) {
590           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
591           mlirNamedAttributes.reserve(attributes.size());
592           for (auto &it : attributes) {
593             auto &mlir_attr = it.second.cast<PyAttribute &>();
594             auto name = it.first.cast<std::string>();
595             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
596                 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
597                                   toMlirStringRef(name)),
598                 mlir_attr));
599           }
600           MlirAttribute attr =
601               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
602                                     mlirNamedAttributes.data());
603           return PyDictAttribute(context->getRef(), attr);
604         },
605         py::arg("value"), py::arg("context") = py::none(),
606         "Gets an uniqued dict attribute");
607     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
608       MlirAttribute attr =
609           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
610       if (mlirAttributeIsNull(attr)) {
611         throw SetPyError(PyExc_KeyError,
612                          "attempt to access a non-existent attribute");
613       }
614       return PyAttribute(self.getContext(), attr);
615     });
616     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
617       if (index < 0 || index >= self.dunderLen()) {
618         throw SetPyError(PyExc_IndexError,
619                          "attempt to access out of bounds attribute");
620       }
621       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
622       return PyNamedAttribute(
623           namedAttr.attribute,
624           std::string(mlirIdentifierStr(namedAttr.name).data));
625     });
626   }
627 };
628 
629 /// Refinement of PyDenseElementsAttribute for attributes containing
630 /// floating-point values. Supports element access.
631 class PyDenseFPElementsAttribute
632     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
633                                  PyDenseElementsAttribute> {
634 public:
635   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
636   static constexpr const char *pyClassName = "DenseFPElementsAttr";
637   using PyConcreteAttribute::PyConcreteAttribute;
638 
639   py::float_ dunderGetItem(intptr_t pos) {
640     if (pos < 0 || pos >= dunderLen()) {
641       throw SetPyError(PyExc_IndexError,
642                        "attempt to access out of bounds element");
643     }
644 
645     MlirType type = mlirAttributeGetType(*this);
646     type = mlirShapedTypeGetElementType(type);
647     // Dispatch element extraction to an appropriate C function based on the
648     // elemental type of the attribute. py::float_ is implicitly constructible
649     // from float and double.
650     // TODO: consider caching the type properties in the constructor to avoid
651     // querying them on each element access.
652     if (mlirTypeIsAF32(type)) {
653       return mlirDenseElementsAttrGetFloatValue(*this, pos);
654     }
655     if (mlirTypeIsAF64(type)) {
656       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
657     }
658     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
659   }
660 
661   static void bindDerived(ClassTy &c) {
662     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
663   }
664 };
665 
666 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
667 public:
668   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
669   static constexpr const char *pyClassName = "TypeAttr";
670   using PyConcreteAttribute::PyConcreteAttribute;
671 
672   static void bindDerived(ClassTy &c) {
673     c.def_static(
674         "get",
675         [](PyType value, DefaultingPyMlirContext context) {
676           MlirAttribute attr = mlirTypeAttrGet(value.get());
677           return PyTypeAttribute(context->getRef(), attr);
678         },
679         py::arg("value"), py::arg("context") = py::none(),
680         "Gets a uniqued Type attribute");
681     c.def_property_readonly("value", [](PyTypeAttribute &self) {
682       return PyType(self.getContext()->getRef(),
683                     mlirTypeAttrGetValue(self.get()));
684     });
685   }
686 };
687 
688 /// Unit Attribute subclass. Unit attributes don't have values.
689 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
690 public:
691   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
692   static constexpr const char *pyClassName = "UnitAttr";
693   using PyConcreteAttribute::PyConcreteAttribute;
694 
695   static void bindDerived(ClassTy &c) {
696     c.def_static(
697         "get",
698         [](DefaultingPyMlirContext context) {
699           return PyUnitAttribute(context->getRef(),
700                                  mlirUnitAttrGet(context->get()));
701         },
702         py::arg("context") = py::none(), "Create a Unit attribute.");
703   }
704 };
705 
706 } // namespace
707 
708 void mlir::python::populateIRAttributes(py::module &m) {
709   PyAffineMapAttribute::bind(m);
710   PyArrayAttribute::bind(m);
711   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
712   PyBoolAttribute::bind(m);
713   PyDenseElementsAttribute::bind(m);
714   PyDenseFPElementsAttribute::bind(m);
715   PyDenseIntElementsAttribute::bind(m);
716   PyDictAttribute::bind(m);
717   PyFlatSymbolRefAttribute::bind(m);
718   PyFloatAttribute::bind(m);
719   PyIntegerAttribute::bind(m);
720   PyStringAttribute::bind(m);
721   PyTypeAttribute::bind(m);
722   PyUnitAttribute::bind(m);
723 }
724