1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "PybindUtils.h"
12 
13 #include "mlir-c/BuiltinAttributes.h"
14 #include "mlir-c/BuiltinTypes.h"
15 
16 namespace py = pybind11;
17 using namespace mlir;
18 using namespace mlir::python;
19 
20 using llvm::SmallVector;
21 using llvm::StringRef;
22 using llvm::Twine;
23 
24 namespace {
25 
26 static MlirStringRef toMlirStringRef(const std::string &s) {
27   return mlirStringRefCreate(s.data(), s.size());
28 }
29 
30 /// CRTP base classes for Python attributes that subclass Attribute and should
31 /// be castable from it (i.e. via something like StringAttr(attr)).
32 /// By default, attribute class hierarchies are one level deep (i.e. a
33 /// concrete attribute class extends PyAttribute); however, intermediate
34 /// python-visible base classes can be modeled by specifying a BaseTy.
35 template <typename DerivedTy, typename BaseTy = PyAttribute>
36 class PyConcreteAttribute : public BaseTy {
37 public:
38   // Derived classes must define statics for:
39   //   IsAFunctionTy isaFunction
40   //   const char *pyClassName
41   using ClassTy = py::class_<DerivedTy, BaseTy>;
42   using IsAFunctionTy = bool (*)(MlirAttribute);
43 
44   PyConcreteAttribute() = default;
45   PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
46       : BaseTy(std::move(contextRef), attr) {}
47   PyConcreteAttribute(PyAttribute &orig)
48       : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
49 
50   static MlirAttribute castFrom(PyAttribute &orig) {
51     if (!DerivedTy::isaFunction(orig)) {
52       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
53       throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
54                                              DerivedTy::pyClassName +
55                                              " (from " + origRepr + ")");
56     }
57     return orig;
58   }
59 
60   static void bind(py::module &m) {
61     auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol());
62     cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
63     DerivedTy::bindDerived(cls);
64   }
65 
66   /// Implemented by derived classes to add methods to the Python subclass.
67   static void bindDerived(ClassTy &m) {}
68 };
69 
70 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
71 public:
72   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
73   static constexpr const char *pyClassName = "AffineMapAttr";
74   using PyConcreteAttribute::PyConcreteAttribute;
75 
76   static void bindDerived(ClassTy &c) {
77     c.def_static(
78         "get",
79         [](PyAffineMap &affineMap) {
80           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
81           return PyAffineMapAttribute(affineMap.getContext(), attr);
82         },
83         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
84   }
85 };
86 
87 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
88 public:
89   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
90   static constexpr const char *pyClassName = "ArrayAttr";
91   using PyConcreteAttribute::PyConcreteAttribute;
92 
93   class PyArrayAttributeIterator {
94   public:
95     PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
96 
97     PyArrayAttributeIterator &dunderIter() { return *this; }
98 
99     PyAttribute dunderNext() {
100       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
101         throw py::stop_iteration();
102       }
103       return PyAttribute(attr.getContext(),
104                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
105     }
106 
107     static void bind(py::module &m) {
108       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
109           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
110           .def("__next__", &PyArrayAttributeIterator::dunderNext);
111     }
112 
113   private:
114     PyAttribute attr;
115     int nextIndex = 0;
116   };
117 
118   static void bindDerived(ClassTy &c) {
119     c.def_static(
120         "get",
121         [](py::list attributes, DefaultingPyMlirContext context) {
122           SmallVector<MlirAttribute> mlirAttributes;
123           mlirAttributes.reserve(py::len(attributes));
124           for (auto attribute : attributes) {
125             try {
126               mlirAttributes.push_back(attribute.cast<PyAttribute>());
127             } catch (py::cast_error &err) {
128               std::string msg = std::string("Invalid attribute when attempting "
129                                             "to create an ArrayAttribute (") +
130                                 err.what() + ")";
131               throw py::cast_error(msg);
132             } catch (py::reference_cast_error &err) {
133               // This exception seems thrown when the value is "None".
134               std::string msg =
135                   std::string("Invalid attribute (None?) when attempting to "
136                               "create an ArrayAttribute (") +
137                   err.what() + ")";
138               throw py::cast_error(msg);
139             }
140           }
141           MlirAttribute attr = mlirArrayAttrGet(
142               context->get(), mlirAttributes.size(), mlirAttributes.data());
143           return PyArrayAttribute(context->getRef(), attr);
144         },
145         py::arg("attributes"), py::arg("context") = py::none(),
146         "Gets a uniqued Array attribute");
147     c.def("__getitem__",
148           [](PyArrayAttribute &arr, intptr_t i) {
149             if (i >= mlirArrayAttrGetNumElements(arr))
150               throw py::index_error("ArrayAttribute index out of range");
151             return PyAttribute(arr.getContext(),
152                                mlirArrayAttrGetElement(arr, i));
153           })
154         .def("__len__",
155              [](const PyArrayAttribute &arr) {
156                return mlirArrayAttrGetNumElements(arr);
157              })
158         .def("__iter__", [](const PyArrayAttribute &arr) {
159           return PyArrayAttributeIterator(arr);
160         });
161   }
162 };
163 
164 /// Float Point Attribute subclass - FloatAttr.
165 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
166 public:
167   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
168   static constexpr const char *pyClassName = "FloatAttr";
169   using PyConcreteAttribute::PyConcreteAttribute;
170 
171   static void bindDerived(ClassTy &c) {
172     c.def_static(
173         "get",
174         [](PyType &type, double value, DefaultingPyLocation loc) {
175           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
176           // TODO: Rework error reporting once diagnostic engine is exposed
177           // in C API.
178           if (mlirAttributeIsNull(attr)) {
179             throw SetPyError(PyExc_ValueError,
180                              Twine("invalid '") +
181                                  py::repr(py::cast(type)).cast<std::string>() +
182                                  "' and expected floating point type.");
183           }
184           return PyFloatAttribute(type.getContext(), attr);
185         },
186         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
187         "Gets an uniqued float point attribute associated to a type");
188     c.def_static(
189         "get_f32",
190         [](double value, DefaultingPyMlirContext context) {
191           MlirAttribute attr = mlirFloatAttrDoubleGet(
192               context->get(), mlirF32TypeGet(context->get()), value);
193           return PyFloatAttribute(context->getRef(), attr);
194         },
195         py::arg("value"), py::arg("context") = py::none(),
196         "Gets an uniqued float point attribute associated to a f32 type");
197     c.def_static(
198         "get_f64",
199         [](double value, DefaultingPyMlirContext context) {
200           MlirAttribute attr = mlirFloatAttrDoubleGet(
201               context->get(), mlirF64TypeGet(context->get()), value);
202           return PyFloatAttribute(context->getRef(), attr);
203         },
204         py::arg("value"), py::arg("context") = py::none(),
205         "Gets an uniqued float point attribute associated to a f64 type");
206     c.def_property_readonly(
207         "value",
208         [](PyFloatAttribute &self) {
209           return mlirFloatAttrGetValueDouble(self);
210         },
211         "Returns the value of the float point attribute");
212   }
213 };
214 
215 /// Integer Attribute subclass - IntegerAttr.
216 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
217 public:
218   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
219   static constexpr const char *pyClassName = "IntegerAttr";
220   using PyConcreteAttribute::PyConcreteAttribute;
221 
222   static void bindDerived(ClassTy &c) {
223     c.def_static(
224         "get",
225         [](PyType &type, int64_t value) {
226           MlirAttribute attr = mlirIntegerAttrGet(type, value);
227           return PyIntegerAttribute(type.getContext(), attr);
228         },
229         py::arg("type"), py::arg("value"),
230         "Gets an uniqued integer attribute associated to a type");
231     c.def_property_readonly(
232         "value",
233         [](PyIntegerAttribute &self) {
234           return mlirIntegerAttrGetValueInt(self);
235         },
236         "Returns the value of the integer attribute");
237   }
238 };
239 
240 /// Bool Attribute subclass - BoolAttr.
241 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
242 public:
243   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
244   static constexpr const char *pyClassName = "BoolAttr";
245   using PyConcreteAttribute::PyConcreteAttribute;
246 
247   static void bindDerived(ClassTy &c) {
248     c.def_static(
249         "get",
250         [](bool value, DefaultingPyMlirContext context) {
251           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
252           return PyBoolAttribute(context->getRef(), attr);
253         },
254         py::arg("value"), py::arg("context") = py::none(),
255         "Gets an uniqued bool attribute");
256     c.def_property_readonly(
257         "value",
258         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
259         "Returns the value of the bool attribute");
260   }
261 };
262 
263 class PyFlatSymbolRefAttribute
264     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
265 public:
266   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
267   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
268   using PyConcreteAttribute::PyConcreteAttribute;
269 
270   static void bindDerived(ClassTy &c) {
271     c.def_static(
272         "get",
273         [](std::string value, DefaultingPyMlirContext context) {
274           MlirAttribute attr =
275               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
276           return PyFlatSymbolRefAttribute(context->getRef(), attr);
277         },
278         py::arg("value"), py::arg("context") = py::none(),
279         "Gets a uniqued FlatSymbolRef attribute");
280     c.def_property_readonly(
281         "value",
282         [](PyFlatSymbolRefAttribute &self) {
283           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
284           return py::str(stringRef.data, stringRef.length);
285         },
286         "Returns the value of the FlatSymbolRef attribute as a string");
287   }
288 };
289 
290 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
291 public:
292   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
293   static constexpr const char *pyClassName = "StringAttr";
294   using PyConcreteAttribute::PyConcreteAttribute;
295 
296   static void bindDerived(ClassTy &c) {
297     c.def_static(
298         "get",
299         [](std::string value, DefaultingPyMlirContext context) {
300           MlirAttribute attr =
301               mlirStringAttrGet(context->get(), toMlirStringRef(value));
302           return PyStringAttribute(context->getRef(), attr);
303         },
304         py::arg("value"), py::arg("context") = py::none(),
305         "Gets a uniqued string attribute");
306     c.def_static(
307         "get_typed",
308         [](PyType &type, std::string value) {
309           MlirAttribute attr =
310               mlirStringAttrTypedGet(type, toMlirStringRef(value));
311           return PyStringAttribute(type.getContext(), attr);
312         },
313 
314         "Gets a uniqued string attribute associated to a type");
315     c.def_property_readonly(
316         "value",
317         [](PyStringAttribute &self) {
318           MlirStringRef stringRef = mlirStringAttrGetValue(self);
319           return py::str(stringRef.data, stringRef.length);
320         },
321         "Returns the value of the string attribute");
322   }
323 };
324 
325 // TODO: Support construction of bool elements.
326 // TODO: Support construction of string elements.
327 class PyDenseElementsAttribute
328     : public PyConcreteAttribute<PyDenseElementsAttribute> {
329 public:
330   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
331   static constexpr const char *pyClassName = "DenseElementsAttr";
332   using PyConcreteAttribute::PyConcreteAttribute;
333 
334   static PyDenseElementsAttribute
335   getFromBuffer(py::buffer array, bool signless,
336                 DefaultingPyMlirContext contextWrapper) {
337     // Request a contiguous view. In exotic cases, this will cause a copy.
338     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
339     Py_buffer *view = new Py_buffer();
340     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
341       delete view;
342       throw py::error_already_set();
343     }
344     py::buffer_info arrayInfo(view);
345 
346     MlirContext context = contextWrapper->get();
347     // Switch on the types that can be bulk loaded between the Python and
348     // MLIR-C APIs.
349     // See: https://docs.python.org/3/library/struct.html#format-characters
350     if (arrayInfo.format == "f") {
351       // f32
352       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
353       return PyDenseElementsAttribute(
354           contextWrapper->getRef(),
355           bulkLoad(context, mlirDenseElementsAttrFloatGet,
356                    mlirF32TypeGet(context), arrayInfo));
357     } else if (arrayInfo.format == "d") {
358       // f64
359       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
360       return PyDenseElementsAttribute(
361           contextWrapper->getRef(),
362           bulkLoad(context, mlirDenseElementsAttrDoubleGet,
363                    mlirF64TypeGet(context), arrayInfo));
364     } else if (isSignedIntegerFormat(arrayInfo.format)) {
365       if (arrayInfo.itemsize == 4) {
366         // i32
367         MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
368                                         : mlirIntegerTypeSignedGet(context, 32);
369         return PyDenseElementsAttribute(contextWrapper->getRef(),
370                                         bulkLoad(context,
371                                                  mlirDenseElementsAttrInt32Get,
372                                                  elementType, arrayInfo));
373       } else if (arrayInfo.itemsize == 8) {
374         // i64
375         MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
376                                         : mlirIntegerTypeSignedGet(context, 64);
377         return PyDenseElementsAttribute(contextWrapper->getRef(),
378                                         bulkLoad(context,
379                                                  mlirDenseElementsAttrInt64Get,
380                                                  elementType, arrayInfo));
381       }
382     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
383       if (arrayInfo.itemsize == 4) {
384         // unsigned i32
385         MlirType elementType = signless
386                                    ? mlirIntegerTypeGet(context, 32)
387                                    : mlirIntegerTypeUnsignedGet(context, 32);
388         return PyDenseElementsAttribute(contextWrapper->getRef(),
389                                         bulkLoad(context,
390                                                  mlirDenseElementsAttrUInt32Get,
391                                                  elementType, arrayInfo));
392       } else if (arrayInfo.itemsize == 8) {
393         // unsigned i64
394         MlirType elementType = signless
395                                    ? mlirIntegerTypeGet(context, 64)
396                                    : mlirIntegerTypeUnsignedGet(context, 64);
397         return PyDenseElementsAttribute(contextWrapper->getRef(),
398                                         bulkLoad(context,
399                                                  mlirDenseElementsAttrUInt64Get,
400                                                  elementType, arrayInfo));
401       }
402     }
403 
404     // TODO: Fall back to string-based get.
405     std::string message = "unimplemented array format conversion from format: ";
406     message.append(arrayInfo.format);
407     throw SetPyError(PyExc_ValueError, message);
408   }
409 
410   static PyDenseElementsAttribute getSplat(PyType shapedType,
411                                            PyAttribute &elementAttr) {
412     auto contextWrapper =
413         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
414     if (!mlirAttributeIsAInteger(elementAttr) &&
415         !mlirAttributeIsAFloat(elementAttr)) {
416       std::string message = "Illegal element type for DenseElementsAttr: ";
417       message.append(py::repr(py::cast(elementAttr)));
418       throw SetPyError(PyExc_ValueError, message);
419     }
420     if (!mlirTypeIsAShaped(shapedType) ||
421         !mlirShapedTypeHasStaticShape(shapedType)) {
422       std::string message =
423           "Expected a static ShapedType for the shaped_type parameter: ";
424       message.append(py::repr(py::cast(shapedType)));
425       throw SetPyError(PyExc_ValueError, message);
426     }
427     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
428     MlirType attrType = mlirAttributeGetType(elementAttr);
429     if (!mlirTypeEqual(shapedElementType, attrType)) {
430       std::string message =
431           "Shaped element type and attribute type must be equal: shaped=";
432       message.append(py::repr(py::cast(shapedType)));
433       message.append(", element=");
434       message.append(py::repr(py::cast(elementAttr)));
435       throw SetPyError(PyExc_ValueError, message);
436     }
437 
438     MlirAttribute elements =
439         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
440     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
441   }
442 
443   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
444 
445   py::buffer_info accessBuffer() {
446     MlirType shapedType = mlirAttributeGetType(*this);
447     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
448 
449     if (mlirTypeIsAF32(elementType)) {
450       // f32
451       return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
452     } else if (mlirTypeIsAF64(elementType)) {
453       // f64
454       return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
455     } else if (mlirTypeIsAInteger(elementType) &&
456                mlirIntegerTypeGetWidth(elementType) == 32) {
457       if (mlirIntegerTypeIsSignless(elementType) ||
458           mlirIntegerTypeIsSigned(elementType)) {
459         // i32
460         return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
461       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
462         // unsigned i32
463         return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
464       }
465     } else if (mlirTypeIsAInteger(elementType) &&
466                mlirIntegerTypeGetWidth(elementType) == 64) {
467       if (mlirIntegerTypeIsSignless(elementType) ||
468           mlirIntegerTypeIsSigned(elementType)) {
469         // i64
470         return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
471       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
472         // unsigned i64
473         return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
474       }
475     }
476 
477     std::string message = "unimplemented array format.";
478     throw SetPyError(PyExc_ValueError, message);
479   }
480 
481   static void bindDerived(ClassTy &c) {
482     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
483         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
484                     py::arg("array"), py::arg("signless") = true,
485                     py::arg("context") = py::none(),
486                     "Gets from a buffer or ndarray")
487         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
488                     py::arg("shaped_type"), py::arg("element_attr"),
489                     "Gets a DenseElementsAttr where all values are the same")
490         .def_property_readonly("is_splat",
491                                [](PyDenseElementsAttribute &self) -> bool {
492                                  return mlirDenseElementsAttrIsSplat(self);
493                                })
494         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
495   }
496 
497 private:
498   template <typename ElementTy>
499   static MlirAttribute
500   bulkLoad(MlirContext context,
501            MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
502            MlirType mlirElementType, py::buffer_info &arrayInfo) {
503     SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
504                                   arrayInfo.shape.begin() + arrayInfo.ndim);
505     auto shapedType =
506         mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType);
507     intptr_t numElements = arrayInfo.size;
508     const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
509     return ctor(shapedType, numElements, contents);
510   }
511 
512   static bool isUnsignedIntegerFormat(const std::string &format) {
513     if (format.empty())
514       return false;
515     char code = format[0];
516     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
517            code == 'Q';
518   }
519 
520   static bool isSignedIntegerFormat(const std::string &format) {
521     if (format.empty())
522       return false;
523     char code = format[0];
524     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
525            code == 'q';
526   }
527 
528   template <typename Type>
529   py::buffer_info bufferInfo(MlirType shapedType,
530                              Type (*value)(MlirAttribute, intptr_t)) {
531     intptr_t rank = mlirShapedTypeGetRank(shapedType);
532     // Prepare the data for the buffer_info.
533     // Buffer is configured for read-only access below.
534     Type *data = static_cast<Type *>(
535         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
536     // Prepare the shape for the buffer_info.
537     SmallVector<intptr_t, 4> shape;
538     for (intptr_t i = 0; i < rank; ++i)
539       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
540     // Prepare the strides for the buffer_info.
541     SmallVector<intptr_t, 4> strides;
542     intptr_t strideFactor = 1;
543     for (intptr_t i = 1; i < rank; ++i) {
544       strideFactor = 1;
545       for (intptr_t j = i; j < rank; ++j) {
546         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
547       }
548       strides.push_back(sizeof(Type) * strideFactor);
549     }
550     strides.push_back(sizeof(Type));
551     return py::buffer_info(data, sizeof(Type),
552                            py::format_descriptor<Type>::format(), rank, shape,
553                            strides, /*readonly=*/true);
554   }
555 }; // namespace
556 
557 /// Refinement of the PyDenseElementsAttribute for attributes containing integer
558 /// (and boolean) values. Supports element access.
559 class PyDenseIntElementsAttribute
560     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
561                                  PyDenseElementsAttribute> {
562 public:
563   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
564   static constexpr const char *pyClassName = "DenseIntElementsAttr";
565   using PyConcreteAttribute::PyConcreteAttribute;
566 
567   /// Returns the element at the given linear position. Asserts if the index is
568   /// out of range.
569   py::int_ dunderGetItem(intptr_t pos) {
570     if (pos < 0 || pos >= dunderLen()) {
571       throw SetPyError(PyExc_IndexError,
572                        "attempt to access out of bounds element");
573     }
574 
575     MlirType type = mlirAttributeGetType(*this);
576     type = mlirShapedTypeGetElementType(type);
577     assert(mlirTypeIsAInteger(type) &&
578            "expected integer element type in dense int elements attribute");
579     // Dispatch element extraction to an appropriate C function based on the
580     // elemental type of the attribute. py::int_ is implicitly constructible
581     // from any C++ integral type and handles bitwidth correctly.
582     // TODO: consider caching the type properties in the constructor to avoid
583     // querying them on each element access.
584     unsigned width = mlirIntegerTypeGetWidth(type);
585     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
586     if (isUnsigned) {
587       if (width == 1) {
588         return mlirDenseElementsAttrGetBoolValue(*this, pos);
589       }
590       if (width == 32) {
591         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
592       }
593       if (width == 64) {
594         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
595       }
596     } else {
597       if (width == 1) {
598         return mlirDenseElementsAttrGetBoolValue(*this, pos);
599       }
600       if (width == 32) {
601         return mlirDenseElementsAttrGetInt32Value(*this, pos);
602       }
603       if (width == 64) {
604         return mlirDenseElementsAttrGetInt64Value(*this, pos);
605       }
606     }
607     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
608   }
609 
610   static void bindDerived(ClassTy &c) {
611     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
612   }
613 };
614 
615 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
616 public:
617   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
618   static constexpr const char *pyClassName = "DictAttr";
619   using PyConcreteAttribute::PyConcreteAttribute;
620 
621   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
622 
623   static void bindDerived(ClassTy &c) {
624     c.def("__len__", &PyDictAttribute::dunderLen);
625     c.def_static(
626         "get",
627         [](py::dict attributes, DefaultingPyMlirContext context) {
628           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
629           mlirNamedAttributes.reserve(attributes.size());
630           for (auto &it : attributes) {
631             auto &mlir_attr = it.second.cast<PyAttribute &>();
632             auto name = it.first.cast<std::string>();
633             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
634                 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
635                                   toMlirStringRef(name)),
636                 mlir_attr));
637           }
638           MlirAttribute attr =
639               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
640                                     mlirNamedAttributes.data());
641           return PyDictAttribute(context->getRef(), attr);
642         },
643         py::arg("value"), py::arg("context") = py::none(),
644         "Gets an uniqued dict attribute");
645     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
646       MlirAttribute attr =
647           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
648       if (mlirAttributeIsNull(attr)) {
649         throw SetPyError(PyExc_KeyError,
650                          "attempt to access a non-existent attribute");
651       }
652       return PyAttribute(self.getContext(), attr);
653     });
654     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
655       if (index < 0 || index >= self.dunderLen()) {
656         throw SetPyError(PyExc_IndexError,
657                          "attempt to access out of bounds attribute");
658       }
659       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
660       return PyNamedAttribute(
661           namedAttr.attribute,
662           std::string(mlirIdentifierStr(namedAttr.name).data));
663     });
664   }
665 };
666 
667 /// Refinement of PyDenseElementsAttribute for attributes containing
668 /// floating-point values. Supports element access.
669 class PyDenseFPElementsAttribute
670     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
671                                  PyDenseElementsAttribute> {
672 public:
673   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
674   static constexpr const char *pyClassName = "DenseFPElementsAttr";
675   using PyConcreteAttribute::PyConcreteAttribute;
676 
677   py::float_ dunderGetItem(intptr_t pos) {
678     if (pos < 0 || pos >= dunderLen()) {
679       throw SetPyError(PyExc_IndexError,
680                        "attempt to access out of bounds element");
681     }
682 
683     MlirType type = mlirAttributeGetType(*this);
684     type = mlirShapedTypeGetElementType(type);
685     // Dispatch element extraction to an appropriate C function based on the
686     // elemental type of the attribute. py::float_ is implicitly constructible
687     // from float and double.
688     // TODO: consider caching the type properties in the constructor to avoid
689     // querying them on each element access.
690     if (mlirTypeIsAF32(type)) {
691       return mlirDenseElementsAttrGetFloatValue(*this, pos);
692     }
693     if (mlirTypeIsAF64(type)) {
694       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
695     }
696     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
697   }
698 
699   static void bindDerived(ClassTy &c) {
700     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
701   }
702 };
703 
704 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
705 public:
706   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
707   static constexpr const char *pyClassName = "TypeAttr";
708   using PyConcreteAttribute::PyConcreteAttribute;
709 
710   static void bindDerived(ClassTy &c) {
711     c.def_static(
712         "get",
713         [](PyType value, DefaultingPyMlirContext context) {
714           MlirAttribute attr = mlirTypeAttrGet(value.get());
715           return PyTypeAttribute(context->getRef(), attr);
716         },
717         py::arg("value"), py::arg("context") = py::none(),
718         "Gets a uniqued Type attribute");
719     c.def_property_readonly("value", [](PyTypeAttribute &self) {
720       return PyType(self.getContext()->getRef(),
721                     mlirTypeAttrGetValue(self.get()));
722     });
723   }
724 };
725 
726 /// Unit Attribute subclass. Unit attributes don't have values.
727 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
728 public:
729   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
730   static constexpr const char *pyClassName = "UnitAttr";
731   using PyConcreteAttribute::PyConcreteAttribute;
732 
733   static void bindDerived(ClassTy &c) {
734     c.def_static(
735         "get",
736         [](DefaultingPyMlirContext context) {
737           return PyUnitAttribute(context->getRef(),
738                                  mlirUnitAttrGet(context->get()));
739         },
740         py::arg("context") = py::none(), "Create a Unit attribute.");
741   }
742 };
743 
744 } // namespace
745 
746 void mlir::python::populateIRAttributes(py::module &m) {
747   PyAffineMapAttribute::bind(m);
748   PyArrayAttribute::bind(m);
749   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
750   PyBoolAttribute::bind(m);
751   PyDenseElementsAttribute::bind(m);
752   PyDenseFPElementsAttribute::bind(m);
753   PyDenseIntElementsAttribute::bind(m);
754   PyDictAttribute::bind(m);
755   PyFlatSymbolRefAttribute::bind(m);
756   PyFloatAttribute::bind(m);
757   PyIntegerAttribute::bind(m);
758   PyStringAttribute::bind(m);
759   PyTypeAttribute::bind(m);
760   PyUnitAttribute::bind(m);
761 }
762