166d4090dSAlex Zinenko //===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===//
266d4090dSAlex Zinenko //
366d4090dSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
466d4090dSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
566d4090dSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
666d4090dSAlex Zinenko //
766d4090dSAlex Zinenko //===----------------------------------------------------------------------===//
866d4090dSAlex Zinenko
966d4090dSAlex Zinenko #include "mlir-c/Dialect/Quant.h"
1066d4090dSAlex Zinenko #include "mlir-c/IR.h"
1166d4090dSAlex Zinenko #include "mlir/Bindings/Python/PybindAdaptors.h"
1266d4090dSAlex Zinenko
1366d4090dSAlex Zinenko namespace py = pybind11;
1466d4090dSAlex Zinenko using namespace llvm;
1566d4090dSAlex Zinenko using namespace mlir;
1666d4090dSAlex Zinenko using namespace mlir::python::adaptors;
1766d4090dSAlex Zinenko
populateDialectQuantSubmodule(const py::module & m)18*95ddbed9SAlex Zinenko static void populateDialectQuantSubmodule(const py::module &m) {
1966d4090dSAlex Zinenko //===-------------------------------------------------------------------===//
2066d4090dSAlex Zinenko // QuantizedType
2166d4090dSAlex Zinenko //===-------------------------------------------------------------------===//
2266d4090dSAlex Zinenko
23*95ddbed9SAlex Zinenko auto quantizedType =
24*95ddbed9SAlex Zinenko mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType);
2566d4090dSAlex Zinenko quantizedType.def_staticmethod(
2666d4090dSAlex Zinenko "default_minimum_for_integer",
2766d4090dSAlex Zinenko [](bool isSigned, unsigned integralWidth) {
2866d4090dSAlex Zinenko return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned,
2966d4090dSAlex Zinenko integralWidth);
3066d4090dSAlex Zinenko },
3166d4090dSAlex Zinenko "Default minimum value for the integer with the specified signedness and "
3266d4090dSAlex Zinenko "bit width.",
3366d4090dSAlex Zinenko py::arg("is_signed"), py::arg("integral_width"));
3466d4090dSAlex Zinenko quantizedType.def_staticmethod(
3566d4090dSAlex Zinenko "default_maximum_for_integer",
3666d4090dSAlex Zinenko [](bool isSigned, unsigned integralWidth) {
3766d4090dSAlex Zinenko return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned,
3866d4090dSAlex Zinenko integralWidth);
3966d4090dSAlex Zinenko },
4066d4090dSAlex Zinenko "Default maximum value for the integer with the specified signedness and "
4166d4090dSAlex Zinenko "bit width.",
4266d4090dSAlex Zinenko py::arg("is_signed"), py::arg("integral_width"));
4366d4090dSAlex Zinenko quantizedType.def_property_readonly(
4466d4090dSAlex Zinenko "expressed_type",
4566d4090dSAlex Zinenko [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); },
4666d4090dSAlex Zinenko "Type expressed by this quantized type.");
4766d4090dSAlex Zinenko quantizedType.def_property_readonly(
4866d4090dSAlex Zinenko "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); },
4966d4090dSAlex Zinenko "Flags of this quantized type (named accessors should be preferred to "
5066d4090dSAlex Zinenko "this)");
5166d4090dSAlex Zinenko quantizedType.def_property_readonly(
5266d4090dSAlex Zinenko "is_signed",
5366d4090dSAlex Zinenko [](MlirType type) { return mlirQuantizedTypeIsSigned(type); },
5466d4090dSAlex Zinenko "Signedness of this quantized type.");
5566d4090dSAlex Zinenko quantizedType.def_property_readonly(
5666d4090dSAlex Zinenko "storage_type",
5766d4090dSAlex Zinenko [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); },
5866d4090dSAlex Zinenko "Storage type backing this quantized type.");
5966d4090dSAlex Zinenko quantizedType.def_property_readonly(
6066d4090dSAlex Zinenko "storage_type_min",
6166d4090dSAlex Zinenko [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); },
6266d4090dSAlex Zinenko "The minimum value held by the storage type of this quantized type.");
6366d4090dSAlex Zinenko quantizedType.def_property_readonly(
6466d4090dSAlex Zinenko "storage_type_max",
6566d4090dSAlex Zinenko [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); },
6666d4090dSAlex Zinenko "The maximum value held by the storage type of this quantized type.");
6766d4090dSAlex Zinenko quantizedType.def_property_readonly(
6866d4090dSAlex Zinenko "storage_type_integral_width",
6966d4090dSAlex Zinenko [](MlirType type) {
7066d4090dSAlex Zinenko return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
7166d4090dSAlex Zinenko },
7266d4090dSAlex Zinenko "The bitwidth of the storage type of this quantized type.");
7366d4090dSAlex Zinenko quantizedType.def(
7466d4090dSAlex Zinenko "is_compatible_expressed_type",
7566d4090dSAlex Zinenko [](MlirType type, MlirType candidate) {
7666d4090dSAlex Zinenko return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
7766d4090dSAlex Zinenko },
7866d4090dSAlex Zinenko "Checks whether the candidate type can be expressed by this quantized "
7966d4090dSAlex Zinenko "type.",
8066d4090dSAlex Zinenko py::arg("candidate"));
8166d4090dSAlex Zinenko quantizedType.def_property_readonly(
8266d4090dSAlex Zinenko "quantized_element_type",
8366d4090dSAlex Zinenko [](MlirType type) {
8466d4090dSAlex Zinenko return mlirQuantizedTypeGetQuantizedElementType(type);
8566d4090dSAlex Zinenko },
8666d4090dSAlex Zinenko "Element type of this quantized type expressed as quantized type.");
8766d4090dSAlex Zinenko quantizedType.def(
8866d4090dSAlex Zinenko "cast_from_storage_type",
8966d4090dSAlex Zinenko [](MlirType type, MlirType candidate) {
9066d4090dSAlex Zinenko MlirType castResult =
9166d4090dSAlex Zinenko mlirQuantizedTypeCastFromStorageType(type, candidate);
9266d4090dSAlex Zinenko if (!mlirTypeIsNull(castResult))
9366d4090dSAlex Zinenko return castResult;
9466d4090dSAlex Zinenko throw py::type_error("Invalid cast.");
9566d4090dSAlex Zinenko },
9666d4090dSAlex Zinenko "Casts from a type based on the storage type of this quantized type to a "
9766d4090dSAlex Zinenko "corresponding type based on the quantized type. Raises TypeError if the "
9866d4090dSAlex Zinenko "cast is not valid.",
9966d4090dSAlex Zinenko py::arg("candidate"));
10066d4090dSAlex Zinenko quantizedType.def_staticmethod(
10166d4090dSAlex Zinenko "cast_to_storage_type",
10266d4090dSAlex Zinenko [](MlirType type) {
10366d4090dSAlex Zinenko MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
10466d4090dSAlex Zinenko if (!mlirTypeIsNull(castResult))
10566d4090dSAlex Zinenko return castResult;
10666d4090dSAlex Zinenko throw py::type_error("Invalid cast.");
10766d4090dSAlex Zinenko },
10866d4090dSAlex Zinenko "Casts from a type based on a quantized type to a corresponding type "
10966d4090dSAlex Zinenko "based on the storage type of this quantized type. Raises TypeError if "
11066d4090dSAlex Zinenko "the cast is not valid.",
11166d4090dSAlex Zinenko py::arg("type"));
11266d4090dSAlex Zinenko quantizedType.def(
11366d4090dSAlex Zinenko "cast_from_expressed_type",
11466d4090dSAlex Zinenko [](MlirType type, MlirType candidate) {
11566d4090dSAlex Zinenko MlirType castResult =
11666d4090dSAlex Zinenko mlirQuantizedTypeCastFromExpressedType(type, candidate);
11766d4090dSAlex Zinenko if (!mlirTypeIsNull(castResult))
11866d4090dSAlex Zinenko return castResult;
11966d4090dSAlex Zinenko throw py::type_error("Invalid cast.");
12066d4090dSAlex Zinenko },
12166d4090dSAlex Zinenko "Casts from a type based on the expressed type of this quantized type to "
12266d4090dSAlex Zinenko "a corresponding type based on the quantized type. Raises TypeError if "
12366d4090dSAlex Zinenko "the cast is not valid.",
12466d4090dSAlex Zinenko py::arg("candidate"));
12566d4090dSAlex Zinenko quantizedType.def_staticmethod(
12666d4090dSAlex Zinenko "cast_to_expressed_type",
12766d4090dSAlex Zinenko [](MlirType type) {
12866d4090dSAlex Zinenko MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
12966d4090dSAlex Zinenko if (!mlirTypeIsNull(castResult))
13066d4090dSAlex Zinenko return castResult;
13166d4090dSAlex Zinenko throw py::type_error("Invalid cast.");
13266d4090dSAlex Zinenko },
13366d4090dSAlex Zinenko "Casts from a type based on a quantized type to a corresponding type "
13466d4090dSAlex Zinenko "based on the expressed type of this quantized type. Raises TypeError if "
13566d4090dSAlex Zinenko "the cast is not valid.",
13666d4090dSAlex Zinenko py::arg("type"));
13766d4090dSAlex Zinenko quantizedType.def(
13866d4090dSAlex Zinenko "cast_expressed_to_storage_type",
13966d4090dSAlex Zinenko [](MlirType type, MlirType candidate) {
14066d4090dSAlex Zinenko MlirType castResult =
14166d4090dSAlex Zinenko mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
14266d4090dSAlex Zinenko if (!mlirTypeIsNull(castResult))
14366d4090dSAlex Zinenko return castResult;
14466d4090dSAlex Zinenko throw py::type_error("Invalid cast.");
14566d4090dSAlex Zinenko },
14666d4090dSAlex Zinenko "Casts from a type based on the expressed type of this quantized type to "
14766d4090dSAlex Zinenko "a corresponding type based on the storage type. Raises TypeError if the "
14866d4090dSAlex Zinenko "cast is not valid.",
14966d4090dSAlex Zinenko py::arg("candidate"));
15066d4090dSAlex Zinenko
15166d4090dSAlex Zinenko quantizedType.get_class().attr("FLAG_SIGNED") =
15266d4090dSAlex Zinenko mlirQuantizedTypeGetSignedFlag();
15366d4090dSAlex Zinenko
15466d4090dSAlex Zinenko //===-------------------------------------------------------------------===//
15566d4090dSAlex Zinenko // AnyQuantizedType
15666d4090dSAlex Zinenko //===-------------------------------------------------------------------===//
15766d4090dSAlex Zinenko
15866d4090dSAlex Zinenko auto anyQuantizedType =
15966d4090dSAlex Zinenko mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType,
16066d4090dSAlex Zinenko quantizedType.get_class());
16166d4090dSAlex Zinenko anyQuantizedType.def_classmethod(
16266d4090dSAlex Zinenko "get",
16366d4090dSAlex Zinenko [](py::object cls, unsigned flags, MlirType storageType,
16466d4090dSAlex Zinenko MlirType expressedType, int64_t storageTypeMin,
16566d4090dSAlex Zinenko int64_t storageTypeMax) {
16666d4090dSAlex Zinenko return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
16766d4090dSAlex Zinenko storageTypeMin, storageTypeMax));
16866d4090dSAlex Zinenko },
16966d4090dSAlex Zinenko "Gets an instance of AnyQuantizedType in the same context as the "
17066d4090dSAlex Zinenko "provided storage type.",
17166d4090dSAlex Zinenko py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
17266d4090dSAlex Zinenko py::arg("expressed_type"), py::arg("storage_type_min"),
17366d4090dSAlex Zinenko py::arg("storage_type_max"));
17466d4090dSAlex Zinenko
17566d4090dSAlex Zinenko //===-------------------------------------------------------------------===//
17666d4090dSAlex Zinenko // UniformQuantizedType
17766d4090dSAlex Zinenko //===-------------------------------------------------------------------===//
17866d4090dSAlex Zinenko
17966d4090dSAlex Zinenko auto uniformQuantizedType = mlir_type_subclass(
18066d4090dSAlex Zinenko m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType,
18166d4090dSAlex Zinenko quantizedType.get_class());
18266d4090dSAlex Zinenko uniformQuantizedType.def_classmethod(
18366d4090dSAlex Zinenko "get",
18466d4090dSAlex Zinenko [](py::object cls, unsigned flags, MlirType storageType,
18566d4090dSAlex Zinenko MlirType expressedType, double scale, int64_t zeroPoint,
18666d4090dSAlex Zinenko int64_t storageTypeMin, int64_t storageTypeMax) {
18766d4090dSAlex Zinenko return cls(mlirUniformQuantizedTypeGet(flags, storageType,
18866d4090dSAlex Zinenko expressedType, scale, zeroPoint,
18966d4090dSAlex Zinenko storageTypeMin, storageTypeMax));
19066d4090dSAlex Zinenko },
19166d4090dSAlex Zinenko "Gets an instance of UniformQuantizedType in the same context as the "
19266d4090dSAlex Zinenko "provided storage type.",
19366d4090dSAlex Zinenko py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
19466d4090dSAlex Zinenko py::arg("expressed_type"), py::arg("scale"), py::arg("zero_point"),
19566d4090dSAlex Zinenko py::arg("storage_type_min"), py::arg("storage_type_max"));
19666d4090dSAlex Zinenko uniformQuantizedType.def_property_readonly(
19766d4090dSAlex Zinenko "scale",
19866d4090dSAlex Zinenko [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); },
19966d4090dSAlex Zinenko "The scale designates the difference between the real values "
20066d4090dSAlex Zinenko "corresponding to consecutive quantized values differing by 1.");
20166d4090dSAlex Zinenko uniformQuantizedType.def_property_readonly(
20266d4090dSAlex Zinenko "zero_point",
20366d4090dSAlex Zinenko [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); },
20466d4090dSAlex Zinenko "The storage value corresponding to the real value 0 in the affine "
20566d4090dSAlex Zinenko "equation.");
20666d4090dSAlex Zinenko uniformQuantizedType.def_property_readonly(
20766d4090dSAlex Zinenko "is_fixed_point",
20866d4090dSAlex Zinenko [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); },
20966d4090dSAlex Zinenko "Fixed point values are real numbers divided by a scale.");
21066d4090dSAlex Zinenko
21166d4090dSAlex Zinenko //===-------------------------------------------------------------------===//
21266d4090dSAlex Zinenko // UniformQuantizedPerAxisType
21366d4090dSAlex Zinenko //===-------------------------------------------------------------------===//
21466d4090dSAlex Zinenko auto uniformQuantizedPerAxisType = mlir_type_subclass(
21566d4090dSAlex Zinenko m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType,
21666d4090dSAlex Zinenko quantizedType.get_class());
21766d4090dSAlex Zinenko uniformQuantizedPerAxisType.def_classmethod(
21866d4090dSAlex Zinenko "get",
21966d4090dSAlex Zinenko [](py::object cls, unsigned flags, MlirType storageType,
22066d4090dSAlex Zinenko MlirType expressedType, std::vector<double> scales,
22166d4090dSAlex Zinenko std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
22266d4090dSAlex Zinenko int64_t storageTypeMin, int64_t storageTypeMax) {
22366d4090dSAlex Zinenko if (scales.size() != zeroPoints.size())
22466d4090dSAlex Zinenko throw py::value_error(
22566d4090dSAlex Zinenko "Mismatching number of scales and zero points.");
22666d4090dSAlex Zinenko auto nDims = static_cast<intptr_t>(scales.size());
22766d4090dSAlex Zinenko return cls(mlirUniformQuantizedPerAxisTypeGet(
22866d4090dSAlex Zinenko flags, storageType, expressedType, nDims, scales.data(),
22966d4090dSAlex Zinenko zeroPoints.data(), quantizedDimension, storageTypeMin,
23066d4090dSAlex Zinenko storageTypeMax));
23166d4090dSAlex Zinenko },
23266d4090dSAlex Zinenko "Gets an instance of UniformQuantizedPerAxisType in the same context as "
23366d4090dSAlex Zinenko "the provided storage type.",
23466d4090dSAlex Zinenko py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
23566d4090dSAlex Zinenko py::arg("expressed_type"), py::arg("scales"), py::arg("zero_points"),
23666d4090dSAlex Zinenko py::arg("quantized_dimension"), py::arg("storage_type_min"),
23766d4090dSAlex Zinenko py::arg("storage_type_max"));
23866d4090dSAlex Zinenko uniformQuantizedPerAxisType.def_property_readonly(
23966d4090dSAlex Zinenko "scales",
24066d4090dSAlex Zinenko [](MlirType type) {
24166d4090dSAlex Zinenko intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
24266d4090dSAlex Zinenko std::vector<double> scales;
24366d4090dSAlex Zinenko scales.reserve(nDim);
24466d4090dSAlex Zinenko for (intptr_t i = 0; i < nDim; ++i) {
24566d4090dSAlex Zinenko double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i);
24666d4090dSAlex Zinenko scales.push_back(scale);
24766d4090dSAlex Zinenko }
24866d4090dSAlex Zinenko },
24966d4090dSAlex Zinenko "The scales designate the difference between the real values "
25066d4090dSAlex Zinenko "corresponding to consecutive quantized values differing by 1. The ith "
25166d4090dSAlex Zinenko "scale corresponds to the ith slice in the quantized_dimension.");
25266d4090dSAlex Zinenko uniformQuantizedPerAxisType.def_property_readonly(
25366d4090dSAlex Zinenko "zero_points",
25466d4090dSAlex Zinenko [](MlirType type) {
25566d4090dSAlex Zinenko intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
25666d4090dSAlex Zinenko std::vector<int64_t> zeroPoints;
25766d4090dSAlex Zinenko zeroPoints.reserve(nDim);
25866d4090dSAlex Zinenko for (intptr_t i = 0; i < nDim; ++i) {
25966d4090dSAlex Zinenko int64_t zeroPoint =
26066d4090dSAlex Zinenko mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i);
26166d4090dSAlex Zinenko zeroPoints.push_back(zeroPoint);
26266d4090dSAlex Zinenko }
26366d4090dSAlex Zinenko },
26466d4090dSAlex Zinenko "the storage values corresponding to the real value 0 in the affine "
26566d4090dSAlex Zinenko "equation. The ith zero point corresponds to the ith slice in the "
26666d4090dSAlex Zinenko "quantized_dimension.");
26766d4090dSAlex Zinenko uniformQuantizedPerAxisType.def_property_readonly(
26866d4090dSAlex Zinenko "quantized_dimension",
26966d4090dSAlex Zinenko [](MlirType type) {
27066d4090dSAlex Zinenko return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
27166d4090dSAlex Zinenko },
27266d4090dSAlex Zinenko "Specifies the dimension of the shape that the scales and zero points "
27366d4090dSAlex Zinenko "correspond to.");
27466d4090dSAlex Zinenko uniformQuantizedPerAxisType.def_property_readonly(
27566d4090dSAlex Zinenko "is_fixed_point",
27666d4090dSAlex Zinenko [](MlirType type) {
27766d4090dSAlex Zinenko return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
27866d4090dSAlex Zinenko },
27966d4090dSAlex Zinenko "Fixed point values are real numbers divided by a scale.");
28066d4090dSAlex Zinenko
28166d4090dSAlex Zinenko //===-------------------------------------------------------------------===//
28266d4090dSAlex Zinenko // CalibratedQuantizedType
28366d4090dSAlex Zinenko //===-------------------------------------------------------------------===//
28466d4090dSAlex Zinenko
28566d4090dSAlex Zinenko auto calibratedQuantizedType = mlir_type_subclass(
28666d4090dSAlex Zinenko m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType,
28766d4090dSAlex Zinenko quantizedType.get_class());
28866d4090dSAlex Zinenko calibratedQuantizedType.def_classmethod(
28966d4090dSAlex Zinenko "get",
29066d4090dSAlex Zinenko [](py::object cls, MlirType expressedType, double min, double max) {
29166d4090dSAlex Zinenko return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max));
29266d4090dSAlex Zinenko },
29366d4090dSAlex Zinenko "Gets an instance of CalibratedQuantizedType in the same context as the "
29466d4090dSAlex Zinenko "provided expressed type.",
29566d4090dSAlex Zinenko py::arg("cls"), py::arg("expressed_type"), py::arg("min"),
29666d4090dSAlex Zinenko py::arg("max"));
29766d4090dSAlex Zinenko calibratedQuantizedType.def_property_readonly("min", [](MlirType type) {
29866d4090dSAlex Zinenko return mlirCalibratedQuantizedTypeGetMin(type);
29966d4090dSAlex Zinenko });
30066d4090dSAlex Zinenko calibratedQuantizedType.def_property_readonly("max", [](MlirType type) {
30166d4090dSAlex Zinenko return mlirCalibratedQuantizedTypeGetMax(type);
30266d4090dSAlex Zinenko });
30366d4090dSAlex Zinenko }
304*95ddbed9SAlex Zinenko
PYBIND11_MODULE(_mlirDialectsQuant,m)305*95ddbed9SAlex Zinenko PYBIND11_MODULE(_mlirDialectsQuant, m) {
306*95ddbed9SAlex Zinenko m.doc() = "MLIR Quantization dialect";
307*95ddbed9SAlex Zinenko
308*95ddbed9SAlex Zinenko populateDialectQuantSubmodule(m);
309*95ddbed9SAlex Zinenko }
310