1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "PybindUtils.h"
12 
13 #include "mlir-c/BuiltinAttributes.h"
14 #include "mlir-c/BuiltinTypes.h"
15 
16 namespace py = pybind11;
17 using namespace mlir;
18 using namespace mlir::python;
19 
20 using llvm::SmallVector;
21 using llvm::Twine;
22 
23 namespace {
24 
25 static MlirStringRef toMlirStringRef(const std::string &s) {
26   return mlirStringRefCreate(s.data(), s.size());
27 }
28 
29 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
30 public:
31   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
32   static constexpr const char *pyClassName = "AffineMapAttr";
33   using PyConcreteAttribute::PyConcreteAttribute;
34 
35   static void bindDerived(ClassTy &c) {
36     c.def_static(
37         "get",
38         [](PyAffineMap &affineMap) {
39           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
40           return PyAffineMapAttribute(affineMap.getContext(), attr);
41         },
42         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
43   }
44 };
45 
46 template <typename T>
47 static T pyTryCast(py::handle object) {
48   try {
49     return object.cast<T>();
50   } catch (py::cast_error &err) {
51     std::string msg =
52         std::string(
53             "Invalid attribute when attempting to create an ArrayAttribute (") +
54         err.what() + ")";
55     throw py::cast_error(msg);
56   } catch (py::reference_cast_error &err) {
57     std::string msg = std::string("Invalid attribute (None?) when attempting "
58                                   "to create an ArrayAttribute (") +
59                       err.what() + ")";
60     throw py::cast_error(msg);
61   }
62 }
63 
64 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
65 public:
66   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
67   static constexpr const char *pyClassName = "ArrayAttr";
68   using PyConcreteAttribute::PyConcreteAttribute;
69 
70   class PyArrayAttributeIterator {
71   public:
72     PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
73 
74     PyArrayAttributeIterator &dunderIter() { return *this; }
75 
76     PyAttribute dunderNext() {
77       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
78         throw py::stop_iteration();
79       }
80       return PyAttribute(attr.getContext(),
81                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
82     }
83 
84     static void bind(py::module &m) {
85       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
86                                            py::module_local())
87           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
88           .def("__next__", &PyArrayAttributeIterator::dunderNext);
89     }
90 
91   private:
92     PyAttribute attr;
93     int nextIndex = 0;
94   };
95 
96   PyAttribute getItem(intptr_t i) {
97     return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
98   }
99 
100   static void bindDerived(ClassTy &c) {
101     c.def_static(
102         "get",
103         [](py::list attributes, DefaultingPyMlirContext context) {
104           SmallVector<MlirAttribute> mlirAttributes;
105           mlirAttributes.reserve(py::len(attributes));
106           for (auto attribute : attributes) {
107             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
108           }
109           MlirAttribute attr = mlirArrayAttrGet(
110               context->get(), mlirAttributes.size(), mlirAttributes.data());
111           return PyArrayAttribute(context->getRef(), attr);
112         },
113         py::arg("attributes"), py::arg("context") = py::none(),
114         "Gets a uniqued Array attribute");
115     c.def("__getitem__",
116           [](PyArrayAttribute &arr, intptr_t i) {
117             if (i >= mlirArrayAttrGetNumElements(arr))
118               throw py::index_error("ArrayAttribute index out of range");
119             return arr.getItem(i);
120           })
121         .def("__len__",
122              [](const PyArrayAttribute &arr) {
123                return mlirArrayAttrGetNumElements(arr);
124              })
125         .def("__iter__", [](const PyArrayAttribute &arr) {
126           return PyArrayAttributeIterator(arr);
127         });
128     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
129       std::vector<MlirAttribute> attributes;
130       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
131       attributes.reserve(numOldElements + py::len(extras));
132       for (intptr_t i = 0; i < numOldElements; ++i)
133         attributes.push_back(arr.getItem(i));
134       for (py::handle attr : extras)
135         attributes.push_back(pyTryCast<PyAttribute>(attr));
136       MlirAttribute arrayAttr = mlirArrayAttrGet(
137           arr.getContext()->get(), attributes.size(), attributes.data());
138       return PyArrayAttribute(arr.getContext(), arrayAttr);
139     });
140   }
141 };
142 
143 /// Float Point Attribute subclass - FloatAttr.
144 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
145 public:
146   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
147   static constexpr const char *pyClassName = "FloatAttr";
148   using PyConcreteAttribute::PyConcreteAttribute;
149 
150   static void bindDerived(ClassTy &c) {
151     c.def_static(
152         "get",
153         [](PyType &type, double value, DefaultingPyLocation loc) {
154           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
155           // TODO: Rework error reporting once diagnostic engine is exposed
156           // in C API.
157           if (mlirAttributeIsNull(attr)) {
158             throw SetPyError(PyExc_ValueError,
159                              Twine("invalid '") +
160                                  py::repr(py::cast(type)).cast<std::string>() +
161                                  "' and expected floating point type.");
162           }
163           return PyFloatAttribute(type.getContext(), attr);
164         },
165         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
166         "Gets an uniqued float point attribute associated to a type");
167     c.def_static(
168         "get_f32",
169         [](double value, DefaultingPyMlirContext context) {
170           MlirAttribute attr = mlirFloatAttrDoubleGet(
171               context->get(), mlirF32TypeGet(context->get()), value);
172           return PyFloatAttribute(context->getRef(), attr);
173         },
174         py::arg("value"), py::arg("context") = py::none(),
175         "Gets an uniqued float point attribute associated to a f32 type");
176     c.def_static(
177         "get_f64",
178         [](double value, DefaultingPyMlirContext context) {
179           MlirAttribute attr = mlirFloatAttrDoubleGet(
180               context->get(), mlirF64TypeGet(context->get()), value);
181           return PyFloatAttribute(context->getRef(), attr);
182         },
183         py::arg("value"), py::arg("context") = py::none(),
184         "Gets an uniqued float point attribute associated to a f64 type");
185     c.def_property_readonly(
186         "value",
187         [](PyFloatAttribute &self) {
188           return mlirFloatAttrGetValueDouble(self);
189         },
190         "Returns the value of the float point attribute");
191   }
192 };
193 
194 /// Integer Attribute subclass - IntegerAttr.
195 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
196 public:
197   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
198   static constexpr const char *pyClassName = "IntegerAttr";
199   using PyConcreteAttribute::PyConcreteAttribute;
200 
201   static void bindDerived(ClassTy &c) {
202     c.def_static(
203         "get",
204         [](PyType &type, int64_t value) {
205           MlirAttribute attr = mlirIntegerAttrGet(type, value);
206           return PyIntegerAttribute(type.getContext(), attr);
207         },
208         py::arg("type"), py::arg("value"),
209         "Gets an uniqued integer attribute associated to a type");
210     c.def_property_readonly(
211         "value",
212         [](PyIntegerAttribute &self) {
213           return mlirIntegerAttrGetValueInt(self);
214         },
215         "Returns the value of the integer attribute");
216   }
217 };
218 
219 /// Bool Attribute subclass - BoolAttr.
220 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
221 public:
222   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
223   static constexpr const char *pyClassName = "BoolAttr";
224   using PyConcreteAttribute::PyConcreteAttribute;
225 
226   static void bindDerived(ClassTy &c) {
227     c.def_static(
228         "get",
229         [](bool value, DefaultingPyMlirContext context) {
230           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
231           return PyBoolAttribute(context->getRef(), attr);
232         },
233         py::arg("value"), py::arg("context") = py::none(),
234         "Gets an uniqued bool attribute");
235     c.def_property_readonly(
236         "value",
237         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
238         "Returns the value of the bool attribute");
239   }
240 };
241 
242 class PyFlatSymbolRefAttribute
243     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
244 public:
245   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
246   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
247   using PyConcreteAttribute::PyConcreteAttribute;
248 
249   static void bindDerived(ClassTy &c) {
250     c.def_static(
251         "get",
252         [](std::string value, DefaultingPyMlirContext context) {
253           MlirAttribute attr =
254               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
255           return PyFlatSymbolRefAttribute(context->getRef(), attr);
256         },
257         py::arg("value"), py::arg("context") = py::none(),
258         "Gets a uniqued FlatSymbolRef attribute");
259     c.def_property_readonly(
260         "value",
261         [](PyFlatSymbolRefAttribute &self) {
262           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
263           return py::str(stringRef.data, stringRef.length);
264         },
265         "Returns the value of the FlatSymbolRef attribute as a string");
266   }
267 };
268 
269 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
270 public:
271   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
272   static constexpr const char *pyClassName = "StringAttr";
273   using PyConcreteAttribute::PyConcreteAttribute;
274 
275   static void bindDerived(ClassTy &c) {
276     c.def_static(
277         "get",
278         [](std::string value, DefaultingPyMlirContext context) {
279           MlirAttribute attr =
280               mlirStringAttrGet(context->get(), toMlirStringRef(value));
281           return PyStringAttribute(context->getRef(), attr);
282         },
283         py::arg("value"), py::arg("context") = py::none(),
284         "Gets a uniqued string attribute");
285     c.def_static(
286         "get_typed",
287         [](PyType &type, std::string value) {
288           MlirAttribute attr =
289               mlirStringAttrTypedGet(type, toMlirStringRef(value));
290           return PyStringAttribute(type.getContext(), attr);
291         },
292 
293         "Gets a uniqued string attribute associated to a type");
294     c.def_property_readonly(
295         "value",
296         [](PyStringAttribute &self) {
297           MlirStringRef stringRef = mlirStringAttrGetValue(self);
298           return py::str(stringRef.data, stringRef.length);
299         },
300         "Returns the value of the string attribute");
301   }
302 };
303 
304 // TODO: Support construction of bool elements.
305 // TODO: Support construction of string elements.
306 class PyDenseElementsAttribute
307     : public PyConcreteAttribute<PyDenseElementsAttribute> {
308 public:
309   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
310   static constexpr const char *pyClassName = "DenseElementsAttr";
311   using PyConcreteAttribute::PyConcreteAttribute;
312 
313   static PyDenseElementsAttribute
314   getFromBuffer(py::buffer array, bool signless,
315                 DefaultingPyMlirContext contextWrapper) {
316     // Request a contiguous view. In exotic cases, this will cause a copy.
317     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
318     Py_buffer *view = new Py_buffer();
319     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
320       delete view;
321       throw py::error_already_set();
322     }
323     py::buffer_info arrayInfo(view);
324 
325     MlirContext context = contextWrapper->get();
326     // Switch on the types that can be bulk loaded between the Python and
327     // MLIR-C APIs.
328     // See: https://docs.python.org/3/library/struct.html#format-characters
329     if (arrayInfo.format == "f") {
330       // f32
331       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
332       return PyDenseElementsAttribute(
333           contextWrapper->getRef(),
334           bulkLoad(context, mlirDenseElementsAttrFloatGet,
335                    mlirF32TypeGet(context), arrayInfo));
336     } else if (arrayInfo.format == "d") {
337       // f64
338       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
339       return PyDenseElementsAttribute(
340           contextWrapper->getRef(),
341           bulkLoad(context, mlirDenseElementsAttrDoubleGet,
342                    mlirF64TypeGet(context), arrayInfo));
343     } else if (isSignedIntegerFormat(arrayInfo.format)) {
344       if (arrayInfo.itemsize == 4) {
345         // i32
346         MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
347                                         : mlirIntegerTypeSignedGet(context, 32);
348         return PyDenseElementsAttribute(contextWrapper->getRef(),
349                                         bulkLoad(context,
350                                                  mlirDenseElementsAttrInt32Get,
351                                                  elementType, arrayInfo));
352       } else if (arrayInfo.itemsize == 8) {
353         // i64
354         MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
355                                         : mlirIntegerTypeSignedGet(context, 64);
356         return PyDenseElementsAttribute(contextWrapper->getRef(),
357                                         bulkLoad(context,
358                                                  mlirDenseElementsAttrInt64Get,
359                                                  elementType, arrayInfo));
360       }
361     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
362       if (arrayInfo.itemsize == 4) {
363         // unsigned i32
364         MlirType elementType = signless
365                                    ? mlirIntegerTypeGet(context, 32)
366                                    : mlirIntegerTypeUnsignedGet(context, 32);
367         return PyDenseElementsAttribute(contextWrapper->getRef(),
368                                         bulkLoad(context,
369                                                  mlirDenseElementsAttrUInt32Get,
370                                                  elementType, arrayInfo));
371       } else if (arrayInfo.itemsize == 8) {
372         // unsigned i64
373         MlirType elementType = signless
374                                    ? mlirIntegerTypeGet(context, 64)
375                                    : mlirIntegerTypeUnsignedGet(context, 64);
376         return PyDenseElementsAttribute(contextWrapper->getRef(),
377                                         bulkLoad(context,
378                                                  mlirDenseElementsAttrUInt64Get,
379                                                  elementType, arrayInfo));
380       }
381     }
382 
383     // TODO: Fall back to string-based get.
384     std::string message = "unimplemented array format conversion from format: ";
385     message.append(arrayInfo.format);
386     throw SetPyError(PyExc_ValueError, message);
387   }
388 
389   static PyDenseElementsAttribute getSplat(PyType shapedType,
390                                            PyAttribute &elementAttr) {
391     auto contextWrapper =
392         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
393     if (!mlirAttributeIsAInteger(elementAttr) &&
394         !mlirAttributeIsAFloat(elementAttr)) {
395       std::string message = "Illegal element type for DenseElementsAttr: ";
396       message.append(py::repr(py::cast(elementAttr)));
397       throw SetPyError(PyExc_ValueError, message);
398     }
399     if (!mlirTypeIsAShaped(shapedType) ||
400         !mlirShapedTypeHasStaticShape(shapedType)) {
401       std::string message =
402           "Expected a static ShapedType for the shaped_type parameter: ";
403       message.append(py::repr(py::cast(shapedType)));
404       throw SetPyError(PyExc_ValueError, message);
405     }
406     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
407     MlirType attrType = mlirAttributeGetType(elementAttr);
408     if (!mlirTypeEqual(shapedElementType, attrType)) {
409       std::string message =
410           "Shaped element type and attribute type must be equal: shaped=";
411       message.append(py::repr(py::cast(shapedType)));
412       message.append(", element=");
413       message.append(py::repr(py::cast(elementAttr)));
414       throw SetPyError(PyExc_ValueError, message);
415     }
416 
417     MlirAttribute elements =
418         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
419     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
420   }
421 
422   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
423 
424   py::buffer_info accessBuffer() {
425     MlirType shapedType = mlirAttributeGetType(*this);
426     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
427 
428     if (mlirTypeIsAF32(elementType)) {
429       // f32
430       return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
431     } else if (mlirTypeIsAF64(elementType)) {
432       // f64
433       return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
434     } else if (mlirTypeIsAInteger(elementType) &&
435                mlirIntegerTypeGetWidth(elementType) == 32) {
436       if (mlirIntegerTypeIsSignless(elementType) ||
437           mlirIntegerTypeIsSigned(elementType)) {
438         // i32
439         return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
440       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
441         // unsigned i32
442         return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
443       }
444     } else if (mlirTypeIsAInteger(elementType) &&
445                mlirIntegerTypeGetWidth(elementType) == 64) {
446       if (mlirIntegerTypeIsSignless(elementType) ||
447           mlirIntegerTypeIsSigned(elementType)) {
448         // i64
449         return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
450       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
451         // unsigned i64
452         return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
453       }
454     }
455 
456     std::string message = "unimplemented array format.";
457     throw SetPyError(PyExc_ValueError, message);
458   }
459 
460   static void bindDerived(ClassTy &c) {
461     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
462         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
463                     py::arg("array"), py::arg("signless") = true,
464                     py::arg("context") = py::none(),
465                     "Gets from a buffer or ndarray")
466         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
467                     py::arg("shaped_type"), py::arg("element_attr"),
468                     "Gets a DenseElementsAttr where all values are the same")
469         .def_property_readonly("is_splat",
470                                [](PyDenseElementsAttribute &self) -> bool {
471                                  return mlirDenseElementsAttrIsSplat(self);
472                                })
473         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
474   }
475 
476 private:
477   template <typename ElementTy>
478   static MlirAttribute
479   bulkLoad(MlirContext context,
480            MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
481            MlirType mlirElementType, py::buffer_info &arrayInfo) {
482     SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
483                                   arrayInfo.shape.begin() + arrayInfo.ndim);
484     MlirAttribute encodingAttr = mlirAttributeGetNull();
485     auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
486                                               mlirElementType, encodingAttr);
487     intptr_t numElements = arrayInfo.size;
488     const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
489     return ctor(shapedType, numElements, contents);
490   }
491 
492   static bool isUnsignedIntegerFormat(const std::string &format) {
493     if (format.empty())
494       return false;
495     char code = format[0];
496     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
497            code == 'Q';
498   }
499 
500   static bool isSignedIntegerFormat(const std::string &format) {
501     if (format.empty())
502       return false;
503     char code = format[0];
504     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
505            code == 'q';
506   }
507 
508   template <typename Type>
509   py::buffer_info bufferInfo(MlirType shapedType,
510                              Type (*value)(MlirAttribute, intptr_t)) {
511     intptr_t rank = mlirShapedTypeGetRank(shapedType);
512     // Prepare the data for the buffer_info.
513     // Buffer is configured for read-only access below.
514     Type *data = static_cast<Type *>(
515         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
516     // Prepare the shape for the buffer_info.
517     SmallVector<intptr_t, 4> shape;
518     for (intptr_t i = 0; i < rank; ++i)
519       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
520     // Prepare the strides for the buffer_info.
521     SmallVector<intptr_t, 4> strides;
522     intptr_t strideFactor = 1;
523     for (intptr_t i = 1; i < rank; ++i) {
524       strideFactor = 1;
525       for (intptr_t j = i; j < rank; ++j) {
526         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
527       }
528       strides.push_back(sizeof(Type) * strideFactor);
529     }
530     strides.push_back(sizeof(Type));
531     return py::buffer_info(data, sizeof(Type),
532                            py::format_descriptor<Type>::format(), rank, shape,
533                            strides, /*readonly=*/true);
534   }
535 }; // namespace
536 
537 /// Refinement of the PyDenseElementsAttribute for attributes containing integer
538 /// (and boolean) values. Supports element access.
539 class PyDenseIntElementsAttribute
540     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
541                                  PyDenseElementsAttribute> {
542 public:
543   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
544   static constexpr const char *pyClassName = "DenseIntElementsAttr";
545   using PyConcreteAttribute::PyConcreteAttribute;
546 
547   /// Returns the element at the given linear position. Asserts if the index is
548   /// out of range.
549   py::int_ dunderGetItem(intptr_t pos) {
550     if (pos < 0 || pos >= dunderLen()) {
551       throw SetPyError(PyExc_IndexError,
552                        "attempt to access out of bounds element");
553     }
554 
555     MlirType type = mlirAttributeGetType(*this);
556     type = mlirShapedTypeGetElementType(type);
557     assert(mlirTypeIsAInteger(type) &&
558            "expected integer element type in dense int elements attribute");
559     // Dispatch element extraction to an appropriate C function based on the
560     // elemental type of the attribute. py::int_ is implicitly constructible
561     // from any C++ integral type and handles bitwidth correctly.
562     // TODO: consider caching the type properties in the constructor to avoid
563     // querying them on each element access.
564     unsigned width = mlirIntegerTypeGetWidth(type);
565     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
566     if (isUnsigned) {
567       if (width == 1) {
568         return mlirDenseElementsAttrGetBoolValue(*this, pos);
569       }
570       if (width == 32) {
571         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
572       }
573       if (width == 64) {
574         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
575       }
576     } else {
577       if (width == 1) {
578         return mlirDenseElementsAttrGetBoolValue(*this, pos);
579       }
580       if (width == 32) {
581         return mlirDenseElementsAttrGetInt32Value(*this, pos);
582       }
583       if (width == 64) {
584         return mlirDenseElementsAttrGetInt64Value(*this, pos);
585       }
586     }
587     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
588   }
589 
590   static void bindDerived(ClassTy &c) {
591     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
592   }
593 };
594 
595 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
596 public:
597   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
598   static constexpr const char *pyClassName = "DictAttr";
599   using PyConcreteAttribute::PyConcreteAttribute;
600 
601   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
602 
603   static void bindDerived(ClassTy &c) {
604     c.def("__len__", &PyDictAttribute::dunderLen);
605     c.def_static(
606         "get",
607         [](py::dict attributes, DefaultingPyMlirContext context) {
608           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
609           mlirNamedAttributes.reserve(attributes.size());
610           for (auto &it : attributes) {
611             auto &mlir_attr = it.second.cast<PyAttribute &>();
612             auto name = it.first.cast<std::string>();
613             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
614                 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
615                                   toMlirStringRef(name)),
616                 mlir_attr));
617           }
618           MlirAttribute attr =
619               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
620                                     mlirNamedAttributes.data());
621           return PyDictAttribute(context->getRef(), attr);
622         },
623         py::arg("value") = py::dict(), py::arg("context") = py::none(),
624         "Gets an uniqued dict attribute");
625     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
626       MlirAttribute attr =
627           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
628       if (mlirAttributeIsNull(attr)) {
629         throw SetPyError(PyExc_KeyError,
630                          "attempt to access a non-existent attribute");
631       }
632       return PyAttribute(self.getContext(), attr);
633     });
634     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
635       if (index < 0 || index >= self.dunderLen()) {
636         throw SetPyError(PyExc_IndexError,
637                          "attempt to access out of bounds attribute");
638       }
639       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
640       return PyNamedAttribute(
641           namedAttr.attribute,
642           std::string(mlirIdentifierStr(namedAttr.name).data));
643     });
644   }
645 };
646 
647 /// Refinement of PyDenseElementsAttribute for attributes containing
648 /// floating-point values. Supports element access.
649 class PyDenseFPElementsAttribute
650     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
651                                  PyDenseElementsAttribute> {
652 public:
653   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
654   static constexpr const char *pyClassName = "DenseFPElementsAttr";
655   using PyConcreteAttribute::PyConcreteAttribute;
656 
657   py::float_ dunderGetItem(intptr_t pos) {
658     if (pos < 0 || pos >= dunderLen()) {
659       throw SetPyError(PyExc_IndexError,
660                        "attempt to access out of bounds element");
661     }
662 
663     MlirType type = mlirAttributeGetType(*this);
664     type = mlirShapedTypeGetElementType(type);
665     // Dispatch element extraction to an appropriate C function based on the
666     // elemental type of the attribute. py::float_ is implicitly constructible
667     // from float and double.
668     // TODO: consider caching the type properties in the constructor to avoid
669     // querying them on each element access.
670     if (mlirTypeIsAF32(type)) {
671       return mlirDenseElementsAttrGetFloatValue(*this, pos);
672     }
673     if (mlirTypeIsAF64(type)) {
674       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
675     }
676     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
677   }
678 
679   static void bindDerived(ClassTy &c) {
680     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
681   }
682 };
683 
684 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
685 public:
686   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
687   static constexpr const char *pyClassName = "TypeAttr";
688   using PyConcreteAttribute::PyConcreteAttribute;
689 
690   static void bindDerived(ClassTy &c) {
691     c.def_static(
692         "get",
693         [](PyType value, DefaultingPyMlirContext context) {
694           MlirAttribute attr = mlirTypeAttrGet(value.get());
695           return PyTypeAttribute(context->getRef(), attr);
696         },
697         py::arg("value"), py::arg("context") = py::none(),
698         "Gets a uniqued Type attribute");
699     c.def_property_readonly("value", [](PyTypeAttribute &self) {
700       return PyType(self.getContext()->getRef(),
701                     mlirTypeAttrGetValue(self.get()));
702     });
703   }
704 };
705 
706 /// Unit Attribute subclass. Unit attributes don't have values.
707 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
708 public:
709   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
710   static constexpr const char *pyClassName = "UnitAttr";
711   using PyConcreteAttribute::PyConcreteAttribute;
712 
713   static void bindDerived(ClassTy &c) {
714     c.def_static(
715         "get",
716         [](DefaultingPyMlirContext context) {
717           return PyUnitAttribute(context->getRef(),
718                                  mlirUnitAttrGet(context->get()));
719         },
720         py::arg("context") = py::none(), "Create a Unit attribute.");
721   }
722 };
723 
724 } // namespace
725 
726 void mlir::python::populateIRAttributes(py::module &m) {
727   PyAffineMapAttribute::bind(m);
728   PyArrayAttribute::bind(m);
729   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
730   PyBoolAttribute::bind(m);
731   PyDenseElementsAttribute::bind(m);
732   PyDenseFPElementsAttribute::bind(m);
733   PyDenseIntElementsAttribute::bind(m);
734   PyDictAttribute::bind(m);
735   PyFlatSymbolRefAttribute::bind(m);
736   PyFloatAttribute::bind(m);
737   PyIntegerAttribute::bind(m);
738   PyStringAttribute::bind(m);
739   PyTypeAttribute::bind(m);
740   PyUnitAttribute::bind(m);
741 }
742