1 //===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Dialects.h" 10 #include "mlir-c/Dialect/Quant.h" 11 #include "mlir-c/IR.h" 12 #include "mlir/Bindings/Python/PybindAdaptors.h" 13 14 namespace py = pybind11; 15 using namespace llvm; 16 using namespace mlir; 17 using namespace mlir::python::adaptors; 18 19 void mlir::python::populateDialectQuantSubmodule(const py::module &m, 20 const py::module &irModule) { 21 auto typeClass = irModule.attr("Type"); 22 23 //===-------------------------------------------------------------------===// 24 // QuantizedType 25 //===-------------------------------------------------------------------===// 26 27 auto quantizedType = mlir_type_subclass(m, "QuantizedType", 28 mlirTypeIsAQuantizedType, typeClass); 29 quantizedType.def_staticmethod( 30 "default_minimum_for_integer", 31 [](bool isSigned, unsigned integralWidth) { 32 return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned, 33 integralWidth); 34 }, 35 "Default minimum value for the integer with the specified signedness and " 36 "bit width.", 37 py::arg("is_signed"), py::arg("integral_width")); 38 quantizedType.def_staticmethod( 39 "default_maximum_for_integer", 40 [](bool isSigned, unsigned integralWidth) { 41 return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned, 42 integralWidth); 43 }, 44 "Default maximum value for the integer with the specified signedness and " 45 "bit width.", 46 py::arg("is_signed"), py::arg("integral_width")); 47 quantizedType.def_property_readonly( 48 "expressed_type", 49 [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); }, 50 "Type expressed by this quantized type."); 51 quantizedType.def_property_readonly( 52 "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); }, 53 "Flags of this quantized type (named accessors should be preferred to " 54 "this)"); 55 quantizedType.def_property_readonly( 56 "is_signed", 57 [](MlirType type) { return mlirQuantizedTypeIsSigned(type); }, 58 "Signedness of this quantized type."); 59 quantizedType.def_property_readonly( 60 "storage_type", 61 [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); }, 62 "Storage type backing this quantized type."); 63 quantizedType.def_property_readonly( 64 "storage_type_min", 65 [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); }, 66 "The minimum value held by the storage type of this quantized type."); 67 quantizedType.def_property_readonly( 68 "storage_type_max", 69 [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); }, 70 "The maximum value held by the storage type of this quantized type."); 71 quantizedType.def_property_readonly( 72 "storage_type_integral_width", 73 [](MlirType type) { 74 return mlirQuantizedTypeGetStorageTypeIntegralWidth(type); 75 }, 76 "The bitwidth of the storage type of this quantized type."); 77 quantizedType.def( 78 "is_compatible_expressed_type", 79 [](MlirType type, MlirType candidate) { 80 return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate); 81 }, 82 "Checks whether the candidate type can be expressed by this quantized " 83 "type.", 84 py::arg("candidate")); 85 quantizedType.def_property_readonly( 86 "quantized_element_type", 87 [](MlirType type) { 88 return mlirQuantizedTypeGetQuantizedElementType(type); 89 }, 90 "Element type of this quantized type expressed as quantized type."); 91 quantizedType.def( 92 "cast_from_storage_type", 93 [](MlirType type, MlirType candidate) { 94 MlirType castResult = 95 mlirQuantizedTypeCastFromStorageType(type, candidate); 96 if (!mlirTypeIsNull(castResult)) 97 return castResult; 98 throw py::type_error("Invalid cast."); 99 }, 100 "Casts from a type based on the storage type of this quantized type to a " 101 "corresponding type based on the quantized type. Raises TypeError if the " 102 "cast is not valid.", 103 py::arg("candidate")); 104 quantizedType.def_staticmethod( 105 "cast_to_storage_type", 106 [](MlirType type) { 107 MlirType castResult = mlirQuantizedTypeCastToStorageType(type); 108 if (!mlirTypeIsNull(castResult)) 109 return castResult; 110 throw py::type_error("Invalid cast."); 111 }, 112 "Casts from a type based on a quantized type to a corresponding type " 113 "based on the storage type of this quantized type. Raises TypeError if " 114 "the cast is not valid.", 115 py::arg("type")); 116 quantizedType.def( 117 "cast_from_expressed_type", 118 [](MlirType type, MlirType candidate) { 119 MlirType castResult = 120 mlirQuantizedTypeCastFromExpressedType(type, candidate); 121 if (!mlirTypeIsNull(castResult)) 122 return castResult; 123 throw py::type_error("Invalid cast."); 124 }, 125 "Casts from a type based on the expressed type of this quantized type to " 126 "a corresponding type based on the quantized type. Raises TypeError if " 127 "the cast is not valid.", 128 py::arg("candidate")); 129 quantizedType.def_staticmethod( 130 "cast_to_expressed_type", 131 [](MlirType type) { 132 MlirType castResult = mlirQuantizedTypeCastToExpressedType(type); 133 if (!mlirTypeIsNull(castResult)) 134 return castResult; 135 throw py::type_error("Invalid cast."); 136 }, 137 "Casts from a type based on a quantized type to a corresponding type " 138 "based on the expressed type of this quantized type. Raises TypeError if " 139 "the cast is not valid.", 140 py::arg("type")); 141 quantizedType.def( 142 "cast_expressed_to_storage_type", 143 [](MlirType type, MlirType candidate) { 144 MlirType castResult = 145 mlirQuantizedTypeCastExpressedToStorageType(type, candidate); 146 if (!mlirTypeIsNull(castResult)) 147 return castResult; 148 throw py::type_error("Invalid cast."); 149 }, 150 "Casts from a type based on the expressed type of this quantized type to " 151 "a corresponding type based on the storage type. Raises TypeError if the " 152 "cast is not valid.", 153 py::arg("candidate")); 154 155 quantizedType.get_class().attr("FLAG_SIGNED") = 156 mlirQuantizedTypeGetSignedFlag(); 157 158 //===-------------------------------------------------------------------===// 159 // AnyQuantizedType 160 //===-------------------------------------------------------------------===// 161 162 auto anyQuantizedType = 163 mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType, 164 quantizedType.get_class()); 165 anyQuantizedType.def_classmethod( 166 "get", 167 [](py::object cls, unsigned flags, MlirType storageType, 168 MlirType expressedType, int64_t storageTypeMin, 169 int64_t storageTypeMax) { 170 return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType, 171 storageTypeMin, storageTypeMax)); 172 }, 173 "Gets an instance of AnyQuantizedType in the same context as the " 174 "provided storage type.", 175 py::arg("cls"), py::arg("flags"), py::arg("storage_type"), 176 py::arg("expressed_type"), py::arg("storage_type_min"), 177 py::arg("storage_type_max")); 178 179 //===-------------------------------------------------------------------===// 180 // UniformQuantizedType 181 //===-------------------------------------------------------------------===// 182 183 auto uniformQuantizedType = mlir_type_subclass( 184 m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType, 185 quantizedType.get_class()); 186 uniformQuantizedType.def_classmethod( 187 "get", 188 [](py::object cls, unsigned flags, MlirType storageType, 189 MlirType expressedType, double scale, int64_t zeroPoint, 190 int64_t storageTypeMin, int64_t storageTypeMax) { 191 return cls(mlirUniformQuantizedTypeGet(flags, storageType, 192 expressedType, scale, zeroPoint, 193 storageTypeMin, storageTypeMax)); 194 }, 195 "Gets an instance of UniformQuantizedType in the same context as the " 196 "provided storage type.", 197 py::arg("cls"), py::arg("flags"), py::arg("storage_type"), 198 py::arg("expressed_type"), py::arg("scale"), py::arg("zero_point"), 199 py::arg("storage_type_min"), py::arg("storage_type_max")); 200 uniformQuantizedType.def_property_readonly( 201 "scale", 202 [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); }, 203 "The scale designates the difference between the real values " 204 "corresponding to consecutive quantized values differing by 1."); 205 uniformQuantizedType.def_property_readonly( 206 "zero_point", 207 [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); }, 208 "The storage value corresponding to the real value 0 in the affine " 209 "equation."); 210 uniformQuantizedType.def_property_readonly( 211 "is_fixed_point", 212 [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); }, 213 "Fixed point values are real numbers divided by a scale."); 214 215 //===-------------------------------------------------------------------===// 216 // UniformQuantizedPerAxisType 217 //===-------------------------------------------------------------------===// 218 auto uniformQuantizedPerAxisType = mlir_type_subclass( 219 m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType, 220 quantizedType.get_class()); 221 uniformQuantizedPerAxisType.def_classmethod( 222 "get", 223 [](py::object cls, unsigned flags, MlirType storageType, 224 MlirType expressedType, std::vector<double> scales, 225 std::vector<int64_t> zeroPoints, int32_t quantizedDimension, 226 int64_t storageTypeMin, int64_t storageTypeMax) { 227 if (scales.size() != zeroPoints.size()) 228 throw py::value_error( 229 "Mismatching number of scales and zero points."); 230 auto nDims = static_cast<intptr_t>(scales.size()); 231 return cls(mlirUniformQuantizedPerAxisTypeGet( 232 flags, storageType, expressedType, nDims, scales.data(), 233 zeroPoints.data(), quantizedDimension, storageTypeMin, 234 storageTypeMax)); 235 }, 236 "Gets an instance of UniformQuantizedPerAxisType in the same context as " 237 "the provided storage type.", 238 py::arg("cls"), py::arg("flags"), py::arg("storage_type"), 239 py::arg("expressed_type"), py::arg("scales"), py::arg("zero_points"), 240 py::arg("quantized_dimension"), py::arg("storage_type_min"), 241 py::arg("storage_type_max")); 242 uniformQuantizedPerAxisType.def_property_readonly( 243 "scales", 244 [](MlirType type) { 245 intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); 246 std::vector<double> scales; 247 scales.reserve(nDim); 248 for (intptr_t i = 0; i < nDim; ++i) { 249 double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i); 250 scales.push_back(scale); 251 } 252 }, 253 "The scales designate the difference between the real values " 254 "corresponding to consecutive quantized values differing by 1. The ith " 255 "scale corresponds to the ith slice in the quantized_dimension."); 256 uniformQuantizedPerAxisType.def_property_readonly( 257 "zero_points", 258 [](MlirType type) { 259 intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); 260 std::vector<int64_t> zeroPoints; 261 zeroPoints.reserve(nDim); 262 for (intptr_t i = 0; i < nDim; ++i) { 263 int64_t zeroPoint = 264 mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i); 265 zeroPoints.push_back(zeroPoint); 266 } 267 }, 268 "the storage values corresponding to the real value 0 in the affine " 269 "equation. The ith zero point corresponds to the ith slice in the " 270 "quantized_dimension."); 271 uniformQuantizedPerAxisType.def_property_readonly( 272 "quantized_dimension", 273 [](MlirType type) { 274 return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type); 275 }, 276 "Specifies the dimension of the shape that the scales and zero points " 277 "correspond to."); 278 uniformQuantizedPerAxisType.def_property_readonly( 279 "is_fixed_point", 280 [](MlirType type) { 281 return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type); 282 }, 283 "Fixed point values are real numbers divided by a scale."); 284 285 //===-------------------------------------------------------------------===// 286 // CalibratedQuantizedType 287 //===-------------------------------------------------------------------===// 288 289 auto calibratedQuantizedType = mlir_type_subclass( 290 m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType, 291 quantizedType.get_class()); 292 calibratedQuantizedType.def_classmethod( 293 "get", 294 [](py::object cls, MlirType expressedType, double min, double max) { 295 return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max)); 296 }, 297 "Gets an instance of CalibratedQuantizedType in the same context as the " 298 "provided expressed type.", 299 py::arg("cls"), py::arg("expressed_type"), py::arg("min"), 300 py::arg("max")); 301 calibratedQuantizedType.def_property_readonly("min", [](MlirType type) { 302 return mlirCalibratedQuantizedTypeGetMin(type); 303 }); 304 calibratedQuantizedType.def_property_readonly("max", [](MlirType type) { 305 return mlirCalibratedQuantizedTypeGetMax(type); 306 }); 307 } 308