1 //===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/Dialect/Quant.h" 10 #include "mlir-c/IR.h" 11 #include "mlir/Bindings/Python/PybindAdaptors.h" 12 13 namespace py = pybind11; 14 using namespace llvm; 15 using namespace mlir; 16 using namespace mlir::python::adaptors; 17 18 static void populateDialectQuantSubmodule(const py::module &m) { 19 //===-------------------------------------------------------------------===// 20 // QuantizedType 21 //===-------------------------------------------------------------------===// 22 23 auto quantizedType = 24 mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType); 25 quantizedType.def_staticmethod( 26 "default_minimum_for_integer", 27 [](bool isSigned, unsigned integralWidth) { 28 return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned, 29 integralWidth); 30 }, 31 "Default minimum value for the integer with the specified signedness and " 32 "bit width.", 33 py::arg("is_signed"), py::arg("integral_width")); 34 quantizedType.def_staticmethod( 35 "default_maximum_for_integer", 36 [](bool isSigned, unsigned integralWidth) { 37 return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned, 38 integralWidth); 39 }, 40 "Default maximum value for the integer with the specified signedness and " 41 "bit width.", 42 py::arg("is_signed"), py::arg("integral_width")); 43 quantizedType.def_property_readonly( 44 "expressed_type", 45 [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); }, 46 "Type expressed by this quantized type."); 47 quantizedType.def_property_readonly( 48 "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); }, 49 "Flags of this quantized type (named accessors should be preferred to " 50 "this)"); 51 quantizedType.def_property_readonly( 52 "is_signed", 53 [](MlirType type) { return mlirQuantizedTypeIsSigned(type); }, 54 "Signedness of this quantized type."); 55 quantizedType.def_property_readonly( 56 "storage_type", 57 [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); }, 58 "Storage type backing this quantized type."); 59 quantizedType.def_property_readonly( 60 "storage_type_min", 61 [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); }, 62 "The minimum value held by the storage type of this quantized type."); 63 quantizedType.def_property_readonly( 64 "storage_type_max", 65 [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); }, 66 "The maximum value held by the storage type of this quantized type."); 67 quantizedType.def_property_readonly( 68 "storage_type_integral_width", 69 [](MlirType type) { 70 return mlirQuantizedTypeGetStorageTypeIntegralWidth(type); 71 }, 72 "The bitwidth of the storage type of this quantized type."); 73 quantizedType.def( 74 "is_compatible_expressed_type", 75 [](MlirType type, MlirType candidate) { 76 return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate); 77 }, 78 "Checks whether the candidate type can be expressed by this quantized " 79 "type.", 80 py::arg("candidate")); 81 quantizedType.def_property_readonly( 82 "quantized_element_type", 83 [](MlirType type) { 84 return mlirQuantizedTypeGetQuantizedElementType(type); 85 }, 86 "Element type of this quantized type expressed as quantized type."); 87 quantizedType.def( 88 "cast_from_storage_type", 89 [](MlirType type, MlirType candidate) { 90 MlirType castResult = 91 mlirQuantizedTypeCastFromStorageType(type, candidate); 92 if (!mlirTypeIsNull(castResult)) 93 return castResult; 94 throw py::type_error("Invalid cast."); 95 }, 96 "Casts from a type based on the storage type of this quantized type to a " 97 "corresponding type based on the quantized type. Raises TypeError if the " 98 "cast is not valid.", 99 py::arg("candidate")); 100 quantizedType.def_staticmethod( 101 "cast_to_storage_type", 102 [](MlirType type) { 103 MlirType castResult = mlirQuantizedTypeCastToStorageType(type); 104 if (!mlirTypeIsNull(castResult)) 105 return castResult; 106 throw py::type_error("Invalid cast."); 107 }, 108 "Casts from a type based on a quantized type to a corresponding type " 109 "based on the storage type of this quantized type. Raises TypeError if " 110 "the cast is not valid.", 111 py::arg("type")); 112 quantizedType.def( 113 "cast_from_expressed_type", 114 [](MlirType type, MlirType candidate) { 115 MlirType castResult = 116 mlirQuantizedTypeCastFromExpressedType(type, candidate); 117 if (!mlirTypeIsNull(castResult)) 118 return castResult; 119 throw py::type_error("Invalid cast."); 120 }, 121 "Casts from a type based on the expressed type of this quantized type to " 122 "a corresponding type based on the quantized type. Raises TypeError if " 123 "the cast is not valid.", 124 py::arg("candidate")); 125 quantizedType.def_staticmethod( 126 "cast_to_expressed_type", 127 [](MlirType type) { 128 MlirType castResult = mlirQuantizedTypeCastToExpressedType(type); 129 if (!mlirTypeIsNull(castResult)) 130 return castResult; 131 throw py::type_error("Invalid cast."); 132 }, 133 "Casts from a type based on a quantized type to a corresponding type " 134 "based on the expressed type of this quantized type. Raises TypeError if " 135 "the cast is not valid.", 136 py::arg("type")); 137 quantizedType.def( 138 "cast_expressed_to_storage_type", 139 [](MlirType type, MlirType candidate) { 140 MlirType castResult = 141 mlirQuantizedTypeCastExpressedToStorageType(type, candidate); 142 if (!mlirTypeIsNull(castResult)) 143 return castResult; 144 throw py::type_error("Invalid cast."); 145 }, 146 "Casts from a type based on the expressed type of this quantized type to " 147 "a corresponding type based on the storage type. Raises TypeError if the " 148 "cast is not valid.", 149 py::arg("candidate")); 150 151 quantizedType.get_class().attr("FLAG_SIGNED") = 152 mlirQuantizedTypeGetSignedFlag(); 153 154 //===-------------------------------------------------------------------===// 155 // AnyQuantizedType 156 //===-------------------------------------------------------------------===// 157 158 auto anyQuantizedType = 159 mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType, 160 quantizedType.get_class()); 161 anyQuantizedType.def_classmethod( 162 "get", 163 [](py::object cls, unsigned flags, MlirType storageType, 164 MlirType expressedType, int64_t storageTypeMin, 165 int64_t storageTypeMax) { 166 return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType, 167 storageTypeMin, storageTypeMax)); 168 }, 169 "Gets an instance of AnyQuantizedType in the same context as the " 170 "provided storage type.", 171 py::arg("cls"), py::arg("flags"), py::arg("storage_type"), 172 py::arg("expressed_type"), py::arg("storage_type_min"), 173 py::arg("storage_type_max")); 174 175 //===-------------------------------------------------------------------===// 176 // UniformQuantizedType 177 //===-------------------------------------------------------------------===// 178 179 auto uniformQuantizedType = mlir_type_subclass( 180 m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType, 181 quantizedType.get_class()); 182 uniformQuantizedType.def_classmethod( 183 "get", 184 [](py::object cls, unsigned flags, MlirType storageType, 185 MlirType expressedType, double scale, int64_t zeroPoint, 186 int64_t storageTypeMin, int64_t storageTypeMax) { 187 return cls(mlirUniformQuantizedTypeGet(flags, storageType, 188 expressedType, scale, zeroPoint, 189 storageTypeMin, storageTypeMax)); 190 }, 191 "Gets an instance of UniformQuantizedType in the same context as the " 192 "provided storage type.", 193 py::arg("cls"), py::arg("flags"), py::arg("storage_type"), 194 py::arg("expressed_type"), py::arg("scale"), py::arg("zero_point"), 195 py::arg("storage_type_min"), py::arg("storage_type_max")); 196 uniformQuantizedType.def_property_readonly( 197 "scale", 198 [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); }, 199 "The scale designates the difference between the real values " 200 "corresponding to consecutive quantized values differing by 1."); 201 uniformQuantizedType.def_property_readonly( 202 "zero_point", 203 [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); }, 204 "The storage value corresponding to the real value 0 in the affine " 205 "equation."); 206 uniformQuantizedType.def_property_readonly( 207 "is_fixed_point", 208 [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); }, 209 "Fixed point values are real numbers divided by a scale."); 210 211 //===-------------------------------------------------------------------===// 212 // UniformQuantizedPerAxisType 213 //===-------------------------------------------------------------------===// 214 auto uniformQuantizedPerAxisType = mlir_type_subclass( 215 m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType, 216 quantizedType.get_class()); 217 uniformQuantizedPerAxisType.def_classmethod( 218 "get", 219 [](py::object cls, unsigned flags, MlirType storageType, 220 MlirType expressedType, std::vector<double> scales, 221 std::vector<int64_t> zeroPoints, int32_t quantizedDimension, 222 int64_t storageTypeMin, int64_t storageTypeMax) { 223 if (scales.size() != zeroPoints.size()) 224 throw py::value_error( 225 "Mismatching number of scales and zero points."); 226 auto nDims = static_cast<intptr_t>(scales.size()); 227 return cls(mlirUniformQuantizedPerAxisTypeGet( 228 flags, storageType, expressedType, nDims, scales.data(), 229 zeroPoints.data(), quantizedDimension, storageTypeMin, 230 storageTypeMax)); 231 }, 232 "Gets an instance of UniformQuantizedPerAxisType in the same context as " 233 "the provided storage type.", 234 py::arg("cls"), py::arg("flags"), py::arg("storage_type"), 235 py::arg("expressed_type"), py::arg("scales"), py::arg("zero_points"), 236 py::arg("quantized_dimension"), py::arg("storage_type_min"), 237 py::arg("storage_type_max")); 238 uniformQuantizedPerAxisType.def_property_readonly( 239 "scales", 240 [](MlirType type) { 241 intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); 242 std::vector<double> scales; 243 scales.reserve(nDim); 244 for (intptr_t i = 0; i < nDim; ++i) { 245 double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i); 246 scales.push_back(scale); 247 } 248 }, 249 "The scales designate the difference between the real values " 250 "corresponding to consecutive quantized values differing by 1. The ith " 251 "scale corresponds to the ith slice in the quantized_dimension."); 252 uniformQuantizedPerAxisType.def_property_readonly( 253 "zero_points", 254 [](MlirType type) { 255 intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); 256 std::vector<int64_t> zeroPoints; 257 zeroPoints.reserve(nDim); 258 for (intptr_t i = 0; i < nDim; ++i) { 259 int64_t zeroPoint = 260 mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i); 261 zeroPoints.push_back(zeroPoint); 262 } 263 }, 264 "the storage values corresponding to the real value 0 in the affine " 265 "equation. The ith zero point corresponds to the ith slice in the " 266 "quantized_dimension."); 267 uniformQuantizedPerAxisType.def_property_readonly( 268 "quantized_dimension", 269 [](MlirType type) { 270 return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type); 271 }, 272 "Specifies the dimension of the shape that the scales and zero points " 273 "correspond to."); 274 uniformQuantizedPerAxisType.def_property_readonly( 275 "is_fixed_point", 276 [](MlirType type) { 277 return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type); 278 }, 279 "Fixed point values are real numbers divided by a scale."); 280 281 //===-------------------------------------------------------------------===// 282 // CalibratedQuantizedType 283 //===-------------------------------------------------------------------===// 284 285 auto calibratedQuantizedType = mlir_type_subclass( 286 m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType, 287 quantizedType.get_class()); 288 calibratedQuantizedType.def_classmethod( 289 "get", 290 [](py::object cls, MlirType expressedType, double min, double max) { 291 return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max)); 292 }, 293 "Gets an instance of CalibratedQuantizedType in the same context as the " 294 "provided expressed type.", 295 py::arg("cls"), py::arg("expressed_type"), py::arg("min"), 296 py::arg("max")); 297 calibratedQuantizedType.def_property_readonly("min", [](MlirType type) { 298 return mlirCalibratedQuantizedTypeGetMin(type); 299 }); 300 calibratedQuantizedType.def_property_readonly("max", [](MlirType type) { 301 return mlirCalibratedQuantizedTypeGetMax(type); 302 }); 303 } 304 305 PYBIND11_MODULE(_mlirDialectsQuant, m) { 306 m.doc() = "MLIR Quantization dialect"; 307 308 populateDialectQuantSubmodule(m); 309 } 310