1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "PybindUtils.h"
12 
13 #include "mlir-c/BuiltinAttributes.h"
14 #include "mlir-c/BuiltinTypes.h"
15 
16 namespace py = pybind11;
17 using namespace mlir;
18 using namespace mlir::python;
19 
20 using llvm::None;
21 using llvm::Optional;
22 using llvm::SmallVector;
23 using llvm::Twine;
24 
25 //------------------------------------------------------------------------------
26 // Docstrings (trivial, non-duplicated docstrings are included inline).
27 //------------------------------------------------------------------------------
28 
29 static const char kDenseElementsAttrGetDocstring[] =
30     R"(Gets a DenseElementsAttr from a Python buffer or array.
31 
32 When `type` is not provided, then some limited type inferencing is done based
33 on the buffer format. Support presently exists for 8/16/32/64 signed and
34 unsigned integers and float16/float32/float64. DenseElementsAttrs of these
35 types can also be converted back to a corresponding buffer.
36 
37 For conversions outside of these types, a `type=` must be explicitly provided
38 and the buffer contents must be bit-castable to the MLIR internal
39 representation:
40 
41   * Integer types (except for i1): the buffer must be byte aligned to the
42     next byte boundary.
43   * Floating point types: Must be bit-castable to the given floating point
44     size.
45   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
46     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
47     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
48 
49 If a single element buffer is passed (or for i1, a single byte with value 0
50 or 255), then a splat will be created.
51 
52 Args:
53   array: The array or buffer to convert.
54   signless: If inferring an appropriate MLIR type, use signless types for
55     integers (defaults True).
56   type: Skips inference of the MLIR element type and uses this instead. The
57     storage size must be consistent with the actual contents of the buffer.
58   shape: Overrides the shape of the buffer when constructing the MLIR
59     shaped type. This is needed when the physical and logical shape differ (as
60     for i1).
61   context: Explicit context, if not from context manager.
62 
63 Returns:
64   DenseElementsAttr on success.
65 
66 Raises:
67   ValueError: If the type of the buffer or array cannot be matched to an MLIR
68     type or if the buffer does not meet expectations.
69 )";
70 
71 namespace {
72 
73 static MlirStringRef toMlirStringRef(const std::string &s) {
74   return mlirStringRefCreate(s.data(), s.size());
75 }
76 
77 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
78 public:
79   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
80   static constexpr const char *pyClassName = "AffineMapAttr";
81   using PyConcreteAttribute::PyConcreteAttribute;
82 
83   static void bindDerived(ClassTy &c) {
84     c.def_static(
85         "get",
86         [](PyAffineMap &affineMap) {
87           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
88           return PyAffineMapAttribute(affineMap.getContext(), attr);
89         },
90         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
91   }
92 };
93 
94 template <typename T>
95 static T pyTryCast(py::handle object) {
96   try {
97     return object.cast<T>();
98   } catch (py::cast_error &err) {
99     std::string msg =
100         std::string(
101             "Invalid attribute when attempting to create an ArrayAttribute (") +
102         err.what() + ")";
103     throw py::cast_error(msg);
104   } catch (py::reference_cast_error &err) {
105     std::string msg = std::string("Invalid attribute (None?) when attempting "
106                                   "to create an ArrayAttribute (") +
107                       err.what() + ")";
108     throw py::cast_error(msg);
109   }
110 }
111 
112 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
113 public:
114   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
115   static constexpr const char *pyClassName = "ArrayAttr";
116   using PyConcreteAttribute::PyConcreteAttribute;
117 
118   class PyArrayAttributeIterator {
119   public:
120     PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
121 
122     PyArrayAttributeIterator &dunderIter() { return *this; }
123 
124     PyAttribute dunderNext() {
125       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
126         throw py::stop_iteration();
127       }
128       return PyAttribute(attr.getContext(),
129                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
130     }
131 
132     static void bind(py::module &m) {
133       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
134                                            py::module_local())
135           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
136           .def("__next__", &PyArrayAttributeIterator::dunderNext);
137     }
138 
139   private:
140     PyAttribute attr;
141     int nextIndex = 0;
142   };
143 
144   PyAttribute getItem(intptr_t i) {
145     return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
146   }
147 
148   static void bindDerived(ClassTy &c) {
149     c.def_static(
150         "get",
151         [](py::list attributes, DefaultingPyMlirContext context) {
152           SmallVector<MlirAttribute> mlirAttributes;
153           mlirAttributes.reserve(py::len(attributes));
154           for (auto attribute : attributes) {
155             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
156           }
157           MlirAttribute attr = mlirArrayAttrGet(
158               context->get(), mlirAttributes.size(), mlirAttributes.data());
159           return PyArrayAttribute(context->getRef(), attr);
160         },
161         py::arg("attributes"), py::arg("context") = py::none(),
162         "Gets a uniqued Array attribute");
163     c.def("__getitem__",
164           [](PyArrayAttribute &arr, intptr_t i) {
165             if (i >= mlirArrayAttrGetNumElements(arr))
166               throw py::index_error("ArrayAttribute index out of range");
167             return arr.getItem(i);
168           })
169         .def("__len__",
170              [](const PyArrayAttribute &arr) {
171                return mlirArrayAttrGetNumElements(arr);
172              })
173         .def("__iter__", [](const PyArrayAttribute &arr) {
174           return PyArrayAttributeIterator(arr);
175         });
176     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
177       std::vector<MlirAttribute> attributes;
178       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
179       attributes.reserve(numOldElements + py::len(extras));
180       for (intptr_t i = 0; i < numOldElements; ++i)
181         attributes.push_back(arr.getItem(i));
182       for (py::handle attr : extras)
183         attributes.push_back(pyTryCast<PyAttribute>(attr));
184       MlirAttribute arrayAttr = mlirArrayAttrGet(
185           arr.getContext()->get(), attributes.size(), attributes.data());
186       return PyArrayAttribute(arr.getContext(), arrayAttr);
187     });
188   }
189 };
190 
191 /// Float Point Attribute subclass - FloatAttr.
192 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
193 public:
194   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
195   static constexpr const char *pyClassName = "FloatAttr";
196   using PyConcreteAttribute::PyConcreteAttribute;
197 
198   static void bindDerived(ClassTy &c) {
199     c.def_static(
200         "get",
201         [](PyType &type, double value, DefaultingPyLocation loc) {
202           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
203           // TODO: Rework error reporting once diagnostic engine is exposed
204           // in C API.
205           if (mlirAttributeIsNull(attr)) {
206             throw SetPyError(PyExc_ValueError,
207                              Twine("invalid '") +
208                                  py::repr(py::cast(type)).cast<std::string>() +
209                                  "' and expected floating point type.");
210           }
211           return PyFloatAttribute(type.getContext(), attr);
212         },
213         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
214         "Gets an uniqued float point attribute associated to a type");
215     c.def_static(
216         "get_f32",
217         [](double value, DefaultingPyMlirContext context) {
218           MlirAttribute attr = mlirFloatAttrDoubleGet(
219               context->get(), mlirF32TypeGet(context->get()), value);
220           return PyFloatAttribute(context->getRef(), attr);
221         },
222         py::arg("value"), py::arg("context") = py::none(),
223         "Gets an uniqued float point attribute associated to a f32 type");
224     c.def_static(
225         "get_f64",
226         [](double value, DefaultingPyMlirContext context) {
227           MlirAttribute attr = mlirFloatAttrDoubleGet(
228               context->get(), mlirF64TypeGet(context->get()), value);
229           return PyFloatAttribute(context->getRef(), attr);
230         },
231         py::arg("value"), py::arg("context") = py::none(),
232         "Gets an uniqued float point attribute associated to a f64 type");
233     c.def_property_readonly(
234         "value",
235         [](PyFloatAttribute &self) {
236           return mlirFloatAttrGetValueDouble(self);
237         },
238         "Returns the value of the float point attribute");
239   }
240 };
241 
242 /// Integer Attribute subclass - IntegerAttr.
243 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
244 public:
245   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
246   static constexpr const char *pyClassName = "IntegerAttr";
247   using PyConcreteAttribute::PyConcreteAttribute;
248 
249   static void bindDerived(ClassTy &c) {
250     c.def_static(
251         "get",
252         [](PyType &type, int64_t value) {
253           MlirAttribute attr = mlirIntegerAttrGet(type, value);
254           return PyIntegerAttribute(type.getContext(), attr);
255         },
256         py::arg("type"), py::arg("value"),
257         "Gets an uniqued integer attribute associated to a type");
258     c.def_property_readonly(
259         "value",
260         [](PyIntegerAttribute &self) {
261           return mlirIntegerAttrGetValueInt(self);
262         },
263         "Returns the value of the integer attribute");
264   }
265 };
266 
267 /// Bool Attribute subclass - BoolAttr.
268 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
269 public:
270   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
271   static constexpr const char *pyClassName = "BoolAttr";
272   using PyConcreteAttribute::PyConcreteAttribute;
273 
274   static void bindDerived(ClassTy &c) {
275     c.def_static(
276         "get",
277         [](bool value, DefaultingPyMlirContext context) {
278           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
279           return PyBoolAttribute(context->getRef(), attr);
280         },
281         py::arg("value"), py::arg("context") = py::none(),
282         "Gets an uniqued bool attribute");
283     c.def_property_readonly(
284         "value",
285         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
286         "Returns the value of the bool attribute");
287   }
288 };
289 
290 class PyFlatSymbolRefAttribute
291     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
292 public:
293   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
294   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
295   using PyConcreteAttribute::PyConcreteAttribute;
296 
297   static void bindDerived(ClassTy &c) {
298     c.def_static(
299         "get",
300         [](std::string value, DefaultingPyMlirContext context) {
301           MlirAttribute attr =
302               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
303           return PyFlatSymbolRefAttribute(context->getRef(), attr);
304         },
305         py::arg("value"), py::arg("context") = py::none(),
306         "Gets a uniqued FlatSymbolRef attribute");
307     c.def_property_readonly(
308         "value",
309         [](PyFlatSymbolRefAttribute &self) {
310           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
311           return py::str(stringRef.data, stringRef.length);
312         },
313         "Returns the value of the FlatSymbolRef attribute as a string");
314   }
315 };
316 
317 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
318 public:
319   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
320   static constexpr const char *pyClassName = "StringAttr";
321   using PyConcreteAttribute::PyConcreteAttribute;
322 
323   static void bindDerived(ClassTy &c) {
324     c.def_static(
325         "get",
326         [](std::string value, DefaultingPyMlirContext context) {
327           MlirAttribute attr =
328               mlirStringAttrGet(context->get(), toMlirStringRef(value));
329           return PyStringAttribute(context->getRef(), attr);
330         },
331         py::arg("value"), py::arg("context") = py::none(),
332         "Gets a uniqued string attribute");
333     c.def_static(
334         "get_typed",
335         [](PyType &type, std::string value) {
336           MlirAttribute attr =
337               mlirStringAttrTypedGet(type, toMlirStringRef(value));
338           return PyStringAttribute(type.getContext(), attr);
339         },
340 
341         "Gets a uniqued string attribute associated to a type");
342     c.def_property_readonly(
343         "value",
344         [](PyStringAttribute &self) {
345           MlirStringRef stringRef = mlirStringAttrGetValue(self);
346           return py::str(stringRef.data, stringRef.length);
347         },
348         "Returns the value of the string attribute");
349   }
350 };
351 
352 // TODO: Support construction of string elements.
353 class PyDenseElementsAttribute
354     : public PyConcreteAttribute<PyDenseElementsAttribute> {
355 public:
356   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
357   static constexpr const char *pyClassName = "DenseElementsAttr";
358   using PyConcreteAttribute::PyConcreteAttribute;
359 
360   static PyDenseElementsAttribute
361   getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType,
362                 Optional<std::vector<int64_t>> explicitShape,
363                 DefaultingPyMlirContext contextWrapper) {
364     // Request a contiguous view. In exotic cases, this will cause a copy.
365     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
366     Py_buffer *view = new Py_buffer();
367     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
368       delete view;
369       throw py::error_already_set();
370     }
371     py::buffer_info arrayInfo(view);
372     SmallVector<int64_t> shape;
373     if (explicitShape) {
374       shape.append(explicitShape->begin(), explicitShape->end());
375     } else {
376       shape.append(arrayInfo.shape.begin(),
377                    arrayInfo.shape.begin() + arrayInfo.ndim);
378     }
379 
380     MlirAttribute encodingAttr = mlirAttributeGetNull();
381     MlirContext context = contextWrapper->get();
382 
383     // Detect format codes that are suitable for bulk loading. This includes
384     // all byte aligned integer and floating point types up to 8 bytes.
385     // Notably, this excludes, bool (which needs to be bit-packed) and
386     // other exotics which do not have a direct representation in the buffer
387     // protocol (i.e. complex, etc).
388     Optional<MlirType> bulkLoadElementType;
389     if (explicitType) {
390       bulkLoadElementType = *explicitType;
391     } else if (arrayInfo.format == "f") {
392       // f32
393       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
394       bulkLoadElementType = mlirF32TypeGet(context);
395     } else if (arrayInfo.format == "d") {
396       // f64
397       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
398       bulkLoadElementType = mlirF64TypeGet(context);
399     } else if (arrayInfo.format == "e") {
400       // f16
401       assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
402       bulkLoadElementType = mlirF16TypeGet(context);
403     } else if (isSignedIntegerFormat(arrayInfo.format)) {
404       if (arrayInfo.itemsize == 4) {
405         // i32
406         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
407                                        : mlirIntegerTypeSignedGet(context, 32);
408       } else if (arrayInfo.itemsize == 8) {
409         // i64
410         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
411                                        : mlirIntegerTypeSignedGet(context, 64);
412       } else if (arrayInfo.itemsize == 1) {
413         // i8
414         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
415                                        : mlirIntegerTypeSignedGet(context, 8);
416       } else if (arrayInfo.itemsize == 2) {
417         // i16
418         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
419                                        : mlirIntegerTypeSignedGet(context, 16);
420       }
421     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
422       if (arrayInfo.itemsize == 4) {
423         // unsigned i32
424         bulkLoadElementType = signless
425                                   ? mlirIntegerTypeGet(context, 32)
426                                   : mlirIntegerTypeUnsignedGet(context, 32);
427       } else if (arrayInfo.itemsize == 8) {
428         // unsigned i64
429         bulkLoadElementType = signless
430                                   ? mlirIntegerTypeGet(context, 64)
431                                   : mlirIntegerTypeUnsignedGet(context, 64);
432       } else if (arrayInfo.itemsize == 1) {
433         // i8
434         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
435                                        : mlirIntegerTypeUnsignedGet(context, 8);
436       } else if (arrayInfo.itemsize == 2) {
437         // i16
438         bulkLoadElementType = signless
439                                   ? mlirIntegerTypeGet(context, 16)
440                                   : mlirIntegerTypeUnsignedGet(context, 16);
441       }
442     }
443     if (bulkLoadElementType) {
444       auto shapedType = mlirRankedTensorTypeGet(
445           shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
446       size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
447       MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
448           shapedType, rawBufferSize, arrayInfo.ptr);
449       if (mlirAttributeIsNull(attr)) {
450         throw std::invalid_argument(
451             "DenseElementsAttr could not be constructed from the given buffer. "
452             "This may mean that the Python buffer layout does not match that "
453             "MLIR expected layout and is a bug.");
454       }
455       return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
456     }
457 
458     throw std::invalid_argument(
459         std::string("unimplemented array format conversion from format: ") +
460         arrayInfo.format);
461   }
462 
463   static PyDenseElementsAttribute getSplat(PyType shapedType,
464                                            PyAttribute &elementAttr) {
465     auto contextWrapper =
466         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
467     if (!mlirAttributeIsAInteger(elementAttr) &&
468         !mlirAttributeIsAFloat(elementAttr)) {
469       std::string message = "Illegal element type for DenseElementsAttr: ";
470       message.append(py::repr(py::cast(elementAttr)));
471       throw SetPyError(PyExc_ValueError, message);
472     }
473     if (!mlirTypeIsAShaped(shapedType) ||
474         !mlirShapedTypeHasStaticShape(shapedType)) {
475       std::string message =
476           "Expected a static ShapedType for the shaped_type parameter: ";
477       message.append(py::repr(py::cast(shapedType)));
478       throw SetPyError(PyExc_ValueError, message);
479     }
480     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
481     MlirType attrType = mlirAttributeGetType(elementAttr);
482     if (!mlirTypeEqual(shapedElementType, attrType)) {
483       std::string message =
484           "Shaped element type and attribute type must be equal: shaped=";
485       message.append(py::repr(py::cast(shapedType)));
486       message.append(", element=");
487       message.append(py::repr(py::cast(elementAttr)));
488       throw SetPyError(PyExc_ValueError, message);
489     }
490 
491     MlirAttribute elements =
492         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
493     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
494   }
495 
496   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
497 
498   py::buffer_info accessBuffer() {
499     if (mlirDenseElementsAttrIsSplat(*this)) {
500       // TODO: Raise an exception.
501       // Reported as https://github.com/pybind/pybind11/issues/3336
502       return py::buffer_info();
503     }
504 
505     MlirType shapedType = mlirAttributeGetType(*this);
506     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
507     std::string format;
508 
509     if (mlirTypeIsAF32(elementType)) {
510       // f32
511       return bufferInfo<float>(shapedType);
512     } else if (mlirTypeIsAF64(elementType)) {
513       // f64
514       return bufferInfo<double>(shapedType);
515     } else if (mlirTypeIsAF16(elementType)) {
516       // f16
517       return bufferInfo<uint16_t>(shapedType, "e");
518     } else if (mlirTypeIsAInteger(elementType) &&
519                mlirIntegerTypeGetWidth(elementType) == 32) {
520       if (mlirIntegerTypeIsSignless(elementType) ||
521           mlirIntegerTypeIsSigned(elementType)) {
522         // i32
523         return bufferInfo<int32_t>(shapedType);
524       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
525         // unsigned i32
526         return bufferInfo<uint32_t>(shapedType);
527       }
528     } else if (mlirTypeIsAInteger(elementType) &&
529                mlirIntegerTypeGetWidth(elementType) == 64) {
530       if (mlirIntegerTypeIsSignless(elementType) ||
531           mlirIntegerTypeIsSigned(elementType)) {
532         // i64
533         return bufferInfo<int64_t>(shapedType);
534       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
535         // unsigned i64
536         return bufferInfo<uint64_t>(shapedType);
537       }
538     } else if (mlirTypeIsAInteger(elementType) &&
539                mlirIntegerTypeGetWidth(elementType) == 8) {
540       if (mlirIntegerTypeIsSignless(elementType) ||
541           mlirIntegerTypeIsSigned(elementType)) {
542         // i8
543         return bufferInfo<int8_t>(shapedType);
544       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
545         // unsigned i8
546         return bufferInfo<uint8_t>(shapedType);
547       }
548     } else if (mlirTypeIsAInteger(elementType) &&
549                mlirIntegerTypeGetWidth(elementType) == 16) {
550       if (mlirIntegerTypeIsSignless(elementType) ||
551           mlirIntegerTypeIsSigned(elementType)) {
552         // i16
553         return bufferInfo<int16_t>(shapedType);
554       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
555         // unsigned i16
556         return bufferInfo<uint16_t>(shapedType);
557       }
558     }
559 
560     // TODO: Currently crashes the program. Just returning an empty buffer
561     // for now.
562     // Reported as https://github.com/pybind/pybind11/issues/3336
563     // throw std::invalid_argument(
564     //     "unsupported data type for conversion to Python buffer");
565     return py::buffer_info();
566   }
567 
568   static void bindDerived(ClassTy &c) {
569     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
570         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
571                     py::arg("array"), py::arg("signless") = true,
572                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
573                     py::arg("context") = py::none(),
574                     kDenseElementsAttrGetDocstring)
575         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
576                     py::arg("shaped_type"), py::arg("element_attr"),
577                     "Gets a DenseElementsAttr where all values are the same")
578         .def_property_readonly("is_splat",
579                                [](PyDenseElementsAttribute &self) -> bool {
580                                  return mlirDenseElementsAttrIsSplat(self);
581                                })
582         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
583   }
584 
585 private:
586   static bool isUnsignedIntegerFormat(const std::string &format) {
587     if (format.empty())
588       return false;
589     char code = format[0];
590     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
591            code == 'Q';
592   }
593 
594   static bool isSignedIntegerFormat(const std::string &format) {
595     if (format.empty())
596       return false;
597     char code = format[0];
598     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
599            code == 'q';
600   }
601 
602   template <typename Type>
603   py::buffer_info bufferInfo(MlirType shapedType,
604                              const char *explicitFormat = nullptr) {
605     intptr_t rank = mlirShapedTypeGetRank(shapedType);
606     // Prepare the data for the buffer_info.
607     // Buffer is configured for read-only access below.
608     Type *data = static_cast<Type *>(
609         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
610     // Prepare the shape for the buffer_info.
611     SmallVector<intptr_t, 4> shape;
612     for (intptr_t i = 0; i < rank; ++i)
613       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
614     // Prepare the strides for the buffer_info.
615     SmallVector<intptr_t, 4> strides;
616     intptr_t strideFactor = 1;
617     for (intptr_t i = 1; i < rank; ++i) {
618       strideFactor = 1;
619       for (intptr_t j = i; j < rank; ++j) {
620         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
621       }
622       strides.push_back(sizeof(Type) * strideFactor);
623     }
624     strides.push_back(sizeof(Type));
625     std::string format;
626     if (explicitFormat) {
627       format = explicitFormat;
628     } else {
629       format = py::format_descriptor<Type>::format();
630     }
631     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
632                            /*readonly=*/true);
633   }
634 }; // namespace
635 
636 /// Refinement of the PyDenseElementsAttribute for attributes containing integer
637 /// (and boolean) values. Supports element access.
638 class PyDenseIntElementsAttribute
639     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
640                                  PyDenseElementsAttribute> {
641 public:
642   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
643   static constexpr const char *pyClassName = "DenseIntElementsAttr";
644   using PyConcreteAttribute::PyConcreteAttribute;
645 
646   /// Returns the element at the given linear position. Asserts if the index is
647   /// out of range.
648   py::int_ dunderGetItem(intptr_t pos) {
649     if (pos < 0 || pos >= dunderLen()) {
650       throw SetPyError(PyExc_IndexError,
651                        "attempt to access out of bounds element");
652     }
653 
654     MlirType type = mlirAttributeGetType(*this);
655     type = mlirShapedTypeGetElementType(type);
656     assert(mlirTypeIsAInteger(type) &&
657            "expected integer element type in dense int elements attribute");
658     // Dispatch element extraction to an appropriate C function based on the
659     // elemental type of the attribute. py::int_ is implicitly constructible
660     // from any C++ integral type and handles bitwidth correctly.
661     // TODO: consider caching the type properties in the constructor to avoid
662     // querying them on each element access.
663     unsigned width = mlirIntegerTypeGetWidth(type);
664     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
665     if (isUnsigned) {
666       if (width == 1) {
667         return mlirDenseElementsAttrGetBoolValue(*this, pos);
668       }
669       if (width == 32) {
670         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
671       }
672       if (width == 64) {
673         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
674       }
675     } else {
676       if (width == 1) {
677         return mlirDenseElementsAttrGetBoolValue(*this, pos);
678       }
679       if (width == 32) {
680         return mlirDenseElementsAttrGetInt32Value(*this, pos);
681       }
682       if (width == 64) {
683         return mlirDenseElementsAttrGetInt64Value(*this, pos);
684       }
685     }
686     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
687   }
688 
689   static void bindDerived(ClassTy &c) {
690     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
691   }
692 };
693 
694 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
695 public:
696   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
697   static constexpr const char *pyClassName = "DictAttr";
698   using PyConcreteAttribute::PyConcreteAttribute;
699 
700   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
701 
702   static void bindDerived(ClassTy &c) {
703     c.def("__len__", &PyDictAttribute::dunderLen);
704     c.def_static(
705         "get",
706         [](py::dict attributes, DefaultingPyMlirContext context) {
707           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
708           mlirNamedAttributes.reserve(attributes.size());
709           for (auto &it : attributes) {
710             auto &mlir_attr = it.second.cast<PyAttribute &>();
711             auto name = it.first.cast<std::string>();
712             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
713                 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
714                                   toMlirStringRef(name)),
715                 mlir_attr));
716           }
717           MlirAttribute attr =
718               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
719                                     mlirNamedAttributes.data());
720           return PyDictAttribute(context->getRef(), attr);
721         },
722         py::arg("value") = py::dict(), py::arg("context") = py::none(),
723         "Gets an uniqued dict attribute");
724     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
725       MlirAttribute attr =
726           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
727       if (mlirAttributeIsNull(attr)) {
728         throw SetPyError(PyExc_KeyError,
729                          "attempt to access a non-existent attribute");
730       }
731       return PyAttribute(self.getContext(), attr);
732     });
733     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
734       if (index < 0 || index >= self.dunderLen()) {
735         throw SetPyError(PyExc_IndexError,
736                          "attempt to access out of bounds attribute");
737       }
738       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
739       return PyNamedAttribute(
740           namedAttr.attribute,
741           std::string(mlirIdentifierStr(namedAttr.name).data));
742     });
743   }
744 };
745 
746 /// Refinement of PyDenseElementsAttribute for attributes containing
747 /// floating-point values. Supports element access.
748 class PyDenseFPElementsAttribute
749     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
750                                  PyDenseElementsAttribute> {
751 public:
752   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
753   static constexpr const char *pyClassName = "DenseFPElementsAttr";
754   using PyConcreteAttribute::PyConcreteAttribute;
755 
756   py::float_ dunderGetItem(intptr_t pos) {
757     if (pos < 0 || pos >= dunderLen()) {
758       throw SetPyError(PyExc_IndexError,
759                        "attempt to access out of bounds element");
760     }
761 
762     MlirType type = mlirAttributeGetType(*this);
763     type = mlirShapedTypeGetElementType(type);
764     // Dispatch element extraction to an appropriate C function based on the
765     // elemental type of the attribute. py::float_ is implicitly constructible
766     // from float and double.
767     // TODO: consider caching the type properties in the constructor to avoid
768     // querying them on each element access.
769     if (mlirTypeIsAF32(type)) {
770       return mlirDenseElementsAttrGetFloatValue(*this, pos);
771     }
772     if (mlirTypeIsAF64(type)) {
773       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
774     }
775     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
776   }
777 
778   static void bindDerived(ClassTy &c) {
779     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
780   }
781 };
782 
783 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
784 public:
785   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
786   static constexpr const char *pyClassName = "TypeAttr";
787   using PyConcreteAttribute::PyConcreteAttribute;
788 
789   static void bindDerived(ClassTy &c) {
790     c.def_static(
791         "get",
792         [](PyType value, DefaultingPyMlirContext context) {
793           MlirAttribute attr = mlirTypeAttrGet(value.get());
794           return PyTypeAttribute(context->getRef(), attr);
795         },
796         py::arg("value"), py::arg("context") = py::none(),
797         "Gets a uniqued Type attribute");
798     c.def_property_readonly("value", [](PyTypeAttribute &self) {
799       return PyType(self.getContext()->getRef(),
800                     mlirTypeAttrGetValue(self.get()));
801     });
802   }
803 };
804 
805 /// Unit Attribute subclass. Unit attributes don't have values.
806 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
807 public:
808   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
809   static constexpr const char *pyClassName = "UnitAttr";
810   using PyConcreteAttribute::PyConcreteAttribute;
811 
812   static void bindDerived(ClassTy &c) {
813     c.def_static(
814         "get",
815         [](DefaultingPyMlirContext context) {
816           return PyUnitAttribute(context->getRef(),
817                                  mlirUnitAttrGet(context->get()));
818         },
819         py::arg("context") = py::none(), "Create a Unit attribute.");
820   }
821 };
822 
823 } // namespace
824 
825 void mlir::python::populateIRAttributes(py::module &m) {
826   PyAffineMapAttribute::bind(m);
827   PyArrayAttribute::bind(m);
828   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
829   PyBoolAttribute::bind(m);
830   PyDenseElementsAttribute::bind(m);
831   PyDenseFPElementsAttribute::bind(m);
832   PyDenseIntElementsAttribute::bind(m);
833   PyDictAttribute::bind(m);
834   PyFlatSymbolRefAttribute::bind(m);
835   PyFloatAttribute::bind(m);
836   PyIntegerAttribute::bind(m);
837   PyStringAttribute::bind(m);
838   PyTypeAttribute::bind(m);
839   PyUnitAttribute::bind(m);
840 }
841