1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "IRModule.h"
12 
13 #include "PybindUtils.h"
14 
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 
18 namespace py = pybind11;
19 using namespace mlir;
20 using namespace mlir::python;
21 
22 using llvm::Optional;
23 using llvm::SmallVector;
24 using llvm::Twine;
25 
26 //------------------------------------------------------------------------------
27 // Docstrings (trivial, non-duplicated docstrings are included inline).
28 //------------------------------------------------------------------------------
29 
30 static const char kDenseElementsAttrGetDocstring[] =
31     R"(Gets a DenseElementsAttr from a Python buffer or array.
32 
33 When `type` is not provided, then some limited type inferencing is done based
34 on the buffer format. Support presently exists for 8/16/32/64 signed and
35 unsigned integers and float16/float32/float64. DenseElementsAttrs of these
36 types can also be converted back to a corresponding buffer.
37 
38 For conversions outside of these types, a `type=` must be explicitly provided
39 and the buffer contents must be bit-castable to the MLIR internal
40 representation:
41 
42   * Integer types (except for i1): the buffer must be byte aligned to the
43     next byte boundary.
44   * Floating point types: Must be bit-castable to the given floating point
45     size.
46   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
47     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
48     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
49 
50 If a single element buffer is passed (or for i1, a single byte with value 0
51 or 255), then a splat will be created.
52 
53 Args:
54   array: The array or buffer to convert.
55   signless: If inferring an appropriate MLIR type, use signless types for
56     integers (defaults True).
57   type: Skips inference of the MLIR element type and uses this instead. The
58     storage size must be consistent with the actual contents of the buffer.
59   shape: Overrides the shape of the buffer when constructing the MLIR
60     shaped type. This is needed when the physical and logical shape differ (as
61     for i1).
62   context: Explicit context, if not from context manager.
63 
64 Returns:
65   DenseElementsAttr on success.
66 
67 Raises:
68   ValueError: If the type of the buffer or array cannot be matched to an MLIR
69     type or if the buffer does not meet expectations.
70 )";
71 
72 namespace {
73 
toMlirStringRef(const std::string & s)74 static MlirStringRef toMlirStringRef(const std::string &s) {
75   return mlirStringRefCreate(s.data(), s.size());
76 }
77 
78 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
79 public:
80   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
81   static constexpr const char *pyClassName = "AffineMapAttr";
82   using PyConcreteAttribute::PyConcreteAttribute;
83 
bindDerived(ClassTy & c)84   static void bindDerived(ClassTy &c) {
85     c.def_static(
86         "get",
87         [](PyAffineMap &affineMap) {
88           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
89           return PyAffineMapAttribute(affineMap.getContext(), attr);
90         },
91         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
92   }
93 };
94 
95 template <typename T>
pyTryCast(py::handle object)96 static T pyTryCast(py::handle object) {
97   try {
98     return object.cast<T>();
99   } catch (py::cast_error &err) {
100     std::string msg =
101         std::string(
102             "Invalid attribute when attempting to create an ArrayAttribute (") +
103         err.what() + ")";
104     throw py::cast_error(msg);
105   } catch (py::reference_cast_error &err) {
106     std::string msg = std::string("Invalid attribute (None?) when attempting "
107                                   "to create an ArrayAttribute (") +
108                       err.what() + ")";
109     throw py::cast_error(msg);
110   }
111 }
112 
113 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
114 public:
115   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
116   static constexpr const char *pyClassName = "ArrayAttr";
117   using PyConcreteAttribute::PyConcreteAttribute;
118 
119   class PyArrayAttributeIterator {
120   public:
PyArrayAttributeIterator(PyAttribute attr)121     PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
122 
dunderIter()123     PyArrayAttributeIterator &dunderIter() { return *this; }
124 
dunderNext()125     PyAttribute dunderNext() {
126       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
127         throw py::stop_iteration();
128       }
129       return PyAttribute(attr.getContext(),
130                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
131     }
132 
bind(py::module & m)133     static void bind(py::module &m) {
134       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
135                                            py::module_local())
136           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
137           .def("__next__", &PyArrayAttributeIterator::dunderNext);
138     }
139 
140   private:
141     PyAttribute attr;
142     int nextIndex = 0;
143   };
144 
getItem(intptr_t i)145   PyAttribute getItem(intptr_t i) {
146     return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
147   }
148 
bindDerived(ClassTy & c)149   static void bindDerived(ClassTy &c) {
150     c.def_static(
151         "get",
152         [](py::list attributes, DefaultingPyMlirContext context) {
153           SmallVector<MlirAttribute> mlirAttributes;
154           mlirAttributes.reserve(py::len(attributes));
155           for (auto attribute : attributes) {
156             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
157           }
158           MlirAttribute attr = mlirArrayAttrGet(
159               context->get(), mlirAttributes.size(), mlirAttributes.data());
160           return PyArrayAttribute(context->getRef(), attr);
161         },
162         py::arg("attributes"), py::arg("context") = py::none(),
163         "Gets a uniqued Array attribute");
164     c.def("__getitem__",
165           [](PyArrayAttribute &arr, intptr_t i) {
166             if (i >= mlirArrayAttrGetNumElements(arr))
167               throw py::index_error("ArrayAttribute index out of range");
168             return arr.getItem(i);
169           })
170         .def("__len__",
171              [](const PyArrayAttribute &arr) {
172                return mlirArrayAttrGetNumElements(arr);
173              })
174         .def("__iter__", [](const PyArrayAttribute &arr) {
175           return PyArrayAttributeIterator(arr);
176         });
177     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
178       std::vector<MlirAttribute> attributes;
179       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
180       attributes.reserve(numOldElements + py::len(extras));
181       for (intptr_t i = 0; i < numOldElements; ++i)
182         attributes.push_back(arr.getItem(i));
183       for (py::handle attr : extras)
184         attributes.push_back(pyTryCast<PyAttribute>(attr));
185       MlirAttribute arrayAttr = mlirArrayAttrGet(
186           arr.getContext()->get(), attributes.size(), attributes.data());
187       return PyArrayAttribute(arr.getContext(), arrayAttr);
188     });
189   }
190 };
191 
192 /// Float Point Attribute subclass - FloatAttr.
193 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
194 public:
195   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
196   static constexpr const char *pyClassName = "FloatAttr";
197   using PyConcreteAttribute::PyConcreteAttribute;
198 
bindDerived(ClassTy & c)199   static void bindDerived(ClassTy &c) {
200     c.def_static(
201         "get",
202         [](PyType &type, double value, DefaultingPyLocation loc) {
203           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
204           // TODO: Rework error reporting once diagnostic engine is exposed
205           // in C API.
206           if (mlirAttributeIsNull(attr)) {
207             throw SetPyError(PyExc_ValueError,
208                              Twine("invalid '") +
209                                  py::repr(py::cast(type)).cast<std::string>() +
210                                  "' and expected floating point type.");
211           }
212           return PyFloatAttribute(type.getContext(), attr);
213         },
214         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
215         "Gets an uniqued float point attribute associated to a type");
216     c.def_static(
217         "get_f32",
218         [](double value, DefaultingPyMlirContext context) {
219           MlirAttribute attr = mlirFloatAttrDoubleGet(
220               context->get(), mlirF32TypeGet(context->get()), value);
221           return PyFloatAttribute(context->getRef(), attr);
222         },
223         py::arg("value"), py::arg("context") = py::none(),
224         "Gets an uniqued float point attribute associated to a f32 type");
225     c.def_static(
226         "get_f64",
227         [](double value, DefaultingPyMlirContext context) {
228           MlirAttribute attr = mlirFloatAttrDoubleGet(
229               context->get(), mlirF64TypeGet(context->get()), value);
230           return PyFloatAttribute(context->getRef(), attr);
231         },
232         py::arg("value"), py::arg("context") = py::none(),
233         "Gets an uniqued float point attribute associated to a f64 type");
234     c.def_property_readonly(
235         "value",
236         [](PyFloatAttribute &self) {
237           return mlirFloatAttrGetValueDouble(self);
238         },
239         "Returns the value of the float point attribute");
240   }
241 };
242 
243 /// Integer Attribute subclass - IntegerAttr.
244 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
245 public:
246   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
247   static constexpr const char *pyClassName = "IntegerAttr";
248   using PyConcreteAttribute::PyConcreteAttribute;
249 
bindDerived(ClassTy & c)250   static void bindDerived(ClassTy &c) {
251     c.def_static(
252         "get",
253         [](PyType &type, int64_t value) {
254           MlirAttribute attr = mlirIntegerAttrGet(type, value);
255           return PyIntegerAttribute(type.getContext(), attr);
256         },
257         py::arg("type"), py::arg("value"),
258         "Gets an uniqued integer attribute associated to a type");
259     c.def_property_readonly(
260         "value",
261         [](PyIntegerAttribute &self) -> py::int_ {
262           MlirType type = mlirAttributeGetType(self);
263           if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
264             return mlirIntegerAttrGetValueInt(self);
265           if (mlirIntegerTypeIsSigned(type))
266             return mlirIntegerAttrGetValueSInt(self);
267           return mlirIntegerAttrGetValueUInt(self);
268         },
269         "Returns the value of the integer attribute");
270   }
271 };
272 
273 /// Bool Attribute subclass - BoolAttr.
274 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
275 public:
276   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
277   static constexpr const char *pyClassName = "BoolAttr";
278   using PyConcreteAttribute::PyConcreteAttribute;
279 
bindDerived(ClassTy & c)280   static void bindDerived(ClassTy &c) {
281     c.def_static(
282         "get",
283         [](bool value, DefaultingPyMlirContext context) {
284           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
285           return PyBoolAttribute(context->getRef(), attr);
286         },
287         py::arg("value"), py::arg("context") = py::none(),
288         "Gets an uniqued bool attribute");
289     c.def_property_readonly(
290         "value",
291         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
292         "Returns the value of the bool attribute");
293   }
294 };
295 
296 class PyFlatSymbolRefAttribute
297     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
298 public:
299   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
300   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
301   using PyConcreteAttribute::PyConcreteAttribute;
302 
bindDerived(ClassTy & c)303   static void bindDerived(ClassTy &c) {
304     c.def_static(
305         "get",
306         [](std::string value, DefaultingPyMlirContext context) {
307           MlirAttribute attr =
308               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
309           return PyFlatSymbolRefAttribute(context->getRef(), attr);
310         },
311         py::arg("value"), py::arg("context") = py::none(),
312         "Gets a uniqued FlatSymbolRef attribute");
313     c.def_property_readonly(
314         "value",
315         [](PyFlatSymbolRefAttribute &self) {
316           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
317           return py::str(stringRef.data, stringRef.length);
318         },
319         "Returns the value of the FlatSymbolRef attribute as a string");
320   }
321 };
322 
323 class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
324 public:
325   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
326   static constexpr const char *pyClassName = "OpaqueAttr";
327   using PyConcreteAttribute::PyConcreteAttribute;
328 
bindDerived(ClassTy & c)329   static void bindDerived(ClassTy &c) {
330     c.def_static(
331         "get",
332         [](std::string dialectNamespace, py::buffer buffer, PyType &type,
333            DefaultingPyMlirContext context) {
334           const py::buffer_info bufferInfo = buffer.request();
335           intptr_t bufferSize = bufferInfo.size;
336           MlirAttribute attr = mlirOpaqueAttrGet(
337               context->get(), toMlirStringRef(dialectNamespace), bufferSize,
338               static_cast<char *>(bufferInfo.ptr), type);
339           return PyOpaqueAttribute(context->getRef(), attr);
340         },
341         py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"),
342         py::arg("context") = py::none(), "Gets an Opaque attribute.");
343     c.def_property_readonly(
344         "dialect_namespace",
345         [](PyOpaqueAttribute &self) {
346           MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
347           return py::str(stringRef.data, stringRef.length);
348         },
349         "Returns the dialect namespace for the Opaque attribute as a string");
350     c.def_property_readonly(
351         "data",
352         [](PyOpaqueAttribute &self) {
353           MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
354           return py::str(stringRef.data, stringRef.length);
355         },
356         "Returns the data for the Opaqued attributes as a string");
357   }
358 };
359 
360 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
361 public:
362   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
363   static constexpr const char *pyClassName = "StringAttr";
364   using PyConcreteAttribute::PyConcreteAttribute;
365 
bindDerived(ClassTy & c)366   static void bindDerived(ClassTy &c) {
367     c.def_static(
368         "get",
369         [](std::string value, DefaultingPyMlirContext context) {
370           MlirAttribute attr =
371               mlirStringAttrGet(context->get(), toMlirStringRef(value));
372           return PyStringAttribute(context->getRef(), attr);
373         },
374         py::arg("value"), py::arg("context") = py::none(),
375         "Gets a uniqued string attribute");
376     c.def_static(
377         "get_typed",
378         [](PyType &type, std::string value) {
379           MlirAttribute attr =
380               mlirStringAttrTypedGet(type, toMlirStringRef(value));
381           return PyStringAttribute(type.getContext(), attr);
382         },
383         py::arg("type"), py::arg("value"),
384         "Gets a uniqued string attribute associated to a type");
385     c.def_property_readonly(
386         "value",
387         [](PyStringAttribute &self) {
388           MlirStringRef stringRef = mlirStringAttrGetValue(self);
389           return py::str(stringRef.data, stringRef.length);
390         },
391         "Returns the value of the string attribute");
392   }
393 };
394 
395 // TODO: Support construction of string elements.
396 class PyDenseElementsAttribute
397     : public PyConcreteAttribute<PyDenseElementsAttribute> {
398 public:
399   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
400   static constexpr const char *pyClassName = "DenseElementsAttr";
401   using PyConcreteAttribute::PyConcreteAttribute;
402 
403   static PyDenseElementsAttribute
getFromBuffer(py::buffer array,bool signless,Optional<PyType> explicitType,Optional<std::vector<int64_t>> explicitShape,DefaultingPyMlirContext contextWrapper)404   getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType,
405                 Optional<std::vector<int64_t>> explicitShape,
406                 DefaultingPyMlirContext contextWrapper) {
407     // Request a contiguous view. In exotic cases, this will cause a copy.
408     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
409     Py_buffer *view = new Py_buffer();
410     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
411       delete view;
412       throw py::error_already_set();
413     }
414     py::buffer_info arrayInfo(view);
415     SmallVector<int64_t> shape;
416     if (explicitShape) {
417       shape.append(explicitShape->begin(), explicitShape->end());
418     } else {
419       shape.append(arrayInfo.shape.begin(),
420                    arrayInfo.shape.begin() + arrayInfo.ndim);
421     }
422 
423     MlirAttribute encodingAttr = mlirAttributeGetNull();
424     MlirContext context = contextWrapper->get();
425 
426     // Detect format codes that are suitable for bulk loading. This includes
427     // all byte aligned integer and floating point types up to 8 bytes.
428     // Notably, this excludes, bool (which needs to be bit-packed) and
429     // other exotics which do not have a direct representation in the buffer
430     // protocol (i.e. complex, etc).
431     Optional<MlirType> bulkLoadElementType;
432     if (explicitType) {
433       bulkLoadElementType = *explicitType;
434     } else if (arrayInfo.format == "f") {
435       // f32
436       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
437       bulkLoadElementType = mlirF32TypeGet(context);
438     } else if (arrayInfo.format == "d") {
439       // f64
440       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
441       bulkLoadElementType = mlirF64TypeGet(context);
442     } else if (arrayInfo.format == "e") {
443       // f16
444       assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
445       bulkLoadElementType = mlirF16TypeGet(context);
446     } else if (isSignedIntegerFormat(arrayInfo.format)) {
447       if (arrayInfo.itemsize == 4) {
448         // i32
449         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
450                                        : mlirIntegerTypeSignedGet(context, 32);
451       } else if (arrayInfo.itemsize == 8) {
452         // i64
453         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
454                                        : mlirIntegerTypeSignedGet(context, 64);
455       } else if (arrayInfo.itemsize == 1) {
456         // i8
457         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
458                                        : mlirIntegerTypeSignedGet(context, 8);
459       } else if (arrayInfo.itemsize == 2) {
460         // i16
461         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
462                                        : mlirIntegerTypeSignedGet(context, 16);
463       }
464     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
465       if (arrayInfo.itemsize == 4) {
466         // unsigned i32
467         bulkLoadElementType = signless
468                                   ? mlirIntegerTypeGet(context, 32)
469                                   : mlirIntegerTypeUnsignedGet(context, 32);
470       } else if (arrayInfo.itemsize == 8) {
471         // unsigned i64
472         bulkLoadElementType = signless
473                                   ? mlirIntegerTypeGet(context, 64)
474                                   : mlirIntegerTypeUnsignedGet(context, 64);
475       } else if (arrayInfo.itemsize == 1) {
476         // i8
477         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
478                                        : mlirIntegerTypeUnsignedGet(context, 8);
479       } else if (arrayInfo.itemsize == 2) {
480         // i16
481         bulkLoadElementType = signless
482                                   ? mlirIntegerTypeGet(context, 16)
483                                   : mlirIntegerTypeUnsignedGet(context, 16);
484       }
485     }
486     if (bulkLoadElementType) {
487       auto shapedType = mlirRankedTensorTypeGet(
488           shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
489       size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
490       MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
491           shapedType, rawBufferSize, arrayInfo.ptr);
492       if (mlirAttributeIsNull(attr)) {
493         throw std::invalid_argument(
494             "DenseElementsAttr could not be constructed from the given buffer. "
495             "This may mean that the Python buffer layout does not match that "
496             "MLIR expected layout and is a bug.");
497       }
498       return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
499     }
500 
501     throw std::invalid_argument(
502         std::string("unimplemented array format conversion from format: ") +
503         arrayInfo.format);
504   }
505 
getSplat(const PyType & shapedType,PyAttribute & elementAttr)506   static PyDenseElementsAttribute getSplat(const PyType &shapedType,
507                                            PyAttribute &elementAttr) {
508     auto contextWrapper =
509         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
510     if (!mlirAttributeIsAInteger(elementAttr) &&
511         !mlirAttributeIsAFloat(elementAttr)) {
512       std::string message = "Illegal element type for DenseElementsAttr: ";
513       message.append(py::repr(py::cast(elementAttr)));
514       throw SetPyError(PyExc_ValueError, message);
515     }
516     if (!mlirTypeIsAShaped(shapedType) ||
517         !mlirShapedTypeHasStaticShape(shapedType)) {
518       std::string message =
519           "Expected a static ShapedType for the shaped_type parameter: ";
520       message.append(py::repr(py::cast(shapedType)));
521       throw SetPyError(PyExc_ValueError, message);
522     }
523     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
524     MlirType attrType = mlirAttributeGetType(elementAttr);
525     if (!mlirTypeEqual(shapedElementType, attrType)) {
526       std::string message =
527           "Shaped element type and attribute type must be equal: shaped=";
528       message.append(py::repr(py::cast(shapedType)));
529       message.append(", element=");
530       message.append(py::repr(py::cast(elementAttr)));
531       throw SetPyError(PyExc_ValueError, message);
532     }
533 
534     MlirAttribute elements =
535         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
536     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
537   }
538 
dunderLen()539   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
540 
accessBuffer()541   py::buffer_info accessBuffer() {
542     if (mlirDenseElementsAttrIsSplat(*this)) {
543       // TODO: Currently crashes the program.
544       // Reported as https://github.com/pybind/pybind11/issues/3336
545       throw std::invalid_argument(
546           "unsupported data type for conversion to Python buffer");
547     }
548 
549     MlirType shapedType = mlirAttributeGetType(*this);
550     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
551     std::string format;
552 
553     if (mlirTypeIsAF32(elementType)) {
554       // f32
555       return bufferInfo<float>(shapedType);
556     }
557     if (mlirTypeIsAF64(elementType)) {
558       // f64
559       return bufferInfo<double>(shapedType);
560     }
561     if (mlirTypeIsAF16(elementType)) {
562       // f16
563       return bufferInfo<uint16_t>(shapedType, "e");
564     }
565     if (mlirTypeIsAInteger(elementType) &&
566         mlirIntegerTypeGetWidth(elementType) == 32) {
567       if (mlirIntegerTypeIsSignless(elementType) ||
568           mlirIntegerTypeIsSigned(elementType)) {
569         // i32
570         return bufferInfo<int32_t>(shapedType);
571       }
572       if (mlirIntegerTypeIsUnsigned(elementType)) {
573         // unsigned i32
574         return bufferInfo<uint32_t>(shapedType);
575       }
576     } else if (mlirTypeIsAInteger(elementType) &&
577                mlirIntegerTypeGetWidth(elementType) == 64) {
578       if (mlirIntegerTypeIsSignless(elementType) ||
579           mlirIntegerTypeIsSigned(elementType)) {
580         // i64
581         return bufferInfo<int64_t>(shapedType);
582       }
583       if (mlirIntegerTypeIsUnsigned(elementType)) {
584         // unsigned i64
585         return bufferInfo<uint64_t>(shapedType);
586       }
587     } else if (mlirTypeIsAInteger(elementType) &&
588                mlirIntegerTypeGetWidth(elementType) == 8) {
589       if (mlirIntegerTypeIsSignless(elementType) ||
590           mlirIntegerTypeIsSigned(elementType)) {
591         // i8
592         return bufferInfo<int8_t>(shapedType);
593       }
594       if (mlirIntegerTypeIsUnsigned(elementType)) {
595         // unsigned i8
596         return bufferInfo<uint8_t>(shapedType);
597       }
598     } else if (mlirTypeIsAInteger(elementType) &&
599                mlirIntegerTypeGetWidth(elementType) == 16) {
600       if (mlirIntegerTypeIsSignless(elementType) ||
601           mlirIntegerTypeIsSigned(elementType)) {
602         // i16
603         return bufferInfo<int16_t>(shapedType);
604       }
605       if (mlirIntegerTypeIsUnsigned(elementType)) {
606         // unsigned i16
607         return bufferInfo<uint16_t>(shapedType);
608       }
609     }
610 
611     // TODO: Currently crashes the program.
612     // Reported as https://github.com/pybind/pybind11/issues/3336
613     throw std::invalid_argument(
614         "unsupported data type for conversion to Python buffer");
615   }
616 
bindDerived(ClassTy & c)617   static void bindDerived(ClassTy &c) {
618     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
619         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
620                     py::arg("array"), py::arg("signless") = true,
621                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
622                     py::arg("context") = py::none(),
623                     kDenseElementsAttrGetDocstring)
624         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
625                     py::arg("shaped_type"), py::arg("element_attr"),
626                     "Gets a DenseElementsAttr where all values are the same")
627         .def_property_readonly("is_splat",
628                                [](PyDenseElementsAttribute &self) -> bool {
629                                  return mlirDenseElementsAttrIsSplat(self);
630                                })
631         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
632   }
633 
634 private:
isUnsignedIntegerFormat(const std::string & format)635   static bool isUnsignedIntegerFormat(const std::string &format) {
636     if (format.empty())
637       return false;
638     char code = format[0];
639     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
640            code == 'Q';
641   }
642 
isSignedIntegerFormat(const std::string & format)643   static bool isSignedIntegerFormat(const std::string &format) {
644     if (format.empty())
645       return false;
646     char code = format[0];
647     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
648            code == 'q';
649   }
650 
651   template <typename Type>
bufferInfo(MlirType shapedType,const char * explicitFormat=nullptr)652   py::buffer_info bufferInfo(MlirType shapedType,
653                              const char *explicitFormat = nullptr) {
654     intptr_t rank = mlirShapedTypeGetRank(shapedType);
655     // Prepare the data for the buffer_info.
656     // Buffer is configured for read-only access below.
657     Type *data = static_cast<Type *>(
658         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
659     // Prepare the shape for the buffer_info.
660     SmallVector<intptr_t, 4> shape;
661     for (intptr_t i = 0; i < rank; ++i)
662       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
663     // Prepare the strides for the buffer_info.
664     SmallVector<intptr_t, 4> strides;
665     intptr_t strideFactor = 1;
666     for (intptr_t i = 1; i < rank; ++i) {
667       strideFactor = 1;
668       for (intptr_t j = i; j < rank; ++j) {
669         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
670       }
671       strides.push_back(sizeof(Type) * strideFactor);
672     }
673     strides.push_back(sizeof(Type));
674     std::string format;
675     if (explicitFormat) {
676       format = explicitFormat;
677     } else {
678       format = py::format_descriptor<Type>::format();
679     }
680     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
681                            /*readonly=*/true);
682   }
683 }; // namespace
684 
685 /// Refinement of the PyDenseElementsAttribute for attributes containing integer
686 /// (and boolean) values. Supports element access.
687 class PyDenseIntElementsAttribute
688     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
689                                  PyDenseElementsAttribute> {
690 public:
691   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
692   static constexpr const char *pyClassName = "DenseIntElementsAttr";
693   using PyConcreteAttribute::PyConcreteAttribute;
694 
695   /// Returns the element at the given linear position. Asserts if the index is
696   /// out of range.
dunderGetItem(intptr_t pos)697   py::int_ dunderGetItem(intptr_t pos) {
698     if (pos < 0 || pos >= dunderLen()) {
699       throw SetPyError(PyExc_IndexError,
700                        "attempt to access out of bounds element");
701     }
702 
703     MlirType type = mlirAttributeGetType(*this);
704     type = mlirShapedTypeGetElementType(type);
705     assert(mlirTypeIsAInteger(type) &&
706            "expected integer element type in dense int elements attribute");
707     // Dispatch element extraction to an appropriate C function based on the
708     // elemental type of the attribute. py::int_ is implicitly constructible
709     // from any C++ integral type and handles bitwidth correctly.
710     // TODO: consider caching the type properties in the constructor to avoid
711     // querying them on each element access.
712     unsigned width = mlirIntegerTypeGetWidth(type);
713     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
714     if (isUnsigned) {
715       if (width == 1) {
716         return mlirDenseElementsAttrGetBoolValue(*this, pos);
717       }
718       if (width == 8) {
719         return mlirDenseElementsAttrGetUInt8Value(*this, pos);
720       }
721       if (width == 16) {
722         return mlirDenseElementsAttrGetUInt16Value(*this, pos);
723       }
724       if (width == 32) {
725         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
726       }
727       if (width == 64) {
728         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
729       }
730     } else {
731       if (width == 1) {
732         return mlirDenseElementsAttrGetBoolValue(*this, pos);
733       }
734       if (width == 8) {
735         return mlirDenseElementsAttrGetInt8Value(*this, pos);
736       }
737       if (width == 16) {
738         return mlirDenseElementsAttrGetInt16Value(*this, pos);
739       }
740       if (width == 32) {
741         return mlirDenseElementsAttrGetInt32Value(*this, pos);
742       }
743       if (width == 64) {
744         return mlirDenseElementsAttrGetInt64Value(*this, pos);
745       }
746     }
747     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
748   }
749 
bindDerived(ClassTy & c)750   static void bindDerived(ClassTy &c) {
751     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
752   }
753 };
754 
755 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
756 public:
757   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
758   static constexpr const char *pyClassName = "DictAttr";
759   using PyConcreteAttribute::PyConcreteAttribute;
760 
dunderLen()761   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
762 
dunderContains(const std::string & name)763   bool dunderContains(const std::string &name) {
764     return !mlirAttributeIsNull(
765         mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
766   }
767 
bindDerived(ClassTy & c)768   static void bindDerived(ClassTy &c) {
769     c.def("__contains__", &PyDictAttribute::dunderContains);
770     c.def("__len__", &PyDictAttribute::dunderLen);
771     c.def_static(
772         "get",
773         [](py::dict attributes, DefaultingPyMlirContext context) {
774           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
775           mlirNamedAttributes.reserve(attributes.size());
776           for (auto &it : attributes) {
777             auto &mlirAttr = it.second.cast<PyAttribute &>();
778             auto name = it.first.cast<std::string>();
779             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
780                 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
781                                   toMlirStringRef(name)),
782                 mlirAttr));
783           }
784           MlirAttribute attr =
785               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
786                                     mlirNamedAttributes.data());
787           return PyDictAttribute(context->getRef(), attr);
788         },
789         py::arg("value") = py::dict(), py::arg("context") = py::none(),
790         "Gets an uniqued dict attribute");
791     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
792       MlirAttribute attr =
793           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
794       if (mlirAttributeIsNull(attr)) {
795         throw SetPyError(PyExc_KeyError,
796                          "attempt to access a non-existent attribute");
797       }
798       return PyAttribute(self.getContext(), attr);
799     });
800     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
801       if (index < 0 || index >= self.dunderLen()) {
802         throw SetPyError(PyExc_IndexError,
803                          "attempt to access out of bounds attribute");
804       }
805       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
806       return PyNamedAttribute(
807           namedAttr.attribute,
808           std::string(mlirIdentifierStr(namedAttr.name).data));
809     });
810   }
811 };
812 
813 /// Refinement of PyDenseElementsAttribute for attributes containing
814 /// floating-point values. Supports element access.
815 class PyDenseFPElementsAttribute
816     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
817                                  PyDenseElementsAttribute> {
818 public:
819   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
820   static constexpr const char *pyClassName = "DenseFPElementsAttr";
821   using PyConcreteAttribute::PyConcreteAttribute;
822 
dunderGetItem(intptr_t pos)823   py::float_ dunderGetItem(intptr_t pos) {
824     if (pos < 0 || pos >= dunderLen()) {
825       throw SetPyError(PyExc_IndexError,
826                        "attempt to access out of bounds element");
827     }
828 
829     MlirType type = mlirAttributeGetType(*this);
830     type = mlirShapedTypeGetElementType(type);
831     // Dispatch element extraction to an appropriate C function based on the
832     // elemental type of the attribute. py::float_ is implicitly constructible
833     // from float and double.
834     // TODO: consider caching the type properties in the constructor to avoid
835     // querying them on each element access.
836     if (mlirTypeIsAF32(type)) {
837       return mlirDenseElementsAttrGetFloatValue(*this, pos);
838     }
839     if (mlirTypeIsAF64(type)) {
840       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
841     }
842     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
843   }
844 
bindDerived(ClassTy & c)845   static void bindDerived(ClassTy &c) {
846     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
847   }
848 };
849 
850 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
851 public:
852   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
853   static constexpr const char *pyClassName = "TypeAttr";
854   using PyConcreteAttribute::PyConcreteAttribute;
855 
bindDerived(ClassTy & c)856   static void bindDerived(ClassTy &c) {
857     c.def_static(
858         "get",
859         [](PyType value, DefaultingPyMlirContext context) {
860           MlirAttribute attr = mlirTypeAttrGet(value.get());
861           return PyTypeAttribute(context->getRef(), attr);
862         },
863         py::arg("value"), py::arg("context") = py::none(),
864         "Gets a uniqued Type attribute");
865     c.def_property_readonly("value", [](PyTypeAttribute &self) {
866       return PyType(self.getContext()->getRef(),
867                     mlirTypeAttrGetValue(self.get()));
868     });
869   }
870 };
871 
872 /// Unit Attribute subclass. Unit attributes don't have values.
873 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
874 public:
875   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
876   static constexpr const char *pyClassName = "UnitAttr";
877   using PyConcreteAttribute::PyConcreteAttribute;
878 
bindDerived(ClassTy & c)879   static void bindDerived(ClassTy &c) {
880     c.def_static(
881         "get",
882         [](DefaultingPyMlirContext context) {
883           return PyUnitAttribute(context->getRef(),
884                                  mlirUnitAttrGet(context->get()));
885         },
886         py::arg("context") = py::none(), "Create a Unit attribute.");
887   }
888 };
889 
890 } // namespace
891 
populateIRAttributes(py::module & m)892 void mlir::python::populateIRAttributes(py::module &m) {
893   PyAffineMapAttribute::bind(m);
894   PyArrayAttribute::bind(m);
895   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
896   PyBoolAttribute::bind(m);
897   PyDenseElementsAttribute::bind(m);
898   PyDenseFPElementsAttribute::bind(m);
899   PyDenseIntElementsAttribute::bind(m);
900   PyDictAttribute::bind(m);
901   PyFlatSymbolRefAttribute::bind(m);
902   PyOpaqueAttribute::bind(m);
903   PyFloatAttribute::bind(m);
904   PyIntegerAttribute::bind(m);
905   PyStringAttribute::bind(m);
906   PyTypeAttribute::bind(m);
907   PyUnitAttribute::bind(m);
908 }
909