1 //===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Dialects.h"
10 #include "mlir-c/Dialect/Quant.h"
11 #include "mlir-c/IR.h"
12 #include "mlir/Bindings/Python/PybindAdaptors.h"
13 
14 namespace py = pybind11;
15 using namespace llvm;
16 using namespace mlir;
17 using namespace mlir::python::adaptors;
18 
19 void mlir::python::populateDialectQuantSubmodule(const py::module &m,
20                                                  const py::module &irModule) {
21   auto typeClass = irModule.attr("Type");
22 
23   //===-------------------------------------------------------------------===//
24   // QuantizedType
25   //===-------------------------------------------------------------------===//
26 
27   auto quantizedType = mlir_type_subclass(m, "QuantizedType",
28                                           mlirTypeIsAQuantizedType, typeClass);
29   quantizedType.def_staticmethod(
30       "default_minimum_for_integer",
31       [](bool isSigned, unsigned integralWidth) {
32         return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned,
33                                                             integralWidth);
34       },
35       "Default minimum value for the integer with the specified signedness and "
36       "bit width.",
37       py::arg("is_signed"), py::arg("integral_width"));
38   quantizedType.def_staticmethod(
39       "default_maximum_for_integer",
40       [](bool isSigned, unsigned integralWidth) {
41         return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned,
42                                                             integralWidth);
43       },
44       "Default maximum value for the integer with the specified signedness and "
45       "bit width.",
46       py::arg("is_signed"), py::arg("integral_width"));
47   quantizedType.def_property_readonly(
48       "expressed_type",
49       [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); },
50       "Type expressed by this quantized type.");
51   quantizedType.def_property_readonly(
52       "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); },
53       "Flags of this quantized type (named accessors should be preferred to "
54       "this)");
55   quantizedType.def_property_readonly(
56       "is_signed",
57       [](MlirType type) { return mlirQuantizedTypeIsSigned(type); },
58       "Signedness of this quantized type.");
59   quantizedType.def_property_readonly(
60       "storage_type",
61       [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); },
62       "Storage type backing this quantized type.");
63   quantizedType.def_property_readonly(
64       "storage_type_min",
65       [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); },
66       "The minimum value held by the storage type of this quantized type.");
67   quantizedType.def_property_readonly(
68       "storage_type_max",
69       [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); },
70       "The maximum value held by the storage type of this quantized type.");
71   quantizedType.def_property_readonly(
72       "storage_type_integral_width",
73       [](MlirType type) {
74         return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
75       },
76       "The bitwidth of the storage type of this quantized type.");
77   quantizedType.def(
78       "is_compatible_expressed_type",
79       [](MlirType type, MlirType candidate) {
80         return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
81       },
82       "Checks whether the candidate type can be expressed by this quantized "
83       "type.",
84       py::arg("candidate"));
85   quantizedType.def_property_readonly(
86       "quantized_element_type",
87       [](MlirType type) {
88         return mlirQuantizedTypeGetQuantizedElementType(type);
89       },
90       "Element type of this quantized type expressed as quantized type.");
91   quantizedType.def(
92       "cast_from_storage_type",
93       [](MlirType type, MlirType candidate) {
94         MlirType castResult =
95             mlirQuantizedTypeCastFromStorageType(type, candidate);
96         if (!mlirTypeIsNull(castResult))
97           return castResult;
98         throw py::type_error("Invalid cast.");
99       },
100       "Casts from a type based on the storage type of this quantized type to a "
101       "corresponding type based on the quantized type. Raises TypeError if the "
102       "cast is not valid.",
103       py::arg("candidate"));
104   quantizedType.def_staticmethod(
105       "cast_to_storage_type",
106       [](MlirType type) {
107         MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
108         if (!mlirTypeIsNull(castResult))
109           return castResult;
110         throw py::type_error("Invalid cast.");
111       },
112       "Casts from a type based on a quantized type to a corresponding type "
113       "based on the storage type of this quantized type. Raises TypeError if "
114       "the cast is not valid.",
115       py::arg("type"));
116   quantizedType.def(
117       "cast_from_expressed_type",
118       [](MlirType type, MlirType candidate) {
119         MlirType castResult =
120             mlirQuantizedTypeCastFromExpressedType(type, candidate);
121         if (!mlirTypeIsNull(castResult))
122           return castResult;
123         throw py::type_error("Invalid cast.");
124       },
125       "Casts from a type based on the expressed type of this quantized type to "
126       "a corresponding type based on the quantized type. Raises TypeError if "
127       "the cast is not valid.",
128       py::arg("candidate"));
129   quantizedType.def_staticmethod(
130       "cast_to_expressed_type",
131       [](MlirType type) {
132         MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
133         if (!mlirTypeIsNull(castResult))
134           return castResult;
135         throw py::type_error("Invalid cast.");
136       },
137       "Casts from a type based on a quantized type to a corresponding type "
138       "based on the expressed type of this quantized type. Raises TypeError if "
139       "the cast is not valid.",
140       py::arg("type"));
141   quantizedType.def(
142       "cast_expressed_to_storage_type",
143       [](MlirType type, MlirType candidate) {
144         MlirType castResult =
145             mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
146         if (!mlirTypeIsNull(castResult))
147           return castResult;
148         throw py::type_error("Invalid cast.");
149       },
150       "Casts from a type based on the expressed type of this quantized type to "
151       "a corresponding type based on the storage type. Raises TypeError if the "
152       "cast is not valid.",
153       py::arg("candidate"));
154 
155   quantizedType.get_class().attr("FLAG_SIGNED") =
156       mlirQuantizedTypeGetSignedFlag();
157 
158   //===-------------------------------------------------------------------===//
159   // AnyQuantizedType
160   //===-------------------------------------------------------------------===//
161 
162   auto anyQuantizedType =
163       mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType,
164                          quantizedType.get_class());
165   anyQuantizedType.def_classmethod(
166       "get",
167       [](py::object cls, unsigned flags, MlirType storageType,
168          MlirType expressedType, int64_t storageTypeMin,
169          int64_t storageTypeMax) {
170         return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
171                                            storageTypeMin, storageTypeMax));
172       },
173       "Gets an instance of AnyQuantizedType in the same context as the "
174       "provided storage type.",
175       py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
176       py::arg("expressed_type"), py::arg("storage_type_min"),
177       py::arg("storage_type_max"));
178 
179   //===-------------------------------------------------------------------===//
180   // UniformQuantizedType
181   //===-------------------------------------------------------------------===//
182 
183   auto uniformQuantizedType = mlir_type_subclass(
184       m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType,
185       quantizedType.get_class());
186   uniformQuantizedType.def_classmethod(
187       "get",
188       [](py::object cls, unsigned flags, MlirType storageType,
189          MlirType expressedType, double scale, int64_t zeroPoint,
190          int64_t storageTypeMin, int64_t storageTypeMax) {
191         return cls(mlirUniformQuantizedTypeGet(flags, storageType,
192                                                expressedType, scale, zeroPoint,
193                                                storageTypeMin, storageTypeMax));
194       },
195       "Gets an instance of UniformQuantizedType in the same context as the "
196       "provided storage type.",
197       py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
198       py::arg("expressed_type"), py::arg("scale"), py::arg("zero_point"),
199       py::arg("storage_type_min"), py::arg("storage_type_max"));
200   uniformQuantizedType.def_property_readonly(
201       "scale",
202       [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); },
203       "The scale designates the difference between the real values "
204       "corresponding to consecutive quantized values differing by 1.");
205   uniformQuantizedType.def_property_readonly(
206       "zero_point",
207       [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); },
208       "The storage value corresponding to the real value 0 in the affine "
209       "equation.");
210   uniformQuantizedType.def_property_readonly(
211       "is_fixed_point",
212       [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); },
213       "Fixed point values are real numbers divided by a scale.");
214 
215   //===-------------------------------------------------------------------===//
216   // UniformQuantizedPerAxisType
217   //===-------------------------------------------------------------------===//
218   auto uniformQuantizedPerAxisType = mlir_type_subclass(
219       m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType,
220       quantizedType.get_class());
221   uniformQuantizedPerAxisType.def_classmethod(
222       "get",
223       [](py::object cls, unsigned flags, MlirType storageType,
224          MlirType expressedType, std::vector<double> scales,
225          std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
226          int64_t storageTypeMin, int64_t storageTypeMax) {
227         if (scales.size() != zeroPoints.size())
228           throw py::value_error(
229               "Mismatching number of scales and zero points.");
230         auto nDims = static_cast<intptr_t>(scales.size());
231         return cls(mlirUniformQuantizedPerAxisTypeGet(
232             flags, storageType, expressedType, nDims, scales.data(),
233             zeroPoints.data(), quantizedDimension, storageTypeMin,
234             storageTypeMax));
235       },
236       "Gets an instance of UniformQuantizedPerAxisType in the same context as "
237       "the provided storage type.",
238       py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
239       py::arg("expressed_type"), py::arg("scales"), py::arg("zero_points"),
240       py::arg("quantized_dimension"), py::arg("storage_type_min"),
241       py::arg("storage_type_max"));
242   uniformQuantizedPerAxisType.def_property_readonly(
243       "scales",
244       [](MlirType type) {
245         intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
246         std::vector<double> scales;
247         scales.reserve(nDim);
248         for (intptr_t i = 0; i < nDim; ++i) {
249           double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i);
250           scales.push_back(scale);
251         }
252       },
253       "The scales designate the difference between the real values "
254       "corresponding to consecutive quantized values differing by 1. The ith "
255       "scale corresponds to the ith slice in the quantized_dimension.");
256   uniformQuantizedPerAxisType.def_property_readonly(
257       "zero_points",
258       [](MlirType type) {
259         intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
260         std::vector<int64_t> zeroPoints;
261         zeroPoints.reserve(nDim);
262         for (intptr_t i = 0; i < nDim; ++i) {
263           int64_t zeroPoint =
264               mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i);
265           zeroPoints.push_back(zeroPoint);
266         }
267       },
268       "the storage values corresponding to the real value 0 in the affine "
269       "equation. The ith zero point corresponds to the ith slice in the "
270       "quantized_dimension.");
271   uniformQuantizedPerAxisType.def_property_readonly(
272       "quantized_dimension",
273       [](MlirType type) {
274         return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
275       },
276       "Specifies the dimension of the shape that the scales and zero points "
277       "correspond to.");
278   uniformQuantizedPerAxisType.def_property_readonly(
279       "is_fixed_point",
280       [](MlirType type) {
281         return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
282       },
283       "Fixed point values are real numbers divided by a scale.");
284 
285   //===-------------------------------------------------------------------===//
286   // CalibratedQuantizedType
287   //===-------------------------------------------------------------------===//
288 
289   auto calibratedQuantizedType = mlir_type_subclass(
290       m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType,
291       quantizedType.get_class());
292   calibratedQuantizedType.def_classmethod(
293       "get",
294       [](py::object cls, MlirType expressedType, double min, double max) {
295         return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max));
296       },
297       "Gets an instance of CalibratedQuantizedType in the same context as the "
298       "provided expressed type.",
299       py::arg("cls"), py::arg("expressed_type"), py::arg("min"),
300       py::arg("max"));
301   calibratedQuantizedType.def_property_readonly("min", [](MlirType type) {
302     return mlirCalibratedQuantizedTypeGetMin(type);
303   });
304   calibratedQuantizedType.def_property_readonly("max", [](MlirType type) {
305     return mlirCalibratedQuantizedTypeGetMax(type);
306   });
307 }
308