1 //===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir-c/Dialect/Quant.h"
10 #include "mlir-c/IR.h"
11 #include "mlir/Bindings/Python/PybindAdaptors.h"
12
13 namespace py = pybind11;
14 using namespace llvm;
15 using namespace mlir;
16 using namespace mlir::python::adaptors;
17
populateDialectQuantSubmodule(const py::module & m)18 static void populateDialectQuantSubmodule(const py::module &m) {
19 //===-------------------------------------------------------------------===//
20 // QuantizedType
21 //===-------------------------------------------------------------------===//
22
23 auto quantizedType =
24 mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType);
25 quantizedType.def_staticmethod(
26 "default_minimum_for_integer",
27 [](bool isSigned, unsigned integralWidth) {
28 return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned,
29 integralWidth);
30 },
31 "Default minimum value for the integer with the specified signedness and "
32 "bit width.",
33 py::arg("is_signed"), py::arg("integral_width"));
34 quantizedType.def_staticmethod(
35 "default_maximum_for_integer",
36 [](bool isSigned, unsigned integralWidth) {
37 return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned,
38 integralWidth);
39 },
40 "Default maximum value for the integer with the specified signedness and "
41 "bit width.",
42 py::arg("is_signed"), py::arg("integral_width"));
43 quantizedType.def_property_readonly(
44 "expressed_type",
45 [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); },
46 "Type expressed by this quantized type.");
47 quantizedType.def_property_readonly(
48 "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); },
49 "Flags of this quantized type (named accessors should be preferred to "
50 "this)");
51 quantizedType.def_property_readonly(
52 "is_signed",
53 [](MlirType type) { return mlirQuantizedTypeIsSigned(type); },
54 "Signedness of this quantized type.");
55 quantizedType.def_property_readonly(
56 "storage_type",
57 [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); },
58 "Storage type backing this quantized type.");
59 quantizedType.def_property_readonly(
60 "storage_type_min",
61 [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); },
62 "The minimum value held by the storage type of this quantized type.");
63 quantizedType.def_property_readonly(
64 "storage_type_max",
65 [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); },
66 "The maximum value held by the storage type of this quantized type.");
67 quantizedType.def_property_readonly(
68 "storage_type_integral_width",
69 [](MlirType type) {
70 return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
71 },
72 "The bitwidth of the storage type of this quantized type.");
73 quantizedType.def(
74 "is_compatible_expressed_type",
75 [](MlirType type, MlirType candidate) {
76 return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
77 },
78 "Checks whether the candidate type can be expressed by this quantized "
79 "type.",
80 py::arg("candidate"));
81 quantizedType.def_property_readonly(
82 "quantized_element_type",
83 [](MlirType type) {
84 return mlirQuantizedTypeGetQuantizedElementType(type);
85 },
86 "Element type of this quantized type expressed as quantized type.");
87 quantizedType.def(
88 "cast_from_storage_type",
89 [](MlirType type, MlirType candidate) {
90 MlirType castResult =
91 mlirQuantizedTypeCastFromStorageType(type, candidate);
92 if (!mlirTypeIsNull(castResult))
93 return castResult;
94 throw py::type_error("Invalid cast.");
95 },
96 "Casts from a type based on the storage type of this quantized type to a "
97 "corresponding type based on the quantized type. Raises TypeError if the "
98 "cast is not valid.",
99 py::arg("candidate"));
100 quantizedType.def_staticmethod(
101 "cast_to_storage_type",
102 [](MlirType type) {
103 MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
104 if (!mlirTypeIsNull(castResult))
105 return castResult;
106 throw py::type_error("Invalid cast.");
107 },
108 "Casts from a type based on a quantized type to a corresponding type "
109 "based on the storage type of this quantized type. Raises TypeError if "
110 "the cast is not valid.",
111 py::arg("type"));
112 quantizedType.def(
113 "cast_from_expressed_type",
114 [](MlirType type, MlirType candidate) {
115 MlirType castResult =
116 mlirQuantizedTypeCastFromExpressedType(type, candidate);
117 if (!mlirTypeIsNull(castResult))
118 return castResult;
119 throw py::type_error("Invalid cast.");
120 },
121 "Casts from a type based on the expressed type of this quantized type to "
122 "a corresponding type based on the quantized type. Raises TypeError if "
123 "the cast is not valid.",
124 py::arg("candidate"));
125 quantizedType.def_staticmethod(
126 "cast_to_expressed_type",
127 [](MlirType type) {
128 MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
129 if (!mlirTypeIsNull(castResult))
130 return castResult;
131 throw py::type_error("Invalid cast.");
132 },
133 "Casts from a type based on a quantized type to a corresponding type "
134 "based on the expressed type of this quantized type. Raises TypeError if "
135 "the cast is not valid.",
136 py::arg("type"));
137 quantizedType.def(
138 "cast_expressed_to_storage_type",
139 [](MlirType type, MlirType candidate) {
140 MlirType castResult =
141 mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
142 if (!mlirTypeIsNull(castResult))
143 return castResult;
144 throw py::type_error("Invalid cast.");
145 },
146 "Casts from a type based on the expressed type of this quantized type to "
147 "a corresponding type based on the storage type. Raises TypeError if the "
148 "cast is not valid.",
149 py::arg("candidate"));
150
151 quantizedType.get_class().attr("FLAG_SIGNED") =
152 mlirQuantizedTypeGetSignedFlag();
153
154 //===-------------------------------------------------------------------===//
155 // AnyQuantizedType
156 //===-------------------------------------------------------------------===//
157
158 auto anyQuantizedType =
159 mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType,
160 quantizedType.get_class());
161 anyQuantizedType.def_classmethod(
162 "get",
163 [](py::object cls, unsigned flags, MlirType storageType,
164 MlirType expressedType, int64_t storageTypeMin,
165 int64_t storageTypeMax) {
166 return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
167 storageTypeMin, storageTypeMax));
168 },
169 "Gets an instance of AnyQuantizedType in the same context as the "
170 "provided storage type.",
171 py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
172 py::arg("expressed_type"), py::arg("storage_type_min"),
173 py::arg("storage_type_max"));
174
175 //===-------------------------------------------------------------------===//
176 // UniformQuantizedType
177 //===-------------------------------------------------------------------===//
178
179 auto uniformQuantizedType = mlir_type_subclass(
180 m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType,
181 quantizedType.get_class());
182 uniformQuantizedType.def_classmethod(
183 "get",
184 [](py::object cls, unsigned flags, MlirType storageType,
185 MlirType expressedType, double scale, int64_t zeroPoint,
186 int64_t storageTypeMin, int64_t storageTypeMax) {
187 return cls(mlirUniformQuantizedTypeGet(flags, storageType,
188 expressedType, scale, zeroPoint,
189 storageTypeMin, storageTypeMax));
190 },
191 "Gets an instance of UniformQuantizedType in the same context as the "
192 "provided storage type.",
193 py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
194 py::arg("expressed_type"), py::arg("scale"), py::arg("zero_point"),
195 py::arg("storage_type_min"), py::arg("storage_type_max"));
196 uniformQuantizedType.def_property_readonly(
197 "scale",
198 [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); },
199 "The scale designates the difference between the real values "
200 "corresponding to consecutive quantized values differing by 1.");
201 uniformQuantizedType.def_property_readonly(
202 "zero_point",
203 [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); },
204 "The storage value corresponding to the real value 0 in the affine "
205 "equation.");
206 uniformQuantizedType.def_property_readonly(
207 "is_fixed_point",
208 [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); },
209 "Fixed point values are real numbers divided by a scale.");
210
211 //===-------------------------------------------------------------------===//
212 // UniformQuantizedPerAxisType
213 //===-------------------------------------------------------------------===//
214 auto uniformQuantizedPerAxisType = mlir_type_subclass(
215 m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType,
216 quantizedType.get_class());
217 uniformQuantizedPerAxisType.def_classmethod(
218 "get",
219 [](py::object cls, unsigned flags, MlirType storageType,
220 MlirType expressedType, std::vector<double> scales,
221 std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
222 int64_t storageTypeMin, int64_t storageTypeMax) {
223 if (scales.size() != zeroPoints.size())
224 throw py::value_error(
225 "Mismatching number of scales and zero points.");
226 auto nDims = static_cast<intptr_t>(scales.size());
227 return cls(mlirUniformQuantizedPerAxisTypeGet(
228 flags, storageType, expressedType, nDims, scales.data(),
229 zeroPoints.data(), quantizedDimension, storageTypeMin,
230 storageTypeMax));
231 },
232 "Gets an instance of UniformQuantizedPerAxisType in the same context as "
233 "the provided storage type.",
234 py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
235 py::arg("expressed_type"), py::arg("scales"), py::arg("zero_points"),
236 py::arg("quantized_dimension"), py::arg("storage_type_min"),
237 py::arg("storage_type_max"));
238 uniformQuantizedPerAxisType.def_property_readonly(
239 "scales",
240 [](MlirType type) {
241 intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
242 std::vector<double> scales;
243 scales.reserve(nDim);
244 for (intptr_t i = 0; i < nDim; ++i) {
245 double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i);
246 scales.push_back(scale);
247 }
248 },
249 "The scales designate the difference between the real values "
250 "corresponding to consecutive quantized values differing by 1. The ith "
251 "scale corresponds to the ith slice in the quantized_dimension.");
252 uniformQuantizedPerAxisType.def_property_readonly(
253 "zero_points",
254 [](MlirType type) {
255 intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
256 std::vector<int64_t> zeroPoints;
257 zeroPoints.reserve(nDim);
258 for (intptr_t i = 0; i < nDim; ++i) {
259 int64_t zeroPoint =
260 mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i);
261 zeroPoints.push_back(zeroPoint);
262 }
263 },
264 "the storage values corresponding to the real value 0 in the affine "
265 "equation. The ith zero point corresponds to the ith slice in the "
266 "quantized_dimension.");
267 uniformQuantizedPerAxisType.def_property_readonly(
268 "quantized_dimension",
269 [](MlirType type) {
270 return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
271 },
272 "Specifies the dimension of the shape that the scales and zero points "
273 "correspond to.");
274 uniformQuantizedPerAxisType.def_property_readonly(
275 "is_fixed_point",
276 [](MlirType type) {
277 return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
278 },
279 "Fixed point values are real numbers divided by a scale.");
280
281 //===-------------------------------------------------------------------===//
282 // CalibratedQuantizedType
283 //===-------------------------------------------------------------------===//
284
285 auto calibratedQuantizedType = mlir_type_subclass(
286 m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType,
287 quantizedType.get_class());
288 calibratedQuantizedType.def_classmethod(
289 "get",
290 [](py::object cls, MlirType expressedType, double min, double max) {
291 return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max));
292 },
293 "Gets an instance of CalibratedQuantizedType in the same context as the "
294 "provided expressed type.",
295 py::arg("cls"), py::arg("expressed_type"), py::arg("min"),
296 py::arg("max"));
297 calibratedQuantizedType.def_property_readonly("min", [](MlirType type) {
298 return mlirCalibratedQuantizedTypeGetMin(type);
299 });
300 calibratedQuantizedType.def_property_readonly("max", [](MlirType type) {
301 return mlirCalibratedQuantizedTypeGetMax(type);
302 });
303 }
304
PYBIND11_MODULE(_mlirDialectsQuant,m)305 PYBIND11_MODULE(_mlirDialectsQuant, m) {
306 m.doc() = "MLIR Quantization dialect";
307
308 populateDialectQuantSubmodule(m);
309 }
310