1 //===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/Quant.h"
10 #include "mlir-c/IR.h"
11 #include "mlir/Bindings/Python/PybindAdaptors.h"
12 
13 namespace py = pybind11;
14 using namespace llvm;
15 using namespace mlir;
16 using namespace mlir::python::adaptors;
17 
populateDialectQuantSubmodule(const py::module & m)18 static void populateDialectQuantSubmodule(const py::module &m) {
19   //===-------------------------------------------------------------------===//
20   // QuantizedType
21   //===-------------------------------------------------------------------===//
22 
23   auto quantizedType =
24       mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType);
25   quantizedType.def_staticmethod(
26       "default_minimum_for_integer",
27       [](bool isSigned, unsigned integralWidth) {
28         return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned,
29                                                             integralWidth);
30       },
31       "Default minimum value for the integer with the specified signedness and "
32       "bit width.",
33       py::arg("is_signed"), py::arg("integral_width"));
34   quantizedType.def_staticmethod(
35       "default_maximum_for_integer",
36       [](bool isSigned, unsigned integralWidth) {
37         return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned,
38                                                             integralWidth);
39       },
40       "Default maximum value for the integer with the specified signedness and "
41       "bit width.",
42       py::arg("is_signed"), py::arg("integral_width"));
43   quantizedType.def_property_readonly(
44       "expressed_type",
45       [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); },
46       "Type expressed by this quantized type.");
47   quantizedType.def_property_readonly(
48       "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); },
49       "Flags of this quantized type (named accessors should be preferred to "
50       "this)");
51   quantizedType.def_property_readonly(
52       "is_signed",
53       [](MlirType type) { return mlirQuantizedTypeIsSigned(type); },
54       "Signedness of this quantized type.");
55   quantizedType.def_property_readonly(
56       "storage_type",
57       [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); },
58       "Storage type backing this quantized type.");
59   quantizedType.def_property_readonly(
60       "storage_type_min",
61       [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); },
62       "The minimum value held by the storage type of this quantized type.");
63   quantizedType.def_property_readonly(
64       "storage_type_max",
65       [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); },
66       "The maximum value held by the storage type of this quantized type.");
67   quantizedType.def_property_readonly(
68       "storage_type_integral_width",
69       [](MlirType type) {
70         return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
71       },
72       "The bitwidth of the storage type of this quantized type.");
73   quantizedType.def(
74       "is_compatible_expressed_type",
75       [](MlirType type, MlirType candidate) {
76         return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
77       },
78       "Checks whether the candidate type can be expressed by this quantized "
79       "type.",
80       py::arg("candidate"));
81   quantizedType.def_property_readonly(
82       "quantized_element_type",
83       [](MlirType type) {
84         return mlirQuantizedTypeGetQuantizedElementType(type);
85       },
86       "Element type of this quantized type expressed as quantized type.");
87   quantizedType.def(
88       "cast_from_storage_type",
89       [](MlirType type, MlirType candidate) {
90         MlirType castResult =
91             mlirQuantizedTypeCastFromStorageType(type, candidate);
92         if (!mlirTypeIsNull(castResult))
93           return castResult;
94         throw py::type_error("Invalid cast.");
95       },
96       "Casts from a type based on the storage type of this quantized type to a "
97       "corresponding type based on the quantized type. Raises TypeError if the "
98       "cast is not valid.",
99       py::arg("candidate"));
100   quantizedType.def_staticmethod(
101       "cast_to_storage_type",
102       [](MlirType type) {
103         MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
104         if (!mlirTypeIsNull(castResult))
105           return castResult;
106         throw py::type_error("Invalid cast.");
107       },
108       "Casts from a type based on a quantized type to a corresponding type "
109       "based on the storage type of this quantized type. Raises TypeError if "
110       "the cast is not valid.",
111       py::arg("type"));
112   quantizedType.def(
113       "cast_from_expressed_type",
114       [](MlirType type, MlirType candidate) {
115         MlirType castResult =
116             mlirQuantizedTypeCastFromExpressedType(type, candidate);
117         if (!mlirTypeIsNull(castResult))
118           return castResult;
119         throw py::type_error("Invalid cast.");
120       },
121       "Casts from a type based on the expressed type of this quantized type to "
122       "a corresponding type based on the quantized type. Raises TypeError if "
123       "the cast is not valid.",
124       py::arg("candidate"));
125   quantizedType.def_staticmethod(
126       "cast_to_expressed_type",
127       [](MlirType type) {
128         MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
129         if (!mlirTypeIsNull(castResult))
130           return castResult;
131         throw py::type_error("Invalid cast.");
132       },
133       "Casts from a type based on a quantized type to a corresponding type "
134       "based on the expressed type of this quantized type. Raises TypeError if "
135       "the cast is not valid.",
136       py::arg("type"));
137   quantizedType.def(
138       "cast_expressed_to_storage_type",
139       [](MlirType type, MlirType candidate) {
140         MlirType castResult =
141             mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
142         if (!mlirTypeIsNull(castResult))
143           return castResult;
144         throw py::type_error("Invalid cast.");
145       },
146       "Casts from a type based on the expressed type of this quantized type to "
147       "a corresponding type based on the storage type. Raises TypeError if the "
148       "cast is not valid.",
149       py::arg("candidate"));
150 
151   quantizedType.get_class().attr("FLAG_SIGNED") =
152       mlirQuantizedTypeGetSignedFlag();
153 
154   //===-------------------------------------------------------------------===//
155   // AnyQuantizedType
156   //===-------------------------------------------------------------------===//
157 
158   auto anyQuantizedType =
159       mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType,
160                          quantizedType.get_class());
161   anyQuantizedType.def_classmethod(
162       "get",
163       [](py::object cls, unsigned flags, MlirType storageType,
164          MlirType expressedType, int64_t storageTypeMin,
165          int64_t storageTypeMax) {
166         return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
167                                            storageTypeMin, storageTypeMax));
168       },
169       "Gets an instance of AnyQuantizedType in the same context as the "
170       "provided storage type.",
171       py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
172       py::arg("expressed_type"), py::arg("storage_type_min"),
173       py::arg("storage_type_max"));
174 
175   //===-------------------------------------------------------------------===//
176   // UniformQuantizedType
177   //===-------------------------------------------------------------------===//
178 
179   auto uniformQuantizedType = mlir_type_subclass(
180       m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType,
181       quantizedType.get_class());
182   uniformQuantizedType.def_classmethod(
183       "get",
184       [](py::object cls, unsigned flags, MlirType storageType,
185          MlirType expressedType, double scale, int64_t zeroPoint,
186          int64_t storageTypeMin, int64_t storageTypeMax) {
187         return cls(mlirUniformQuantizedTypeGet(flags, storageType,
188                                                expressedType, scale, zeroPoint,
189                                                storageTypeMin, storageTypeMax));
190       },
191       "Gets an instance of UniformQuantizedType in the same context as the "
192       "provided storage type.",
193       py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
194       py::arg("expressed_type"), py::arg("scale"), py::arg("zero_point"),
195       py::arg("storage_type_min"), py::arg("storage_type_max"));
196   uniformQuantizedType.def_property_readonly(
197       "scale",
198       [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); },
199       "The scale designates the difference between the real values "
200       "corresponding to consecutive quantized values differing by 1.");
201   uniformQuantizedType.def_property_readonly(
202       "zero_point",
203       [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); },
204       "The storage value corresponding to the real value 0 in the affine "
205       "equation.");
206   uniformQuantizedType.def_property_readonly(
207       "is_fixed_point",
208       [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); },
209       "Fixed point values are real numbers divided by a scale.");
210 
211   //===-------------------------------------------------------------------===//
212   // UniformQuantizedPerAxisType
213   //===-------------------------------------------------------------------===//
214   auto uniformQuantizedPerAxisType = mlir_type_subclass(
215       m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType,
216       quantizedType.get_class());
217   uniformQuantizedPerAxisType.def_classmethod(
218       "get",
219       [](py::object cls, unsigned flags, MlirType storageType,
220          MlirType expressedType, std::vector<double> scales,
221          std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
222          int64_t storageTypeMin, int64_t storageTypeMax) {
223         if (scales.size() != zeroPoints.size())
224           throw py::value_error(
225               "Mismatching number of scales and zero points.");
226         auto nDims = static_cast<intptr_t>(scales.size());
227         return cls(mlirUniformQuantizedPerAxisTypeGet(
228             flags, storageType, expressedType, nDims, scales.data(),
229             zeroPoints.data(), quantizedDimension, storageTypeMin,
230             storageTypeMax));
231       },
232       "Gets an instance of UniformQuantizedPerAxisType in the same context as "
233       "the provided storage type.",
234       py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
235       py::arg("expressed_type"), py::arg("scales"), py::arg("zero_points"),
236       py::arg("quantized_dimension"), py::arg("storage_type_min"),
237       py::arg("storage_type_max"));
238   uniformQuantizedPerAxisType.def_property_readonly(
239       "scales",
240       [](MlirType type) {
241         intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
242         std::vector<double> scales;
243         scales.reserve(nDim);
244         for (intptr_t i = 0; i < nDim; ++i) {
245           double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i);
246           scales.push_back(scale);
247         }
248       },
249       "The scales designate the difference between the real values "
250       "corresponding to consecutive quantized values differing by 1. The ith "
251       "scale corresponds to the ith slice in the quantized_dimension.");
252   uniformQuantizedPerAxisType.def_property_readonly(
253       "zero_points",
254       [](MlirType type) {
255         intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
256         std::vector<int64_t> zeroPoints;
257         zeroPoints.reserve(nDim);
258         for (intptr_t i = 0; i < nDim; ++i) {
259           int64_t zeroPoint =
260               mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i);
261           zeroPoints.push_back(zeroPoint);
262         }
263       },
264       "the storage values corresponding to the real value 0 in the affine "
265       "equation. The ith zero point corresponds to the ith slice in the "
266       "quantized_dimension.");
267   uniformQuantizedPerAxisType.def_property_readonly(
268       "quantized_dimension",
269       [](MlirType type) {
270         return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
271       },
272       "Specifies the dimension of the shape that the scales and zero points "
273       "correspond to.");
274   uniformQuantizedPerAxisType.def_property_readonly(
275       "is_fixed_point",
276       [](MlirType type) {
277         return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
278       },
279       "Fixed point values are real numbers divided by a scale.");
280 
281   //===-------------------------------------------------------------------===//
282   // CalibratedQuantizedType
283   //===-------------------------------------------------------------------===//
284 
285   auto calibratedQuantizedType = mlir_type_subclass(
286       m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType,
287       quantizedType.get_class());
288   calibratedQuantizedType.def_classmethod(
289       "get",
290       [](py::object cls, MlirType expressedType, double min, double max) {
291         return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max));
292       },
293       "Gets an instance of CalibratedQuantizedType in the same context as the "
294       "provided expressed type.",
295       py::arg("cls"), py::arg("expressed_type"), py::arg("min"),
296       py::arg("max"));
297   calibratedQuantizedType.def_property_readonly("min", [](MlirType type) {
298     return mlirCalibratedQuantizedTypeGetMin(type);
299   });
300   calibratedQuantizedType.def_property_readonly("max", [](MlirType type) {
301     return mlirCalibratedQuantizedTypeGetMax(type);
302   });
303 }
304 
PYBIND11_MODULE(_mlirDialectsQuant,m)305 PYBIND11_MODULE(_mlirDialectsQuant, m) {
306   m.doc() = "MLIR Quantization dialect";
307 
308   populateDialectQuantSubmodule(m);
309 }
310