1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "PybindUtils.h"
12 
13 #include "mlir-c/BuiltinAttributes.h"
14 #include "mlir-c/BuiltinTypes.h"
15 
16 namespace py = pybind11;
17 using namespace mlir;
18 using namespace mlir::python;
19 
20 using llvm::SmallVector;
21 using llvm::StringRef;
22 using llvm::Twine;
23 
24 namespace {
25 
26 static MlirStringRef toMlirStringRef(const std::string &s) {
27   return mlirStringRefCreate(s.data(), s.size());
28 }
29 
30 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
31 public:
32   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
33   static constexpr const char *pyClassName = "AffineMapAttr";
34   using PyConcreteAttribute::PyConcreteAttribute;
35 
36   static void bindDerived(ClassTy &c) {
37     c.def_static(
38         "get",
39         [](PyAffineMap &affineMap) {
40           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
41           return PyAffineMapAttribute(affineMap.getContext(), attr);
42         },
43         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
44   }
45 };
46 
47 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
48 public:
49   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
50   static constexpr const char *pyClassName = "ArrayAttr";
51   using PyConcreteAttribute::PyConcreteAttribute;
52 
53   class PyArrayAttributeIterator {
54   public:
55     PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
56 
57     PyArrayAttributeIterator &dunderIter() { return *this; }
58 
59     PyAttribute dunderNext() {
60       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
61         throw py::stop_iteration();
62       }
63       return PyAttribute(attr.getContext(),
64                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
65     }
66 
67     static void bind(py::module &m) {
68       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
69           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
70           .def("__next__", &PyArrayAttributeIterator::dunderNext);
71     }
72 
73   private:
74     PyAttribute attr;
75     int nextIndex = 0;
76   };
77 
78   static void bindDerived(ClassTy &c) {
79     c.def_static(
80         "get",
81         [](py::list attributes, DefaultingPyMlirContext context) {
82           SmallVector<MlirAttribute> mlirAttributes;
83           mlirAttributes.reserve(py::len(attributes));
84           for (auto attribute : attributes) {
85             try {
86               mlirAttributes.push_back(attribute.cast<PyAttribute>());
87             } catch (py::cast_error &err) {
88               std::string msg = std::string("Invalid attribute when attempting "
89                                             "to create an ArrayAttribute (") +
90                                 err.what() + ")";
91               throw py::cast_error(msg);
92             } catch (py::reference_cast_error &err) {
93               // This exception seems thrown when the value is "None".
94               std::string msg =
95                   std::string("Invalid attribute (None?) when attempting to "
96                               "create an ArrayAttribute (") +
97                   err.what() + ")";
98               throw py::cast_error(msg);
99             }
100           }
101           MlirAttribute attr = mlirArrayAttrGet(
102               context->get(), mlirAttributes.size(), mlirAttributes.data());
103           return PyArrayAttribute(context->getRef(), attr);
104         },
105         py::arg("attributes"), py::arg("context") = py::none(),
106         "Gets a uniqued Array attribute");
107     c.def("__getitem__",
108           [](PyArrayAttribute &arr, intptr_t i) {
109             if (i >= mlirArrayAttrGetNumElements(arr))
110               throw py::index_error("ArrayAttribute index out of range");
111             return PyAttribute(arr.getContext(),
112                                mlirArrayAttrGetElement(arr, i));
113           })
114         .def("__len__",
115              [](const PyArrayAttribute &arr) {
116                return mlirArrayAttrGetNumElements(arr);
117              })
118         .def("__iter__", [](const PyArrayAttribute &arr) {
119           return PyArrayAttributeIterator(arr);
120         });
121   }
122 };
123 
124 /// Float Point Attribute subclass - FloatAttr.
125 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
126 public:
127   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
128   static constexpr const char *pyClassName = "FloatAttr";
129   using PyConcreteAttribute::PyConcreteAttribute;
130 
131   static void bindDerived(ClassTy &c) {
132     c.def_static(
133         "get",
134         [](PyType &type, double value, DefaultingPyLocation loc) {
135           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
136           // TODO: Rework error reporting once diagnostic engine is exposed
137           // in C API.
138           if (mlirAttributeIsNull(attr)) {
139             throw SetPyError(PyExc_ValueError,
140                              Twine("invalid '") +
141                                  py::repr(py::cast(type)).cast<std::string>() +
142                                  "' and expected floating point type.");
143           }
144           return PyFloatAttribute(type.getContext(), attr);
145         },
146         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
147         "Gets an uniqued float point attribute associated to a type");
148     c.def_static(
149         "get_f32",
150         [](double value, DefaultingPyMlirContext context) {
151           MlirAttribute attr = mlirFloatAttrDoubleGet(
152               context->get(), mlirF32TypeGet(context->get()), value);
153           return PyFloatAttribute(context->getRef(), attr);
154         },
155         py::arg("value"), py::arg("context") = py::none(),
156         "Gets an uniqued float point attribute associated to a f32 type");
157     c.def_static(
158         "get_f64",
159         [](double value, DefaultingPyMlirContext context) {
160           MlirAttribute attr = mlirFloatAttrDoubleGet(
161               context->get(), mlirF64TypeGet(context->get()), value);
162           return PyFloatAttribute(context->getRef(), attr);
163         },
164         py::arg("value"), py::arg("context") = py::none(),
165         "Gets an uniqued float point attribute associated to a f64 type");
166     c.def_property_readonly(
167         "value",
168         [](PyFloatAttribute &self) {
169           return mlirFloatAttrGetValueDouble(self);
170         },
171         "Returns the value of the float point attribute");
172   }
173 };
174 
175 /// Integer Attribute subclass - IntegerAttr.
176 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
177 public:
178   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
179   static constexpr const char *pyClassName = "IntegerAttr";
180   using PyConcreteAttribute::PyConcreteAttribute;
181 
182   static void bindDerived(ClassTy &c) {
183     c.def_static(
184         "get",
185         [](PyType &type, int64_t value) {
186           MlirAttribute attr = mlirIntegerAttrGet(type, value);
187           return PyIntegerAttribute(type.getContext(), attr);
188         },
189         py::arg("type"), py::arg("value"),
190         "Gets an uniqued integer attribute associated to a type");
191     c.def_property_readonly(
192         "value",
193         [](PyIntegerAttribute &self) {
194           return mlirIntegerAttrGetValueInt(self);
195         },
196         "Returns the value of the integer attribute");
197   }
198 };
199 
200 /// Bool Attribute subclass - BoolAttr.
201 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
202 public:
203   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
204   static constexpr const char *pyClassName = "BoolAttr";
205   using PyConcreteAttribute::PyConcreteAttribute;
206 
207   static void bindDerived(ClassTy &c) {
208     c.def_static(
209         "get",
210         [](bool value, DefaultingPyMlirContext context) {
211           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
212           return PyBoolAttribute(context->getRef(), attr);
213         },
214         py::arg("value"), py::arg("context") = py::none(),
215         "Gets an uniqued bool attribute");
216     c.def_property_readonly(
217         "value",
218         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
219         "Returns the value of the bool attribute");
220   }
221 };
222 
223 class PyFlatSymbolRefAttribute
224     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
225 public:
226   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
227   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
228   using PyConcreteAttribute::PyConcreteAttribute;
229 
230   static void bindDerived(ClassTy &c) {
231     c.def_static(
232         "get",
233         [](std::string value, DefaultingPyMlirContext context) {
234           MlirAttribute attr =
235               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
236           return PyFlatSymbolRefAttribute(context->getRef(), attr);
237         },
238         py::arg("value"), py::arg("context") = py::none(),
239         "Gets a uniqued FlatSymbolRef attribute");
240     c.def_property_readonly(
241         "value",
242         [](PyFlatSymbolRefAttribute &self) {
243           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
244           return py::str(stringRef.data, stringRef.length);
245         },
246         "Returns the value of the FlatSymbolRef attribute as a string");
247   }
248 };
249 
250 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
251 public:
252   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
253   static constexpr const char *pyClassName = "StringAttr";
254   using PyConcreteAttribute::PyConcreteAttribute;
255 
256   static void bindDerived(ClassTy &c) {
257     c.def_static(
258         "get",
259         [](std::string value, DefaultingPyMlirContext context) {
260           MlirAttribute attr =
261               mlirStringAttrGet(context->get(), toMlirStringRef(value));
262           return PyStringAttribute(context->getRef(), attr);
263         },
264         py::arg("value"), py::arg("context") = py::none(),
265         "Gets a uniqued string attribute");
266     c.def_static(
267         "get_typed",
268         [](PyType &type, std::string value) {
269           MlirAttribute attr =
270               mlirStringAttrTypedGet(type, toMlirStringRef(value));
271           return PyStringAttribute(type.getContext(), attr);
272         },
273 
274         "Gets a uniqued string attribute associated to a type");
275     c.def_property_readonly(
276         "value",
277         [](PyStringAttribute &self) {
278           MlirStringRef stringRef = mlirStringAttrGetValue(self);
279           return py::str(stringRef.data, stringRef.length);
280         },
281         "Returns the value of the string attribute");
282   }
283 };
284 
285 // TODO: Support construction of bool elements.
286 // TODO: Support construction of string elements.
287 class PyDenseElementsAttribute
288     : public PyConcreteAttribute<PyDenseElementsAttribute> {
289 public:
290   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
291   static constexpr const char *pyClassName = "DenseElementsAttr";
292   using PyConcreteAttribute::PyConcreteAttribute;
293 
294   static PyDenseElementsAttribute
295   getFromBuffer(py::buffer array, bool signless,
296                 DefaultingPyMlirContext contextWrapper) {
297     // Request a contiguous view. In exotic cases, this will cause a copy.
298     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
299     Py_buffer *view = new Py_buffer();
300     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
301       delete view;
302       throw py::error_already_set();
303     }
304     py::buffer_info arrayInfo(view);
305 
306     MlirContext context = contextWrapper->get();
307     // Switch on the types that can be bulk loaded between the Python and
308     // MLIR-C APIs.
309     // See: https://docs.python.org/3/library/struct.html#format-characters
310     if (arrayInfo.format == "f") {
311       // f32
312       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
313       return PyDenseElementsAttribute(
314           contextWrapper->getRef(),
315           bulkLoad(context, mlirDenseElementsAttrFloatGet,
316                    mlirF32TypeGet(context), arrayInfo));
317     } else if (arrayInfo.format == "d") {
318       // f64
319       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
320       return PyDenseElementsAttribute(
321           contextWrapper->getRef(),
322           bulkLoad(context, mlirDenseElementsAttrDoubleGet,
323                    mlirF64TypeGet(context), arrayInfo));
324     } else if (isSignedIntegerFormat(arrayInfo.format)) {
325       if (arrayInfo.itemsize == 4) {
326         // i32
327         MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
328                                         : mlirIntegerTypeSignedGet(context, 32);
329         return PyDenseElementsAttribute(contextWrapper->getRef(),
330                                         bulkLoad(context,
331                                                  mlirDenseElementsAttrInt32Get,
332                                                  elementType, arrayInfo));
333       } else if (arrayInfo.itemsize == 8) {
334         // i64
335         MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
336                                         : mlirIntegerTypeSignedGet(context, 64);
337         return PyDenseElementsAttribute(contextWrapper->getRef(),
338                                         bulkLoad(context,
339                                                  mlirDenseElementsAttrInt64Get,
340                                                  elementType, arrayInfo));
341       }
342     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
343       if (arrayInfo.itemsize == 4) {
344         // unsigned i32
345         MlirType elementType = signless
346                                    ? mlirIntegerTypeGet(context, 32)
347                                    : mlirIntegerTypeUnsignedGet(context, 32);
348         return PyDenseElementsAttribute(contextWrapper->getRef(),
349                                         bulkLoad(context,
350                                                  mlirDenseElementsAttrUInt32Get,
351                                                  elementType, arrayInfo));
352       } else if (arrayInfo.itemsize == 8) {
353         // unsigned i64
354         MlirType elementType = signless
355                                    ? mlirIntegerTypeGet(context, 64)
356                                    : mlirIntegerTypeUnsignedGet(context, 64);
357         return PyDenseElementsAttribute(contextWrapper->getRef(),
358                                         bulkLoad(context,
359                                                  mlirDenseElementsAttrUInt64Get,
360                                                  elementType, arrayInfo));
361       }
362     }
363 
364     // TODO: Fall back to string-based get.
365     std::string message = "unimplemented array format conversion from format: ";
366     message.append(arrayInfo.format);
367     throw SetPyError(PyExc_ValueError, message);
368   }
369 
370   static PyDenseElementsAttribute getSplat(PyType shapedType,
371                                            PyAttribute &elementAttr) {
372     auto contextWrapper =
373         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
374     if (!mlirAttributeIsAInteger(elementAttr) &&
375         !mlirAttributeIsAFloat(elementAttr)) {
376       std::string message = "Illegal element type for DenseElementsAttr: ";
377       message.append(py::repr(py::cast(elementAttr)));
378       throw SetPyError(PyExc_ValueError, message);
379     }
380     if (!mlirTypeIsAShaped(shapedType) ||
381         !mlirShapedTypeHasStaticShape(shapedType)) {
382       std::string message =
383           "Expected a static ShapedType for the shaped_type parameter: ";
384       message.append(py::repr(py::cast(shapedType)));
385       throw SetPyError(PyExc_ValueError, message);
386     }
387     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
388     MlirType attrType = mlirAttributeGetType(elementAttr);
389     if (!mlirTypeEqual(shapedElementType, attrType)) {
390       std::string message =
391           "Shaped element type and attribute type must be equal: shaped=";
392       message.append(py::repr(py::cast(shapedType)));
393       message.append(", element=");
394       message.append(py::repr(py::cast(elementAttr)));
395       throw SetPyError(PyExc_ValueError, message);
396     }
397 
398     MlirAttribute elements =
399         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
400     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
401   }
402 
403   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
404 
405   py::buffer_info accessBuffer() {
406     MlirType shapedType = mlirAttributeGetType(*this);
407     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
408 
409     if (mlirTypeIsAF32(elementType)) {
410       // f32
411       return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
412     } else if (mlirTypeIsAF64(elementType)) {
413       // f64
414       return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
415     } else if (mlirTypeIsAInteger(elementType) &&
416                mlirIntegerTypeGetWidth(elementType) == 32) {
417       if (mlirIntegerTypeIsSignless(elementType) ||
418           mlirIntegerTypeIsSigned(elementType)) {
419         // i32
420         return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
421       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
422         // unsigned i32
423         return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
424       }
425     } else if (mlirTypeIsAInteger(elementType) &&
426                mlirIntegerTypeGetWidth(elementType) == 64) {
427       if (mlirIntegerTypeIsSignless(elementType) ||
428           mlirIntegerTypeIsSigned(elementType)) {
429         // i64
430         return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
431       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
432         // unsigned i64
433         return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
434       }
435     }
436 
437     std::string message = "unimplemented array format.";
438     throw SetPyError(PyExc_ValueError, message);
439   }
440 
441   static void bindDerived(ClassTy &c) {
442     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
443         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
444                     py::arg("array"), py::arg("signless") = true,
445                     py::arg("context") = py::none(),
446                     "Gets from a buffer or ndarray")
447         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
448                     py::arg("shaped_type"), py::arg("element_attr"),
449                     "Gets a DenseElementsAttr where all values are the same")
450         .def_property_readonly("is_splat",
451                                [](PyDenseElementsAttribute &self) -> bool {
452                                  return mlirDenseElementsAttrIsSplat(self);
453                                })
454         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
455   }
456 
457 private:
458   template <typename ElementTy>
459   static MlirAttribute
460   bulkLoad(MlirContext context,
461            MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
462            MlirType mlirElementType, py::buffer_info &arrayInfo) {
463     SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
464                                   arrayInfo.shape.begin() + arrayInfo.ndim);
465     MlirAttribute encodingAttr = mlirAttributeGetNull();
466     auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
467                                               mlirElementType, encodingAttr);
468     intptr_t numElements = arrayInfo.size;
469     const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
470     return ctor(shapedType, numElements, contents);
471   }
472 
473   static bool isUnsignedIntegerFormat(const std::string &format) {
474     if (format.empty())
475       return false;
476     char code = format[0];
477     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
478            code == 'Q';
479   }
480 
481   static bool isSignedIntegerFormat(const std::string &format) {
482     if (format.empty())
483       return false;
484     char code = format[0];
485     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
486            code == 'q';
487   }
488 
489   template <typename Type>
490   py::buffer_info bufferInfo(MlirType shapedType,
491                              Type (*value)(MlirAttribute, intptr_t)) {
492     intptr_t rank = mlirShapedTypeGetRank(shapedType);
493     // Prepare the data for the buffer_info.
494     // Buffer is configured for read-only access below.
495     Type *data = static_cast<Type *>(
496         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
497     // Prepare the shape for the buffer_info.
498     SmallVector<intptr_t, 4> shape;
499     for (intptr_t i = 0; i < rank; ++i)
500       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
501     // Prepare the strides for the buffer_info.
502     SmallVector<intptr_t, 4> strides;
503     intptr_t strideFactor = 1;
504     for (intptr_t i = 1; i < rank; ++i) {
505       strideFactor = 1;
506       for (intptr_t j = i; j < rank; ++j) {
507         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
508       }
509       strides.push_back(sizeof(Type) * strideFactor);
510     }
511     strides.push_back(sizeof(Type));
512     return py::buffer_info(data, sizeof(Type),
513                            py::format_descriptor<Type>::format(), rank, shape,
514                            strides, /*readonly=*/true);
515   }
516 }; // namespace
517 
518 /// Refinement of the PyDenseElementsAttribute for attributes containing integer
519 /// (and boolean) values. Supports element access.
520 class PyDenseIntElementsAttribute
521     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
522                                  PyDenseElementsAttribute> {
523 public:
524   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
525   static constexpr const char *pyClassName = "DenseIntElementsAttr";
526   using PyConcreteAttribute::PyConcreteAttribute;
527 
528   /// Returns the element at the given linear position. Asserts if the index is
529   /// out of range.
530   py::int_ dunderGetItem(intptr_t pos) {
531     if (pos < 0 || pos >= dunderLen()) {
532       throw SetPyError(PyExc_IndexError,
533                        "attempt to access out of bounds element");
534     }
535 
536     MlirType type = mlirAttributeGetType(*this);
537     type = mlirShapedTypeGetElementType(type);
538     assert(mlirTypeIsAInteger(type) &&
539            "expected integer element type in dense int elements attribute");
540     // Dispatch element extraction to an appropriate C function based on the
541     // elemental type of the attribute. py::int_ is implicitly constructible
542     // from any C++ integral type and handles bitwidth correctly.
543     // TODO: consider caching the type properties in the constructor to avoid
544     // querying them on each element access.
545     unsigned width = mlirIntegerTypeGetWidth(type);
546     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
547     if (isUnsigned) {
548       if (width == 1) {
549         return mlirDenseElementsAttrGetBoolValue(*this, pos);
550       }
551       if (width == 32) {
552         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
553       }
554       if (width == 64) {
555         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
556       }
557     } else {
558       if (width == 1) {
559         return mlirDenseElementsAttrGetBoolValue(*this, pos);
560       }
561       if (width == 32) {
562         return mlirDenseElementsAttrGetInt32Value(*this, pos);
563       }
564       if (width == 64) {
565         return mlirDenseElementsAttrGetInt64Value(*this, pos);
566       }
567     }
568     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
569   }
570 
571   static void bindDerived(ClassTy &c) {
572     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
573   }
574 };
575 
576 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
577 public:
578   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
579   static constexpr const char *pyClassName = "DictAttr";
580   using PyConcreteAttribute::PyConcreteAttribute;
581 
582   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
583 
584   static void bindDerived(ClassTy &c) {
585     c.def("__len__", &PyDictAttribute::dunderLen);
586     c.def_static(
587         "get",
588         [](py::dict attributes, DefaultingPyMlirContext context) {
589           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
590           mlirNamedAttributes.reserve(attributes.size());
591           for (auto &it : attributes) {
592             auto &mlir_attr = it.second.cast<PyAttribute &>();
593             auto name = it.first.cast<std::string>();
594             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
595                 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
596                                   toMlirStringRef(name)),
597                 mlir_attr));
598           }
599           MlirAttribute attr =
600               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
601                                     mlirNamedAttributes.data());
602           return PyDictAttribute(context->getRef(), attr);
603         },
604         py::arg("value"), py::arg("context") = py::none(),
605         "Gets an uniqued dict attribute");
606     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
607       MlirAttribute attr =
608           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
609       if (mlirAttributeIsNull(attr)) {
610         throw SetPyError(PyExc_KeyError,
611                          "attempt to access a non-existent attribute");
612       }
613       return PyAttribute(self.getContext(), attr);
614     });
615     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
616       if (index < 0 || index >= self.dunderLen()) {
617         throw SetPyError(PyExc_IndexError,
618                          "attempt to access out of bounds attribute");
619       }
620       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
621       return PyNamedAttribute(
622           namedAttr.attribute,
623           std::string(mlirIdentifierStr(namedAttr.name).data));
624     });
625   }
626 };
627 
628 /// Refinement of PyDenseElementsAttribute for attributes containing
629 /// floating-point values. Supports element access.
630 class PyDenseFPElementsAttribute
631     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
632                                  PyDenseElementsAttribute> {
633 public:
634   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
635   static constexpr const char *pyClassName = "DenseFPElementsAttr";
636   using PyConcreteAttribute::PyConcreteAttribute;
637 
638   py::float_ dunderGetItem(intptr_t pos) {
639     if (pos < 0 || pos >= dunderLen()) {
640       throw SetPyError(PyExc_IndexError,
641                        "attempt to access out of bounds element");
642     }
643 
644     MlirType type = mlirAttributeGetType(*this);
645     type = mlirShapedTypeGetElementType(type);
646     // Dispatch element extraction to an appropriate C function based on the
647     // elemental type of the attribute. py::float_ is implicitly constructible
648     // from float and double.
649     // TODO: consider caching the type properties in the constructor to avoid
650     // querying them on each element access.
651     if (mlirTypeIsAF32(type)) {
652       return mlirDenseElementsAttrGetFloatValue(*this, pos);
653     }
654     if (mlirTypeIsAF64(type)) {
655       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
656     }
657     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
658   }
659 
660   static void bindDerived(ClassTy &c) {
661     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
662   }
663 };
664 
665 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
666 public:
667   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
668   static constexpr const char *pyClassName = "TypeAttr";
669   using PyConcreteAttribute::PyConcreteAttribute;
670 
671   static void bindDerived(ClassTy &c) {
672     c.def_static(
673         "get",
674         [](PyType value, DefaultingPyMlirContext context) {
675           MlirAttribute attr = mlirTypeAttrGet(value.get());
676           return PyTypeAttribute(context->getRef(), attr);
677         },
678         py::arg("value"), py::arg("context") = py::none(),
679         "Gets a uniqued Type attribute");
680     c.def_property_readonly("value", [](PyTypeAttribute &self) {
681       return PyType(self.getContext()->getRef(),
682                     mlirTypeAttrGetValue(self.get()));
683     });
684   }
685 };
686 
687 /// Unit Attribute subclass. Unit attributes don't have values.
688 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
689 public:
690   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
691   static constexpr const char *pyClassName = "UnitAttr";
692   using PyConcreteAttribute::PyConcreteAttribute;
693 
694   static void bindDerived(ClassTy &c) {
695     c.def_static(
696         "get",
697         [](DefaultingPyMlirContext context) {
698           return PyUnitAttribute(context->getRef(),
699                                  mlirUnitAttrGet(context->get()));
700         },
701         py::arg("context") = py::none(), "Create a Unit attribute.");
702   }
703 };
704 
705 } // namespace
706 
707 void mlir::python::populateIRAttributes(py::module &m) {
708   PyAffineMapAttribute::bind(m);
709   PyArrayAttribute::bind(m);
710   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
711   PyBoolAttribute::bind(m);
712   PyDenseElementsAttribute::bind(m);
713   PyDenseFPElementsAttribute::bind(m);
714   PyDenseIntElementsAttribute::bind(m);
715   PyDictAttribute::bind(m);
716   PyFlatSymbolRefAttribute::bind(m);
717   PyFloatAttribute::bind(m);
718   PyIntegerAttribute::bind(m);
719   PyStringAttribute::bind(m);
720   PyTypeAttribute::bind(m);
721   PyUnitAttribute::bind(m);
722 }
723