166d4090dSAlex Zinenko //===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===//
266d4090dSAlex Zinenko //
366d4090dSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
466d4090dSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
566d4090dSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
666d4090dSAlex Zinenko //
766d4090dSAlex Zinenko //===----------------------------------------------------------------------===//
866d4090dSAlex Zinenko 
966d4090dSAlex Zinenko #include "mlir-c/Dialect/Quant.h"
1066d4090dSAlex Zinenko #include "mlir-c/IR.h"
1166d4090dSAlex Zinenko #include "mlir/Bindings/Python/PybindAdaptors.h"
1266d4090dSAlex Zinenko 
1366d4090dSAlex Zinenko namespace py = pybind11;
1466d4090dSAlex Zinenko using namespace llvm;
1566d4090dSAlex Zinenko using namespace mlir;
1666d4090dSAlex Zinenko using namespace mlir::python::adaptors;
1766d4090dSAlex Zinenko 
populateDialectQuantSubmodule(const py::module & m)18*95ddbed9SAlex Zinenko static void populateDialectQuantSubmodule(const py::module &m) {
1966d4090dSAlex Zinenko   //===-------------------------------------------------------------------===//
2066d4090dSAlex Zinenko   // QuantizedType
2166d4090dSAlex Zinenko   //===-------------------------------------------------------------------===//
2266d4090dSAlex Zinenko 
23*95ddbed9SAlex Zinenko   auto quantizedType =
24*95ddbed9SAlex Zinenko       mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType);
2566d4090dSAlex Zinenko   quantizedType.def_staticmethod(
2666d4090dSAlex Zinenko       "default_minimum_for_integer",
2766d4090dSAlex Zinenko       [](bool isSigned, unsigned integralWidth) {
2866d4090dSAlex Zinenko         return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned,
2966d4090dSAlex Zinenko                                                             integralWidth);
3066d4090dSAlex Zinenko       },
3166d4090dSAlex Zinenko       "Default minimum value for the integer with the specified signedness and "
3266d4090dSAlex Zinenko       "bit width.",
3366d4090dSAlex Zinenko       py::arg("is_signed"), py::arg("integral_width"));
3466d4090dSAlex Zinenko   quantizedType.def_staticmethod(
3566d4090dSAlex Zinenko       "default_maximum_for_integer",
3666d4090dSAlex Zinenko       [](bool isSigned, unsigned integralWidth) {
3766d4090dSAlex Zinenko         return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned,
3866d4090dSAlex Zinenko                                                             integralWidth);
3966d4090dSAlex Zinenko       },
4066d4090dSAlex Zinenko       "Default maximum value for the integer with the specified signedness and "
4166d4090dSAlex Zinenko       "bit width.",
4266d4090dSAlex Zinenko       py::arg("is_signed"), py::arg("integral_width"));
4366d4090dSAlex Zinenko   quantizedType.def_property_readonly(
4466d4090dSAlex Zinenko       "expressed_type",
4566d4090dSAlex Zinenko       [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); },
4666d4090dSAlex Zinenko       "Type expressed by this quantized type.");
4766d4090dSAlex Zinenko   quantizedType.def_property_readonly(
4866d4090dSAlex Zinenko       "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); },
4966d4090dSAlex Zinenko       "Flags of this quantized type (named accessors should be preferred to "
5066d4090dSAlex Zinenko       "this)");
5166d4090dSAlex Zinenko   quantizedType.def_property_readonly(
5266d4090dSAlex Zinenko       "is_signed",
5366d4090dSAlex Zinenko       [](MlirType type) { return mlirQuantizedTypeIsSigned(type); },
5466d4090dSAlex Zinenko       "Signedness of this quantized type.");
5566d4090dSAlex Zinenko   quantizedType.def_property_readonly(
5666d4090dSAlex Zinenko       "storage_type",
5766d4090dSAlex Zinenko       [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); },
5866d4090dSAlex Zinenko       "Storage type backing this quantized type.");
5966d4090dSAlex Zinenko   quantizedType.def_property_readonly(
6066d4090dSAlex Zinenko       "storage_type_min",
6166d4090dSAlex Zinenko       [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); },
6266d4090dSAlex Zinenko       "The minimum value held by the storage type of this quantized type.");
6366d4090dSAlex Zinenko   quantizedType.def_property_readonly(
6466d4090dSAlex Zinenko       "storage_type_max",
6566d4090dSAlex Zinenko       [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); },
6666d4090dSAlex Zinenko       "The maximum value held by the storage type of this quantized type.");
6766d4090dSAlex Zinenko   quantizedType.def_property_readonly(
6866d4090dSAlex Zinenko       "storage_type_integral_width",
6966d4090dSAlex Zinenko       [](MlirType type) {
7066d4090dSAlex Zinenko         return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
7166d4090dSAlex Zinenko       },
7266d4090dSAlex Zinenko       "The bitwidth of the storage type of this quantized type.");
7366d4090dSAlex Zinenko   quantizedType.def(
7466d4090dSAlex Zinenko       "is_compatible_expressed_type",
7566d4090dSAlex Zinenko       [](MlirType type, MlirType candidate) {
7666d4090dSAlex Zinenko         return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
7766d4090dSAlex Zinenko       },
7866d4090dSAlex Zinenko       "Checks whether the candidate type can be expressed by this quantized "
7966d4090dSAlex Zinenko       "type.",
8066d4090dSAlex Zinenko       py::arg("candidate"));
8166d4090dSAlex Zinenko   quantizedType.def_property_readonly(
8266d4090dSAlex Zinenko       "quantized_element_type",
8366d4090dSAlex Zinenko       [](MlirType type) {
8466d4090dSAlex Zinenko         return mlirQuantizedTypeGetQuantizedElementType(type);
8566d4090dSAlex Zinenko       },
8666d4090dSAlex Zinenko       "Element type of this quantized type expressed as quantized type.");
8766d4090dSAlex Zinenko   quantizedType.def(
8866d4090dSAlex Zinenko       "cast_from_storage_type",
8966d4090dSAlex Zinenko       [](MlirType type, MlirType candidate) {
9066d4090dSAlex Zinenko         MlirType castResult =
9166d4090dSAlex Zinenko             mlirQuantizedTypeCastFromStorageType(type, candidate);
9266d4090dSAlex Zinenko         if (!mlirTypeIsNull(castResult))
9366d4090dSAlex Zinenko           return castResult;
9466d4090dSAlex Zinenko         throw py::type_error("Invalid cast.");
9566d4090dSAlex Zinenko       },
9666d4090dSAlex Zinenko       "Casts from a type based on the storage type of this quantized type to a "
9766d4090dSAlex Zinenko       "corresponding type based on the quantized type. Raises TypeError if the "
9866d4090dSAlex Zinenko       "cast is not valid.",
9966d4090dSAlex Zinenko       py::arg("candidate"));
10066d4090dSAlex Zinenko   quantizedType.def_staticmethod(
10166d4090dSAlex Zinenko       "cast_to_storage_type",
10266d4090dSAlex Zinenko       [](MlirType type) {
10366d4090dSAlex Zinenko         MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
10466d4090dSAlex Zinenko         if (!mlirTypeIsNull(castResult))
10566d4090dSAlex Zinenko           return castResult;
10666d4090dSAlex Zinenko         throw py::type_error("Invalid cast.");
10766d4090dSAlex Zinenko       },
10866d4090dSAlex Zinenko       "Casts from a type based on a quantized type to a corresponding type "
10966d4090dSAlex Zinenko       "based on the storage type of this quantized type. Raises TypeError if "
11066d4090dSAlex Zinenko       "the cast is not valid.",
11166d4090dSAlex Zinenko       py::arg("type"));
11266d4090dSAlex Zinenko   quantizedType.def(
11366d4090dSAlex Zinenko       "cast_from_expressed_type",
11466d4090dSAlex Zinenko       [](MlirType type, MlirType candidate) {
11566d4090dSAlex Zinenko         MlirType castResult =
11666d4090dSAlex Zinenko             mlirQuantizedTypeCastFromExpressedType(type, candidate);
11766d4090dSAlex Zinenko         if (!mlirTypeIsNull(castResult))
11866d4090dSAlex Zinenko           return castResult;
11966d4090dSAlex Zinenko         throw py::type_error("Invalid cast.");
12066d4090dSAlex Zinenko       },
12166d4090dSAlex Zinenko       "Casts from a type based on the expressed type of this quantized type to "
12266d4090dSAlex Zinenko       "a corresponding type based on the quantized type. Raises TypeError if "
12366d4090dSAlex Zinenko       "the cast is not valid.",
12466d4090dSAlex Zinenko       py::arg("candidate"));
12566d4090dSAlex Zinenko   quantizedType.def_staticmethod(
12666d4090dSAlex Zinenko       "cast_to_expressed_type",
12766d4090dSAlex Zinenko       [](MlirType type) {
12866d4090dSAlex Zinenko         MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
12966d4090dSAlex Zinenko         if (!mlirTypeIsNull(castResult))
13066d4090dSAlex Zinenko           return castResult;
13166d4090dSAlex Zinenko         throw py::type_error("Invalid cast.");
13266d4090dSAlex Zinenko       },
13366d4090dSAlex Zinenko       "Casts from a type based on a quantized type to a corresponding type "
13466d4090dSAlex Zinenko       "based on the expressed type of this quantized type. Raises TypeError if "
13566d4090dSAlex Zinenko       "the cast is not valid.",
13666d4090dSAlex Zinenko       py::arg("type"));
13766d4090dSAlex Zinenko   quantizedType.def(
13866d4090dSAlex Zinenko       "cast_expressed_to_storage_type",
13966d4090dSAlex Zinenko       [](MlirType type, MlirType candidate) {
14066d4090dSAlex Zinenko         MlirType castResult =
14166d4090dSAlex Zinenko             mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
14266d4090dSAlex Zinenko         if (!mlirTypeIsNull(castResult))
14366d4090dSAlex Zinenko           return castResult;
14466d4090dSAlex Zinenko         throw py::type_error("Invalid cast.");
14566d4090dSAlex Zinenko       },
14666d4090dSAlex Zinenko       "Casts from a type based on the expressed type of this quantized type to "
14766d4090dSAlex Zinenko       "a corresponding type based on the storage type. Raises TypeError if the "
14866d4090dSAlex Zinenko       "cast is not valid.",
14966d4090dSAlex Zinenko       py::arg("candidate"));
15066d4090dSAlex Zinenko 
15166d4090dSAlex Zinenko   quantizedType.get_class().attr("FLAG_SIGNED") =
15266d4090dSAlex Zinenko       mlirQuantizedTypeGetSignedFlag();
15366d4090dSAlex Zinenko 
15466d4090dSAlex Zinenko   //===-------------------------------------------------------------------===//
15566d4090dSAlex Zinenko   // AnyQuantizedType
15666d4090dSAlex Zinenko   //===-------------------------------------------------------------------===//
15766d4090dSAlex Zinenko 
15866d4090dSAlex Zinenko   auto anyQuantizedType =
15966d4090dSAlex Zinenko       mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType,
16066d4090dSAlex Zinenko                          quantizedType.get_class());
16166d4090dSAlex Zinenko   anyQuantizedType.def_classmethod(
16266d4090dSAlex Zinenko       "get",
16366d4090dSAlex Zinenko       [](py::object cls, unsigned flags, MlirType storageType,
16466d4090dSAlex Zinenko          MlirType expressedType, int64_t storageTypeMin,
16566d4090dSAlex Zinenko          int64_t storageTypeMax) {
16666d4090dSAlex Zinenko         return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
16766d4090dSAlex Zinenko                                            storageTypeMin, storageTypeMax));
16866d4090dSAlex Zinenko       },
16966d4090dSAlex Zinenko       "Gets an instance of AnyQuantizedType in the same context as the "
17066d4090dSAlex Zinenko       "provided storage type.",
17166d4090dSAlex Zinenko       py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
17266d4090dSAlex Zinenko       py::arg("expressed_type"), py::arg("storage_type_min"),
17366d4090dSAlex Zinenko       py::arg("storage_type_max"));
17466d4090dSAlex Zinenko 
17566d4090dSAlex Zinenko   //===-------------------------------------------------------------------===//
17666d4090dSAlex Zinenko   // UniformQuantizedType
17766d4090dSAlex Zinenko   //===-------------------------------------------------------------------===//
17866d4090dSAlex Zinenko 
17966d4090dSAlex Zinenko   auto uniformQuantizedType = mlir_type_subclass(
18066d4090dSAlex Zinenko       m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType,
18166d4090dSAlex Zinenko       quantizedType.get_class());
18266d4090dSAlex Zinenko   uniformQuantizedType.def_classmethod(
18366d4090dSAlex Zinenko       "get",
18466d4090dSAlex Zinenko       [](py::object cls, unsigned flags, MlirType storageType,
18566d4090dSAlex Zinenko          MlirType expressedType, double scale, int64_t zeroPoint,
18666d4090dSAlex Zinenko          int64_t storageTypeMin, int64_t storageTypeMax) {
18766d4090dSAlex Zinenko         return cls(mlirUniformQuantizedTypeGet(flags, storageType,
18866d4090dSAlex Zinenko                                                expressedType, scale, zeroPoint,
18966d4090dSAlex Zinenko                                                storageTypeMin, storageTypeMax));
19066d4090dSAlex Zinenko       },
19166d4090dSAlex Zinenko       "Gets an instance of UniformQuantizedType in the same context as the "
19266d4090dSAlex Zinenko       "provided storage type.",
19366d4090dSAlex Zinenko       py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
19466d4090dSAlex Zinenko       py::arg("expressed_type"), py::arg("scale"), py::arg("zero_point"),
19566d4090dSAlex Zinenko       py::arg("storage_type_min"), py::arg("storage_type_max"));
19666d4090dSAlex Zinenko   uniformQuantizedType.def_property_readonly(
19766d4090dSAlex Zinenko       "scale",
19866d4090dSAlex Zinenko       [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); },
19966d4090dSAlex Zinenko       "The scale designates the difference between the real values "
20066d4090dSAlex Zinenko       "corresponding to consecutive quantized values differing by 1.");
20166d4090dSAlex Zinenko   uniformQuantizedType.def_property_readonly(
20266d4090dSAlex Zinenko       "zero_point",
20366d4090dSAlex Zinenko       [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); },
20466d4090dSAlex Zinenko       "The storage value corresponding to the real value 0 in the affine "
20566d4090dSAlex Zinenko       "equation.");
20666d4090dSAlex Zinenko   uniformQuantizedType.def_property_readonly(
20766d4090dSAlex Zinenko       "is_fixed_point",
20866d4090dSAlex Zinenko       [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); },
20966d4090dSAlex Zinenko       "Fixed point values are real numbers divided by a scale.");
21066d4090dSAlex Zinenko 
21166d4090dSAlex Zinenko   //===-------------------------------------------------------------------===//
21266d4090dSAlex Zinenko   // UniformQuantizedPerAxisType
21366d4090dSAlex Zinenko   //===-------------------------------------------------------------------===//
21466d4090dSAlex Zinenko   auto uniformQuantizedPerAxisType = mlir_type_subclass(
21566d4090dSAlex Zinenko       m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType,
21666d4090dSAlex Zinenko       quantizedType.get_class());
21766d4090dSAlex Zinenko   uniformQuantizedPerAxisType.def_classmethod(
21866d4090dSAlex Zinenko       "get",
21966d4090dSAlex Zinenko       [](py::object cls, unsigned flags, MlirType storageType,
22066d4090dSAlex Zinenko          MlirType expressedType, std::vector<double> scales,
22166d4090dSAlex Zinenko          std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
22266d4090dSAlex Zinenko          int64_t storageTypeMin, int64_t storageTypeMax) {
22366d4090dSAlex Zinenko         if (scales.size() != zeroPoints.size())
22466d4090dSAlex Zinenko           throw py::value_error(
22566d4090dSAlex Zinenko               "Mismatching number of scales and zero points.");
22666d4090dSAlex Zinenko         auto nDims = static_cast<intptr_t>(scales.size());
22766d4090dSAlex Zinenko         return cls(mlirUniformQuantizedPerAxisTypeGet(
22866d4090dSAlex Zinenko             flags, storageType, expressedType, nDims, scales.data(),
22966d4090dSAlex Zinenko             zeroPoints.data(), quantizedDimension, storageTypeMin,
23066d4090dSAlex Zinenko             storageTypeMax));
23166d4090dSAlex Zinenko       },
23266d4090dSAlex Zinenko       "Gets an instance of UniformQuantizedPerAxisType in the same context as "
23366d4090dSAlex Zinenko       "the provided storage type.",
23466d4090dSAlex Zinenko       py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
23566d4090dSAlex Zinenko       py::arg("expressed_type"), py::arg("scales"), py::arg("zero_points"),
23666d4090dSAlex Zinenko       py::arg("quantized_dimension"), py::arg("storage_type_min"),
23766d4090dSAlex Zinenko       py::arg("storage_type_max"));
23866d4090dSAlex Zinenko   uniformQuantizedPerAxisType.def_property_readonly(
23966d4090dSAlex Zinenko       "scales",
24066d4090dSAlex Zinenko       [](MlirType type) {
24166d4090dSAlex Zinenko         intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
24266d4090dSAlex Zinenko         std::vector<double> scales;
24366d4090dSAlex Zinenko         scales.reserve(nDim);
24466d4090dSAlex Zinenko         for (intptr_t i = 0; i < nDim; ++i) {
24566d4090dSAlex Zinenko           double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i);
24666d4090dSAlex Zinenko           scales.push_back(scale);
24766d4090dSAlex Zinenko         }
24866d4090dSAlex Zinenko       },
24966d4090dSAlex Zinenko       "The scales designate the difference between the real values "
25066d4090dSAlex Zinenko       "corresponding to consecutive quantized values differing by 1. The ith "
25166d4090dSAlex Zinenko       "scale corresponds to the ith slice in the quantized_dimension.");
25266d4090dSAlex Zinenko   uniformQuantizedPerAxisType.def_property_readonly(
25366d4090dSAlex Zinenko       "zero_points",
25466d4090dSAlex Zinenko       [](MlirType type) {
25566d4090dSAlex Zinenko         intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
25666d4090dSAlex Zinenko         std::vector<int64_t> zeroPoints;
25766d4090dSAlex Zinenko         zeroPoints.reserve(nDim);
25866d4090dSAlex Zinenko         for (intptr_t i = 0; i < nDim; ++i) {
25966d4090dSAlex Zinenko           int64_t zeroPoint =
26066d4090dSAlex Zinenko               mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i);
26166d4090dSAlex Zinenko           zeroPoints.push_back(zeroPoint);
26266d4090dSAlex Zinenko         }
26366d4090dSAlex Zinenko       },
26466d4090dSAlex Zinenko       "the storage values corresponding to the real value 0 in the affine "
26566d4090dSAlex Zinenko       "equation. The ith zero point corresponds to the ith slice in the "
26666d4090dSAlex Zinenko       "quantized_dimension.");
26766d4090dSAlex Zinenko   uniformQuantizedPerAxisType.def_property_readonly(
26866d4090dSAlex Zinenko       "quantized_dimension",
26966d4090dSAlex Zinenko       [](MlirType type) {
27066d4090dSAlex Zinenko         return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
27166d4090dSAlex Zinenko       },
27266d4090dSAlex Zinenko       "Specifies the dimension of the shape that the scales and zero points "
27366d4090dSAlex Zinenko       "correspond to.");
27466d4090dSAlex Zinenko   uniformQuantizedPerAxisType.def_property_readonly(
27566d4090dSAlex Zinenko       "is_fixed_point",
27666d4090dSAlex Zinenko       [](MlirType type) {
27766d4090dSAlex Zinenko         return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
27866d4090dSAlex Zinenko       },
27966d4090dSAlex Zinenko       "Fixed point values are real numbers divided by a scale.");
28066d4090dSAlex Zinenko 
28166d4090dSAlex Zinenko   //===-------------------------------------------------------------------===//
28266d4090dSAlex Zinenko   // CalibratedQuantizedType
28366d4090dSAlex Zinenko   //===-------------------------------------------------------------------===//
28466d4090dSAlex Zinenko 
28566d4090dSAlex Zinenko   auto calibratedQuantizedType = mlir_type_subclass(
28666d4090dSAlex Zinenko       m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType,
28766d4090dSAlex Zinenko       quantizedType.get_class());
28866d4090dSAlex Zinenko   calibratedQuantizedType.def_classmethod(
28966d4090dSAlex Zinenko       "get",
29066d4090dSAlex Zinenko       [](py::object cls, MlirType expressedType, double min, double max) {
29166d4090dSAlex Zinenko         return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max));
29266d4090dSAlex Zinenko       },
29366d4090dSAlex Zinenko       "Gets an instance of CalibratedQuantizedType in the same context as the "
29466d4090dSAlex Zinenko       "provided expressed type.",
29566d4090dSAlex Zinenko       py::arg("cls"), py::arg("expressed_type"), py::arg("min"),
29666d4090dSAlex Zinenko       py::arg("max"));
29766d4090dSAlex Zinenko   calibratedQuantizedType.def_property_readonly("min", [](MlirType type) {
29866d4090dSAlex Zinenko     return mlirCalibratedQuantizedTypeGetMin(type);
29966d4090dSAlex Zinenko   });
30066d4090dSAlex Zinenko   calibratedQuantizedType.def_property_readonly("max", [](MlirType type) {
30166d4090dSAlex Zinenko     return mlirCalibratedQuantizedTypeGetMax(type);
30266d4090dSAlex Zinenko   });
30366d4090dSAlex Zinenko }
304*95ddbed9SAlex Zinenko 
PYBIND11_MODULE(_mlirDialectsQuant,m)305*95ddbed9SAlex Zinenko PYBIND11_MODULE(_mlirDialectsQuant, m) {
306*95ddbed9SAlex Zinenko   m.doc() = "MLIR Quantization dialect";
307*95ddbed9SAlex Zinenko 
308*95ddbed9SAlex Zinenko   populateDialectQuantSubmodule(m);
309*95ddbed9SAlex Zinenko }
310