1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "PybindUtils.h"
12 
13 #include "mlir-c/BuiltinAttributes.h"
14 #include "mlir-c/BuiltinTypes.h"
15 
16 namespace py = pybind11;
17 using namespace mlir;
18 using namespace mlir::python;
19 
20 using llvm::Optional;
21 using llvm::SmallVector;
22 using llvm::Twine;
23 
24 //------------------------------------------------------------------------------
25 // Docstrings (trivial, non-duplicated docstrings are included inline).
26 //------------------------------------------------------------------------------
27 
28 static const char kDenseElementsAttrGetDocstring[] =
29     R"(Gets a DenseElementsAttr from a Python buffer or array.
30 
31 When `type` is not provided, then some limited type inferencing is done based
32 on the buffer format. Support presently exists for 8/16/32/64 signed and
33 unsigned integers and float16/float32/float64. DenseElementsAttrs of these
34 types can also be converted back to a corresponding buffer.
35 
36 For conversions outside of these types, a `type=` must be explicitly provided
37 and the buffer contents must be bit-castable to the MLIR internal
38 representation:
39 
40   * Integer types (except for i1): the buffer must be byte aligned to the
41     next byte boundary.
42   * Floating point types: Must be bit-castable to the given floating point
43     size.
44   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
45     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
46     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
47 
48 If a single element buffer is passed (or for i1, a single byte with value 0
49 or 255), then a splat will be created.
50 
51 Args:
52   array: The array or buffer to convert.
53   signless: If inferring an appropriate MLIR type, use signless types for
54     integers (defaults True).
55   type: Skips inference of the MLIR element type and uses this instead. The
56     storage size must be consistent with the actual contents of the buffer.
57   shape: Overrides the shape of the buffer when constructing the MLIR
58     shaped type. This is needed when the physical and logical shape differ (as
59     for i1).
60   context: Explicit context, if not from context manager.
61 
62 Returns:
63   DenseElementsAttr on success.
64 
65 Raises:
66   ValueError: If the type of the buffer or array cannot be matched to an MLIR
67     type or if the buffer does not meet expectations.
68 )";
69 
70 namespace {
71 
72 static MlirStringRef toMlirStringRef(const std::string &s) {
73   return mlirStringRefCreate(s.data(), s.size());
74 }
75 
76 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
77 public:
78   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
79   static constexpr const char *pyClassName = "AffineMapAttr";
80   using PyConcreteAttribute::PyConcreteAttribute;
81 
82   static void bindDerived(ClassTy &c) {
83     c.def_static(
84         "get",
85         [](PyAffineMap &affineMap) {
86           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
87           return PyAffineMapAttribute(affineMap.getContext(), attr);
88         },
89         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
90   }
91 };
92 
93 template <typename T>
94 static T pyTryCast(py::handle object) {
95   try {
96     return object.cast<T>();
97   } catch (py::cast_error &err) {
98     std::string msg =
99         std::string(
100             "Invalid attribute when attempting to create an ArrayAttribute (") +
101         err.what() + ")";
102     throw py::cast_error(msg);
103   } catch (py::reference_cast_error &err) {
104     std::string msg = std::string("Invalid attribute (None?) when attempting "
105                                   "to create an ArrayAttribute (") +
106                       err.what() + ")";
107     throw py::cast_error(msg);
108   }
109 }
110 
111 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
112 public:
113   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
114   static constexpr const char *pyClassName = "ArrayAttr";
115   using PyConcreteAttribute::PyConcreteAttribute;
116 
117   class PyArrayAttributeIterator {
118   public:
119     PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
120 
121     PyArrayAttributeIterator &dunderIter() { return *this; }
122 
123     PyAttribute dunderNext() {
124       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
125         throw py::stop_iteration();
126       }
127       return PyAttribute(attr.getContext(),
128                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
129     }
130 
131     static void bind(py::module &m) {
132       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
133                                            py::module_local())
134           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
135           .def("__next__", &PyArrayAttributeIterator::dunderNext);
136     }
137 
138   private:
139     PyAttribute attr;
140     int nextIndex = 0;
141   };
142 
143   PyAttribute getItem(intptr_t i) {
144     return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
145   }
146 
147   static void bindDerived(ClassTy &c) {
148     c.def_static(
149         "get",
150         [](py::list attributes, DefaultingPyMlirContext context) {
151           SmallVector<MlirAttribute> mlirAttributes;
152           mlirAttributes.reserve(py::len(attributes));
153           for (auto attribute : attributes) {
154             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
155           }
156           MlirAttribute attr = mlirArrayAttrGet(
157               context->get(), mlirAttributes.size(), mlirAttributes.data());
158           return PyArrayAttribute(context->getRef(), attr);
159         },
160         py::arg("attributes"), py::arg("context") = py::none(),
161         "Gets a uniqued Array attribute");
162     c.def("__getitem__",
163           [](PyArrayAttribute &arr, intptr_t i) {
164             if (i >= mlirArrayAttrGetNumElements(arr))
165               throw py::index_error("ArrayAttribute index out of range");
166             return arr.getItem(i);
167           })
168         .def("__len__",
169              [](const PyArrayAttribute &arr) {
170                return mlirArrayAttrGetNumElements(arr);
171              })
172         .def("__iter__", [](const PyArrayAttribute &arr) {
173           return PyArrayAttributeIterator(arr);
174         });
175     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
176       std::vector<MlirAttribute> attributes;
177       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
178       attributes.reserve(numOldElements + py::len(extras));
179       for (intptr_t i = 0; i < numOldElements; ++i)
180         attributes.push_back(arr.getItem(i));
181       for (py::handle attr : extras)
182         attributes.push_back(pyTryCast<PyAttribute>(attr));
183       MlirAttribute arrayAttr = mlirArrayAttrGet(
184           arr.getContext()->get(), attributes.size(), attributes.data());
185       return PyArrayAttribute(arr.getContext(), arrayAttr);
186     });
187   }
188 };
189 
190 /// Float Point Attribute subclass - FloatAttr.
191 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
192 public:
193   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
194   static constexpr const char *pyClassName = "FloatAttr";
195   using PyConcreteAttribute::PyConcreteAttribute;
196 
197   static void bindDerived(ClassTy &c) {
198     c.def_static(
199         "get",
200         [](PyType &type, double value, DefaultingPyLocation loc) {
201           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
202           // TODO: Rework error reporting once diagnostic engine is exposed
203           // in C API.
204           if (mlirAttributeIsNull(attr)) {
205             throw SetPyError(PyExc_ValueError,
206                              Twine("invalid '") +
207                                  py::repr(py::cast(type)).cast<std::string>() +
208                                  "' and expected floating point type.");
209           }
210           return PyFloatAttribute(type.getContext(), attr);
211         },
212         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
213         "Gets an uniqued float point attribute associated to a type");
214     c.def_static(
215         "get_f32",
216         [](double value, DefaultingPyMlirContext context) {
217           MlirAttribute attr = mlirFloatAttrDoubleGet(
218               context->get(), mlirF32TypeGet(context->get()), value);
219           return PyFloatAttribute(context->getRef(), attr);
220         },
221         py::arg("value"), py::arg("context") = py::none(),
222         "Gets an uniqued float point attribute associated to a f32 type");
223     c.def_static(
224         "get_f64",
225         [](double value, DefaultingPyMlirContext context) {
226           MlirAttribute attr = mlirFloatAttrDoubleGet(
227               context->get(), mlirF64TypeGet(context->get()), value);
228           return PyFloatAttribute(context->getRef(), attr);
229         },
230         py::arg("value"), py::arg("context") = py::none(),
231         "Gets an uniqued float point attribute associated to a f64 type");
232     c.def_property_readonly(
233         "value",
234         [](PyFloatAttribute &self) {
235           return mlirFloatAttrGetValueDouble(self);
236         },
237         "Returns the value of the float point attribute");
238   }
239 };
240 
241 /// Integer Attribute subclass - IntegerAttr.
242 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
243 public:
244   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
245   static constexpr const char *pyClassName = "IntegerAttr";
246   using PyConcreteAttribute::PyConcreteAttribute;
247 
248   static void bindDerived(ClassTy &c) {
249     c.def_static(
250         "get",
251         [](PyType &type, int64_t value) {
252           MlirAttribute attr = mlirIntegerAttrGet(type, value);
253           return PyIntegerAttribute(type.getContext(), attr);
254         },
255         py::arg("type"), py::arg("value"),
256         "Gets an uniqued integer attribute associated to a type");
257     c.def_property_readonly(
258         "value",
259         [](PyIntegerAttribute &self) {
260           return mlirIntegerAttrGetValueInt(self);
261         },
262         "Returns the value of the integer attribute");
263   }
264 };
265 
266 /// Bool Attribute subclass - BoolAttr.
267 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
268 public:
269   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
270   static constexpr const char *pyClassName = "BoolAttr";
271   using PyConcreteAttribute::PyConcreteAttribute;
272 
273   static void bindDerived(ClassTy &c) {
274     c.def_static(
275         "get",
276         [](bool value, DefaultingPyMlirContext context) {
277           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
278           return PyBoolAttribute(context->getRef(), attr);
279         },
280         py::arg("value"), py::arg("context") = py::none(),
281         "Gets an uniqued bool attribute");
282     c.def_property_readonly(
283         "value",
284         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
285         "Returns the value of the bool attribute");
286   }
287 };
288 
289 class PyFlatSymbolRefAttribute
290     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
291 public:
292   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
293   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
294   using PyConcreteAttribute::PyConcreteAttribute;
295 
296   static void bindDerived(ClassTy &c) {
297     c.def_static(
298         "get",
299         [](std::string value, DefaultingPyMlirContext context) {
300           MlirAttribute attr =
301               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
302           return PyFlatSymbolRefAttribute(context->getRef(), attr);
303         },
304         py::arg("value"), py::arg("context") = py::none(),
305         "Gets a uniqued FlatSymbolRef attribute");
306     c.def_property_readonly(
307         "value",
308         [](PyFlatSymbolRefAttribute &self) {
309           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
310           return py::str(stringRef.data, stringRef.length);
311         },
312         "Returns the value of the FlatSymbolRef attribute as a string");
313   }
314 };
315 
316 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
317 public:
318   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
319   static constexpr const char *pyClassName = "StringAttr";
320   using PyConcreteAttribute::PyConcreteAttribute;
321 
322   static void bindDerived(ClassTy &c) {
323     c.def_static(
324         "get",
325         [](std::string value, DefaultingPyMlirContext context) {
326           MlirAttribute attr =
327               mlirStringAttrGet(context->get(), toMlirStringRef(value));
328           return PyStringAttribute(context->getRef(), attr);
329         },
330         py::arg("value"), py::arg("context") = py::none(),
331         "Gets a uniqued string attribute");
332     c.def_static(
333         "get_typed",
334         [](PyType &type, std::string value) {
335           MlirAttribute attr =
336               mlirStringAttrTypedGet(type, toMlirStringRef(value));
337           return PyStringAttribute(type.getContext(), attr);
338         },
339         py::arg("type"), py::arg("value"),
340         "Gets a uniqued string attribute associated to a type");
341     c.def_property_readonly(
342         "value",
343         [](PyStringAttribute &self) {
344           MlirStringRef stringRef = mlirStringAttrGetValue(self);
345           return py::str(stringRef.data, stringRef.length);
346         },
347         "Returns the value of the string attribute");
348   }
349 };
350 
351 // TODO: Support construction of string elements.
352 class PyDenseElementsAttribute
353     : public PyConcreteAttribute<PyDenseElementsAttribute> {
354 public:
355   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
356   static constexpr const char *pyClassName = "DenseElementsAttr";
357   using PyConcreteAttribute::PyConcreteAttribute;
358 
359   static PyDenseElementsAttribute
360   getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType,
361                 Optional<std::vector<int64_t>> explicitShape,
362                 DefaultingPyMlirContext contextWrapper) {
363     // Request a contiguous view. In exotic cases, this will cause a copy.
364     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
365     Py_buffer *view = new Py_buffer();
366     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
367       delete view;
368       throw py::error_already_set();
369     }
370     py::buffer_info arrayInfo(view);
371     SmallVector<int64_t> shape;
372     if (explicitShape) {
373       shape.append(explicitShape->begin(), explicitShape->end());
374     } else {
375       shape.append(arrayInfo.shape.begin(),
376                    arrayInfo.shape.begin() + arrayInfo.ndim);
377     }
378 
379     MlirAttribute encodingAttr = mlirAttributeGetNull();
380     MlirContext context = contextWrapper->get();
381 
382     // Detect format codes that are suitable for bulk loading. This includes
383     // all byte aligned integer and floating point types up to 8 bytes.
384     // Notably, this excludes, bool (which needs to be bit-packed) and
385     // other exotics which do not have a direct representation in the buffer
386     // protocol (i.e. complex, etc).
387     Optional<MlirType> bulkLoadElementType;
388     if (explicitType) {
389       bulkLoadElementType = *explicitType;
390     } else if (arrayInfo.format == "f") {
391       // f32
392       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
393       bulkLoadElementType = mlirF32TypeGet(context);
394     } else if (arrayInfo.format == "d") {
395       // f64
396       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
397       bulkLoadElementType = mlirF64TypeGet(context);
398     } else if (arrayInfo.format == "e") {
399       // f16
400       assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
401       bulkLoadElementType = mlirF16TypeGet(context);
402     } else if (isSignedIntegerFormat(arrayInfo.format)) {
403       if (arrayInfo.itemsize == 4) {
404         // i32
405         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
406                                        : mlirIntegerTypeSignedGet(context, 32);
407       } else if (arrayInfo.itemsize == 8) {
408         // i64
409         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
410                                        : mlirIntegerTypeSignedGet(context, 64);
411       } else if (arrayInfo.itemsize == 1) {
412         // i8
413         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
414                                        : mlirIntegerTypeSignedGet(context, 8);
415       } else if (arrayInfo.itemsize == 2) {
416         // i16
417         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
418                                        : mlirIntegerTypeSignedGet(context, 16);
419       }
420     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
421       if (arrayInfo.itemsize == 4) {
422         // unsigned i32
423         bulkLoadElementType = signless
424                                   ? mlirIntegerTypeGet(context, 32)
425                                   : mlirIntegerTypeUnsignedGet(context, 32);
426       } else if (arrayInfo.itemsize == 8) {
427         // unsigned i64
428         bulkLoadElementType = signless
429                                   ? mlirIntegerTypeGet(context, 64)
430                                   : mlirIntegerTypeUnsignedGet(context, 64);
431       } else if (arrayInfo.itemsize == 1) {
432         // i8
433         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
434                                        : mlirIntegerTypeUnsignedGet(context, 8);
435       } else if (arrayInfo.itemsize == 2) {
436         // i16
437         bulkLoadElementType = signless
438                                   ? mlirIntegerTypeGet(context, 16)
439                                   : mlirIntegerTypeUnsignedGet(context, 16);
440       }
441     }
442     if (bulkLoadElementType) {
443       auto shapedType = mlirRankedTensorTypeGet(
444           shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
445       size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
446       MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
447           shapedType, rawBufferSize, arrayInfo.ptr);
448       if (mlirAttributeIsNull(attr)) {
449         throw std::invalid_argument(
450             "DenseElementsAttr could not be constructed from the given buffer. "
451             "This may mean that the Python buffer layout does not match that "
452             "MLIR expected layout and is a bug.");
453       }
454       return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
455     }
456 
457     throw std::invalid_argument(
458         std::string("unimplemented array format conversion from format: ") +
459         arrayInfo.format);
460   }
461 
462   static PyDenseElementsAttribute getSplat(PyType shapedType,
463                                            PyAttribute &elementAttr) {
464     auto contextWrapper =
465         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
466     if (!mlirAttributeIsAInteger(elementAttr) &&
467         !mlirAttributeIsAFloat(elementAttr)) {
468       std::string message = "Illegal element type for DenseElementsAttr: ";
469       message.append(py::repr(py::cast(elementAttr)));
470       throw SetPyError(PyExc_ValueError, message);
471     }
472     if (!mlirTypeIsAShaped(shapedType) ||
473         !mlirShapedTypeHasStaticShape(shapedType)) {
474       std::string message =
475           "Expected a static ShapedType for the shaped_type parameter: ";
476       message.append(py::repr(py::cast(shapedType)));
477       throw SetPyError(PyExc_ValueError, message);
478     }
479     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
480     MlirType attrType = mlirAttributeGetType(elementAttr);
481     if (!mlirTypeEqual(shapedElementType, attrType)) {
482       std::string message =
483           "Shaped element type and attribute type must be equal: shaped=";
484       message.append(py::repr(py::cast(shapedType)));
485       message.append(", element=");
486       message.append(py::repr(py::cast(elementAttr)));
487       throw SetPyError(PyExc_ValueError, message);
488     }
489 
490     MlirAttribute elements =
491         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
492     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
493   }
494 
495   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
496 
497   py::buffer_info accessBuffer() {
498     if (mlirDenseElementsAttrIsSplat(*this)) {
499       // TODO: Currently crashes the program.
500       // Reported as https://github.com/pybind/pybind11/issues/3336
501       throw std::invalid_argument(
502           "unsupported data type for conversion to Python buffer");
503     }
504 
505     MlirType shapedType = mlirAttributeGetType(*this);
506     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
507     std::string format;
508 
509     if (mlirTypeIsAF32(elementType)) {
510       // f32
511       return bufferInfo<float>(shapedType);
512     }
513     if (mlirTypeIsAF64(elementType)) {
514       // f64
515       return bufferInfo<double>(shapedType);
516     }
517     if (mlirTypeIsAF16(elementType)) {
518       // f16
519       return bufferInfo<uint16_t>(shapedType, "e");
520     }
521     if (mlirTypeIsAInteger(elementType) &&
522         mlirIntegerTypeGetWidth(elementType) == 32) {
523       if (mlirIntegerTypeIsSignless(elementType) ||
524           mlirIntegerTypeIsSigned(elementType)) {
525         // i32
526         return bufferInfo<int32_t>(shapedType);
527       }
528       if (mlirIntegerTypeIsUnsigned(elementType)) {
529         // unsigned i32
530         return bufferInfo<uint32_t>(shapedType);
531       }
532     } else if (mlirTypeIsAInteger(elementType) &&
533                mlirIntegerTypeGetWidth(elementType) == 64) {
534       if (mlirIntegerTypeIsSignless(elementType) ||
535           mlirIntegerTypeIsSigned(elementType)) {
536         // i64
537         return bufferInfo<int64_t>(shapedType);
538       }
539       if (mlirIntegerTypeIsUnsigned(elementType)) {
540         // unsigned i64
541         return bufferInfo<uint64_t>(shapedType);
542       }
543     } else if (mlirTypeIsAInteger(elementType) &&
544                mlirIntegerTypeGetWidth(elementType) == 8) {
545       if (mlirIntegerTypeIsSignless(elementType) ||
546           mlirIntegerTypeIsSigned(elementType)) {
547         // i8
548         return bufferInfo<int8_t>(shapedType);
549       }
550       if (mlirIntegerTypeIsUnsigned(elementType)) {
551         // unsigned i8
552         return bufferInfo<uint8_t>(shapedType);
553       }
554     } else if (mlirTypeIsAInteger(elementType) &&
555                mlirIntegerTypeGetWidth(elementType) == 16) {
556       if (mlirIntegerTypeIsSignless(elementType) ||
557           mlirIntegerTypeIsSigned(elementType)) {
558         // i16
559         return bufferInfo<int16_t>(shapedType);
560       }
561       if (mlirIntegerTypeIsUnsigned(elementType)) {
562         // unsigned i16
563         return bufferInfo<uint16_t>(shapedType);
564       }
565     }
566 
567     // TODO: Currently crashes the program.
568     // Reported as https://github.com/pybind/pybind11/issues/3336
569     throw std::invalid_argument(
570         "unsupported data type for conversion to Python buffer");
571   }
572 
573   static void bindDerived(ClassTy &c) {
574     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
575         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
576                     py::arg("array"), py::arg("signless") = true,
577                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
578                     py::arg("context") = py::none(),
579                     kDenseElementsAttrGetDocstring)
580         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
581                     py::arg("shaped_type"), py::arg("element_attr"),
582                     "Gets a DenseElementsAttr where all values are the same")
583         .def_property_readonly("is_splat",
584                                [](PyDenseElementsAttribute &self) -> bool {
585                                  return mlirDenseElementsAttrIsSplat(self);
586                                })
587         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
588   }
589 
590 private:
591   static bool isUnsignedIntegerFormat(const std::string &format) {
592     if (format.empty())
593       return false;
594     char code = format[0];
595     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
596            code == 'Q';
597   }
598 
599   static bool isSignedIntegerFormat(const std::string &format) {
600     if (format.empty())
601       return false;
602     char code = format[0];
603     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
604            code == 'q';
605   }
606 
607   template <typename Type>
608   py::buffer_info bufferInfo(MlirType shapedType,
609                              const char *explicitFormat = nullptr) {
610     intptr_t rank = mlirShapedTypeGetRank(shapedType);
611     // Prepare the data for the buffer_info.
612     // Buffer is configured for read-only access below.
613     Type *data = static_cast<Type *>(
614         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
615     // Prepare the shape for the buffer_info.
616     SmallVector<intptr_t, 4> shape;
617     for (intptr_t i = 0; i < rank; ++i)
618       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
619     // Prepare the strides for the buffer_info.
620     SmallVector<intptr_t, 4> strides;
621     intptr_t strideFactor = 1;
622     for (intptr_t i = 1; i < rank; ++i) {
623       strideFactor = 1;
624       for (intptr_t j = i; j < rank; ++j) {
625         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
626       }
627       strides.push_back(sizeof(Type) * strideFactor);
628     }
629     strides.push_back(sizeof(Type));
630     std::string format;
631     if (explicitFormat) {
632       format = explicitFormat;
633     } else {
634       format = py::format_descriptor<Type>::format();
635     }
636     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
637                            /*readonly=*/true);
638   }
639 }; // namespace
640 
641 /// Refinement of the PyDenseElementsAttribute for attributes containing integer
642 /// (and boolean) values. Supports element access.
643 class PyDenseIntElementsAttribute
644     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
645                                  PyDenseElementsAttribute> {
646 public:
647   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
648   static constexpr const char *pyClassName = "DenseIntElementsAttr";
649   using PyConcreteAttribute::PyConcreteAttribute;
650 
651   /// Returns the element at the given linear position. Asserts if the index is
652   /// out of range.
653   py::int_ dunderGetItem(intptr_t pos) {
654     if (pos < 0 || pos >= dunderLen()) {
655       throw SetPyError(PyExc_IndexError,
656                        "attempt to access out of bounds element");
657     }
658 
659     MlirType type = mlirAttributeGetType(*this);
660     type = mlirShapedTypeGetElementType(type);
661     assert(mlirTypeIsAInteger(type) &&
662            "expected integer element type in dense int elements attribute");
663     // Dispatch element extraction to an appropriate C function based on the
664     // elemental type of the attribute. py::int_ is implicitly constructible
665     // from any C++ integral type and handles bitwidth correctly.
666     // TODO: consider caching the type properties in the constructor to avoid
667     // querying them on each element access.
668     unsigned width = mlirIntegerTypeGetWidth(type);
669     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
670     if (isUnsigned) {
671       if (width == 1) {
672         return mlirDenseElementsAttrGetBoolValue(*this, pos);
673       }
674       if (width == 32) {
675         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
676       }
677       if (width == 64) {
678         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
679       }
680     } else {
681       if (width == 1) {
682         return mlirDenseElementsAttrGetBoolValue(*this, pos);
683       }
684       if (width == 32) {
685         return mlirDenseElementsAttrGetInt32Value(*this, pos);
686       }
687       if (width == 64) {
688         return mlirDenseElementsAttrGetInt64Value(*this, pos);
689       }
690     }
691     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
692   }
693 
694   static void bindDerived(ClassTy &c) {
695     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
696   }
697 };
698 
699 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
700 public:
701   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
702   static constexpr const char *pyClassName = "DictAttr";
703   using PyConcreteAttribute::PyConcreteAttribute;
704 
705   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
706 
707   bool dunderContains(const std::string &name) {
708     return !mlirAttributeIsNull(
709         mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
710   }
711 
712   static void bindDerived(ClassTy &c) {
713     c.def("__contains__", &PyDictAttribute::dunderContains);
714     c.def("__len__", &PyDictAttribute::dunderLen);
715     c.def_static(
716         "get",
717         [](py::dict attributes, DefaultingPyMlirContext context) {
718           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
719           mlirNamedAttributes.reserve(attributes.size());
720           for (auto &it : attributes) {
721             auto &mlirAttr = it.second.cast<PyAttribute &>();
722             auto name = it.first.cast<std::string>();
723             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
724                 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
725                                   toMlirStringRef(name)),
726                 mlirAttr));
727           }
728           MlirAttribute attr =
729               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
730                                     mlirNamedAttributes.data());
731           return PyDictAttribute(context->getRef(), attr);
732         },
733         py::arg("value") = py::dict(), py::arg("context") = py::none(),
734         "Gets an uniqued dict attribute");
735     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
736       MlirAttribute attr =
737           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
738       if (mlirAttributeIsNull(attr)) {
739         throw SetPyError(PyExc_KeyError,
740                          "attempt to access a non-existent attribute");
741       }
742       return PyAttribute(self.getContext(), attr);
743     });
744     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
745       if (index < 0 || index >= self.dunderLen()) {
746         throw SetPyError(PyExc_IndexError,
747                          "attempt to access out of bounds attribute");
748       }
749       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
750       return PyNamedAttribute(
751           namedAttr.attribute,
752           std::string(mlirIdentifierStr(namedAttr.name).data));
753     });
754   }
755 };
756 
757 /// Refinement of PyDenseElementsAttribute for attributes containing
758 /// floating-point values. Supports element access.
759 class PyDenseFPElementsAttribute
760     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
761                                  PyDenseElementsAttribute> {
762 public:
763   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
764   static constexpr const char *pyClassName = "DenseFPElementsAttr";
765   using PyConcreteAttribute::PyConcreteAttribute;
766 
767   py::float_ dunderGetItem(intptr_t pos) {
768     if (pos < 0 || pos >= dunderLen()) {
769       throw SetPyError(PyExc_IndexError,
770                        "attempt to access out of bounds element");
771     }
772 
773     MlirType type = mlirAttributeGetType(*this);
774     type = mlirShapedTypeGetElementType(type);
775     // Dispatch element extraction to an appropriate C function based on the
776     // elemental type of the attribute. py::float_ is implicitly constructible
777     // from float and double.
778     // TODO: consider caching the type properties in the constructor to avoid
779     // querying them on each element access.
780     if (mlirTypeIsAF32(type)) {
781       return mlirDenseElementsAttrGetFloatValue(*this, pos);
782     }
783     if (mlirTypeIsAF64(type)) {
784       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
785     }
786     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
787   }
788 
789   static void bindDerived(ClassTy &c) {
790     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
791   }
792 };
793 
794 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
795 public:
796   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
797   static constexpr const char *pyClassName = "TypeAttr";
798   using PyConcreteAttribute::PyConcreteAttribute;
799 
800   static void bindDerived(ClassTy &c) {
801     c.def_static(
802         "get",
803         [](PyType value, DefaultingPyMlirContext context) {
804           MlirAttribute attr = mlirTypeAttrGet(value.get());
805           return PyTypeAttribute(context->getRef(), attr);
806         },
807         py::arg("value"), py::arg("context") = py::none(),
808         "Gets a uniqued Type attribute");
809     c.def_property_readonly("value", [](PyTypeAttribute &self) {
810       return PyType(self.getContext()->getRef(),
811                     mlirTypeAttrGetValue(self.get()));
812     });
813   }
814 };
815 
816 /// Unit Attribute subclass. Unit attributes don't have values.
817 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
818 public:
819   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
820   static constexpr const char *pyClassName = "UnitAttr";
821   using PyConcreteAttribute::PyConcreteAttribute;
822 
823   static void bindDerived(ClassTy &c) {
824     c.def_static(
825         "get",
826         [](DefaultingPyMlirContext context) {
827           return PyUnitAttribute(context->getRef(),
828                                  mlirUnitAttrGet(context->get()));
829         },
830         py::arg("context") = py::none(), "Create a Unit attribute.");
831   }
832 };
833 
834 } // namespace
835 
836 void mlir::python::populateIRAttributes(py::module &m) {
837   PyAffineMapAttribute::bind(m);
838   PyArrayAttribute::bind(m);
839   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
840   PyBoolAttribute::bind(m);
841   PyDenseElementsAttribute::bind(m);
842   PyDenseFPElementsAttribute::bind(m);
843   PyDenseIntElementsAttribute::bind(m);
844   PyDictAttribute::bind(m);
845   PyFlatSymbolRefAttribute::bind(m);
846   PyFloatAttribute::bind(m);
847   PyIntegerAttribute::bind(m);
848   PyStringAttribute::bind(m);
849   PyTypeAttribute::bind(m);
850   PyUnitAttribute::bind(m);
851 }
852