1*a8a2ee63SDenys Shabalin //===- Quant.cpp - C Interface for Quant dialect --------------------------===//
29bcf13bfSAlex Zinenko //
39bcf13bfSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49bcf13bfSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
59bcf13bfSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69bcf13bfSAlex Zinenko //
79bcf13bfSAlex Zinenko //===----------------------------------------------------------------------===//
89bcf13bfSAlex Zinenko 
99bcf13bfSAlex Zinenko #include "mlir-c/Dialect/Quant.h"
109bcf13bfSAlex Zinenko #include "mlir/CAPI/Registration.h"
119bcf13bfSAlex Zinenko #include "mlir/Dialect/Quant/QuantOps.h"
129bcf13bfSAlex Zinenko #include "mlir/Dialect/Quant/QuantTypes.h"
139bcf13bfSAlex Zinenko 
149bcf13bfSAlex Zinenko using namespace mlir;
159bcf13bfSAlex Zinenko 
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant,quant,quant::QuantizationDialect)169bcf13bfSAlex Zinenko MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect)
179bcf13bfSAlex Zinenko 
189bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
199bcf13bfSAlex Zinenko // QuantizedType
209bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
219bcf13bfSAlex Zinenko 
229bcf13bfSAlex Zinenko bool mlirTypeIsAQuantizedType(MlirType type) {
239bcf13bfSAlex Zinenko   return unwrap(type).isa<quant::QuantizedType>();
249bcf13bfSAlex Zinenko }
259bcf13bfSAlex Zinenko 
mlirQuantizedTypeGetSignedFlag()269bcf13bfSAlex Zinenko unsigned mlirQuantizedTypeGetSignedFlag() {
279bcf13bfSAlex Zinenko   return quant::QuantizationFlags::Signed;
289bcf13bfSAlex Zinenko }
299bcf13bfSAlex Zinenko 
mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned,unsigned integralWidth)309bcf13bfSAlex Zinenko int64_t mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned,
319bcf13bfSAlex Zinenko                                                      unsigned integralWidth) {
329bcf13bfSAlex Zinenko   return quant::QuantizedType::getDefaultMinimumForInteger(isSigned,
339bcf13bfSAlex Zinenko                                                            integralWidth);
349bcf13bfSAlex Zinenko }
359bcf13bfSAlex Zinenko 
mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned,unsigned integralWidth)369bcf13bfSAlex Zinenko int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned,
379bcf13bfSAlex Zinenko                                                      unsigned integralWidth) {
389bcf13bfSAlex Zinenko   return quant::QuantizedType::getDefaultMaximumForInteger(isSigned,
399bcf13bfSAlex Zinenko                                                            integralWidth);
409bcf13bfSAlex Zinenko }
419bcf13bfSAlex Zinenko 
mlirQuantizedTypeGetExpressedType(MlirType type)429bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeGetExpressedType(MlirType type) {
439bcf13bfSAlex Zinenko   return wrap(unwrap(type).cast<quant::QuantizedType>().getExpressedType());
449bcf13bfSAlex Zinenko }
459bcf13bfSAlex Zinenko 
mlirQuantizedTypeGetFlags(MlirType type)469bcf13bfSAlex Zinenko unsigned mlirQuantizedTypeGetFlags(MlirType type) {
479bcf13bfSAlex Zinenko   return unwrap(type).cast<quant::QuantizedType>().getFlags();
489bcf13bfSAlex Zinenko }
499bcf13bfSAlex Zinenko 
mlirQuantizedTypeIsSigned(MlirType type)509bcf13bfSAlex Zinenko bool mlirQuantizedTypeIsSigned(MlirType type) {
519bcf13bfSAlex Zinenko   return unwrap(type).cast<quant::QuantizedType>().isSigned();
529bcf13bfSAlex Zinenko }
539bcf13bfSAlex Zinenko 
mlirQuantizedTypeGetStorageType(MlirType type)549bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeGetStorageType(MlirType type) {
559bcf13bfSAlex Zinenko   return wrap(unwrap(type).cast<quant::QuantizedType>().getStorageType());
569bcf13bfSAlex Zinenko }
579bcf13bfSAlex Zinenko 
mlirQuantizedTypeGetStorageTypeMin(MlirType type)589bcf13bfSAlex Zinenko int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) {
599bcf13bfSAlex Zinenko   return unwrap(type).cast<quant::QuantizedType>().getStorageTypeMin();
609bcf13bfSAlex Zinenko }
619bcf13bfSAlex Zinenko 
mlirQuantizedTypeGetStorageTypeMax(MlirType type)629bcf13bfSAlex Zinenko int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) {
639bcf13bfSAlex Zinenko   return unwrap(type).cast<quant::QuantizedType>().getStorageTypeMax();
649bcf13bfSAlex Zinenko }
659bcf13bfSAlex Zinenko 
mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type)669bcf13bfSAlex Zinenko unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) {
679bcf13bfSAlex Zinenko   return unwrap(type)
689bcf13bfSAlex Zinenko       .cast<quant::QuantizedType>()
699bcf13bfSAlex Zinenko       .getStorageTypeIntegralWidth();
709bcf13bfSAlex Zinenko }
719bcf13bfSAlex Zinenko 
mlirQuantizedTypeIsCompatibleExpressedType(MlirType type,MlirType candidate)729bcf13bfSAlex Zinenko bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type,
739bcf13bfSAlex Zinenko                                                 MlirType candidate) {
749bcf13bfSAlex Zinenko   return unwrap(type).cast<quant::QuantizedType>().isCompatibleExpressedType(
759bcf13bfSAlex Zinenko       unwrap(candidate));
769bcf13bfSAlex Zinenko }
779bcf13bfSAlex Zinenko 
mlirQuantizedTypeGetQuantizedElementType(MlirType type)789bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) {
799bcf13bfSAlex Zinenko   return wrap(quant::QuantizedType::getQuantizedElementType(unwrap(type)));
809bcf13bfSAlex Zinenko }
819bcf13bfSAlex Zinenko 
mlirQuantizedTypeCastFromStorageType(MlirType type,MlirType candidate)829bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeCastFromStorageType(MlirType type,
839bcf13bfSAlex Zinenko                                               MlirType candidate) {
849bcf13bfSAlex Zinenko   return wrap(unwrap(type).cast<quant::QuantizedType>().castFromStorageType(
859bcf13bfSAlex Zinenko       unwrap(candidate)));
869bcf13bfSAlex Zinenko }
879bcf13bfSAlex Zinenko 
mlirQuantizedTypeCastToStorageType(MlirType type)889bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeCastToStorageType(MlirType type) {
899bcf13bfSAlex Zinenko   return wrap(quant::QuantizedType::castToStorageType(
909bcf13bfSAlex Zinenko       unwrap(type).cast<quant::QuantizedType>()));
919bcf13bfSAlex Zinenko }
929bcf13bfSAlex Zinenko 
mlirQuantizedTypeCastFromExpressedType(MlirType type,MlirType candidate)939bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type,
949bcf13bfSAlex Zinenko                                                 MlirType candidate) {
959bcf13bfSAlex Zinenko   return wrap(unwrap(type).cast<quant::QuantizedType>().castFromExpressedType(
969bcf13bfSAlex Zinenko       unwrap(candidate)));
979bcf13bfSAlex Zinenko }
989bcf13bfSAlex Zinenko 
mlirQuantizedTypeCastToExpressedType(MlirType type)999bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) {
1009bcf13bfSAlex Zinenko   return wrap(quant::QuantizedType::castToExpressedType(unwrap(type)));
1019bcf13bfSAlex Zinenko }
1029bcf13bfSAlex Zinenko 
mlirQuantizedTypeCastExpressedToStorageType(MlirType type,MlirType candidate)1039bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type,
1049bcf13bfSAlex Zinenko                                                      MlirType candidate) {
1059bcf13bfSAlex Zinenko   return wrap(
1069bcf13bfSAlex Zinenko       unwrap(type).cast<quant::QuantizedType>().castExpressedToStorageType(
1079bcf13bfSAlex Zinenko           unwrap(candidate)));
1089bcf13bfSAlex Zinenko }
1099bcf13bfSAlex Zinenko 
1109bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
1119bcf13bfSAlex Zinenko // AnyQuantizedType
1129bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
1139bcf13bfSAlex Zinenko 
mlirTypeIsAAnyQuantizedType(MlirType type)1149bcf13bfSAlex Zinenko bool mlirTypeIsAAnyQuantizedType(MlirType type) {
1159bcf13bfSAlex Zinenko   return unwrap(type).isa<quant::AnyQuantizedType>();
1169bcf13bfSAlex Zinenko }
1179bcf13bfSAlex Zinenko 
mlirAnyQuantizedTypeGet(unsigned flags,MlirType storageType,MlirType expressedType,int64_t storageTypeMin,int64_t storageTypeMax)1189bcf13bfSAlex Zinenko MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType,
1199bcf13bfSAlex Zinenko                                  MlirType expressedType, int64_t storageTypeMin,
1209bcf13bfSAlex Zinenko                                  int64_t storageTypeMax) {
1219bcf13bfSAlex Zinenko   return wrap(quant::AnyQuantizedType::get(flags, unwrap(storageType),
1229bcf13bfSAlex Zinenko                                            unwrap(expressedType),
1239bcf13bfSAlex Zinenko                                            storageTypeMin, storageTypeMax));
1249bcf13bfSAlex Zinenko }
1259bcf13bfSAlex Zinenko 
1269bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
1279bcf13bfSAlex Zinenko // UniformQuantizedType
1289bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
1299bcf13bfSAlex Zinenko 
mlirTypeIsAUniformQuantizedType(MlirType type)1309bcf13bfSAlex Zinenko bool mlirTypeIsAUniformQuantizedType(MlirType type) {
1319bcf13bfSAlex Zinenko   return unwrap(type).isa<quant::UniformQuantizedType>();
1329bcf13bfSAlex Zinenko }
1339bcf13bfSAlex Zinenko 
mlirUniformQuantizedTypeGet(unsigned flags,MlirType storageType,MlirType expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax)1349bcf13bfSAlex Zinenko MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType,
1359bcf13bfSAlex Zinenko                                      MlirType expressedType, double scale,
1369bcf13bfSAlex Zinenko                                      int64_t zeroPoint, int64_t storageTypeMin,
1379bcf13bfSAlex Zinenko                                      int64_t storageTypeMax) {
1389bcf13bfSAlex Zinenko   return wrap(quant::UniformQuantizedType::get(
1399bcf13bfSAlex Zinenko       flags, unwrap(storageType), unwrap(expressedType), scale, zeroPoint,
1409bcf13bfSAlex Zinenko       storageTypeMin, storageTypeMax));
1419bcf13bfSAlex Zinenko }
1429bcf13bfSAlex Zinenko 
mlirUniformQuantizedTypeGetScale(MlirType type)1439bcf13bfSAlex Zinenko double mlirUniformQuantizedTypeGetScale(MlirType type) {
1449bcf13bfSAlex Zinenko   return unwrap(type).cast<quant::UniformQuantizedType>().getScale();
1459bcf13bfSAlex Zinenko }
1469bcf13bfSAlex Zinenko 
mlirUniformQuantizedTypeGetZeroPoint(MlirType type)1479bcf13bfSAlex Zinenko int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) {
1489bcf13bfSAlex Zinenko   return unwrap(type).cast<quant::UniformQuantizedType>().getZeroPoint();
1499bcf13bfSAlex Zinenko }
1509bcf13bfSAlex Zinenko 
mlirUniformQuantizedTypeIsFixedPoint(MlirType type)1519bcf13bfSAlex Zinenko bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) {
1529bcf13bfSAlex Zinenko   return unwrap(type).cast<quant::UniformQuantizedType>().isFixedPoint();
1539bcf13bfSAlex Zinenko }
1549bcf13bfSAlex Zinenko 
1559bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
1569bcf13bfSAlex Zinenko // UniformQuantizedPerAxisType
1579bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
1589bcf13bfSAlex Zinenko 
mlirTypeIsAUniformQuantizedPerAxisType(MlirType type)1599bcf13bfSAlex Zinenko bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) {
1609bcf13bfSAlex Zinenko   return unwrap(type).isa<quant::UniformQuantizedPerAxisType>();
1619bcf13bfSAlex Zinenko }
1629bcf13bfSAlex Zinenko 
mlirUniformQuantizedPerAxisTypeGet(unsigned flags,MlirType storageType,MlirType expressedType,intptr_t nDims,double * scales,int64_t * zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax)1639bcf13bfSAlex Zinenko MlirType mlirUniformQuantizedPerAxisTypeGet(
1649bcf13bfSAlex Zinenko     unsigned flags, MlirType storageType, MlirType expressedType,
1659bcf13bfSAlex Zinenko     intptr_t nDims, double *scales, int64_t *zeroPoints,
1669bcf13bfSAlex Zinenko     int32_t quantizedDimension, int64_t storageTypeMin,
1679bcf13bfSAlex Zinenko     int64_t storageTypeMax) {
1689bcf13bfSAlex Zinenko   return wrap(quant::UniformQuantizedPerAxisType::get(
1699bcf13bfSAlex Zinenko       flags, unwrap(storageType), unwrap(expressedType),
1709bcf13bfSAlex Zinenko       llvm::makeArrayRef(scales, nDims), llvm::makeArrayRef(zeroPoints, nDims),
1719bcf13bfSAlex Zinenko       quantizedDimension, storageTypeMin, storageTypeMax));
1729bcf13bfSAlex Zinenko }
1739bcf13bfSAlex Zinenko 
mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type)1749bcf13bfSAlex Zinenko intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) {
1759bcf13bfSAlex Zinenko   return unwrap(type)
1769bcf13bfSAlex Zinenko       .cast<quant::UniformQuantizedPerAxisType>()
1779bcf13bfSAlex Zinenko       .getScales()
1789bcf13bfSAlex Zinenko       .size();
1799bcf13bfSAlex Zinenko }
1809bcf13bfSAlex Zinenko 
mlirUniformQuantizedPerAxisTypeGetScale(MlirType type,intptr_t pos)1819bcf13bfSAlex Zinenko double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) {
1829bcf13bfSAlex Zinenko   return unwrap(type)
1839bcf13bfSAlex Zinenko       .cast<quant::UniformQuantizedPerAxisType>()
1849bcf13bfSAlex Zinenko       .getScales()[pos];
1859bcf13bfSAlex Zinenko }
1869bcf13bfSAlex Zinenko 
mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type,intptr_t pos)1879bcf13bfSAlex Zinenko int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type,
1889bcf13bfSAlex Zinenko                                                     intptr_t pos) {
1899bcf13bfSAlex Zinenko   return unwrap(type)
1909bcf13bfSAlex Zinenko       .cast<quant::UniformQuantizedPerAxisType>()
1919bcf13bfSAlex Zinenko       .getZeroPoints()[pos];
1929bcf13bfSAlex Zinenko }
1939bcf13bfSAlex Zinenko 
mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type)1949bcf13bfSAlex Zinenko int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) {
1959bcf13bfSAlex Zinenko   return unwrap(type)
1969bcf13bfSAlex Zinenko       .cast<quant::UniformQuantizedPerAxisType>()
1979bcf13bfSAlex Zinenko       .getQuantizedDimension();
1989bcf13bfSAlex Zinenko }
1999bcf13bfSAlex Zinenko 
mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type)2009bcf13bfSAlex Zinenko bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) {
2019bcf13bfSAlex Zinenko   return unwrap(type).cast<quant::UniformQuantizedPerAxisType>().isFixedPoint();
2029bcf13bfSAlex Zinenko }
2039bcf13bfSAlex Zinenko 
2049bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
2059bcf13bfSAlex Zinenko // CalibratedQuantizedType
2069bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
2079bcf13bfSAlex Zinenko 
mlirTypeIsACalibratedQuantizedType(MlirType type)2089bcf13bfSAlex Zinenko bool mlirTypeIsACalibratedQuantizedType(MlirType type) {
2099bcf13bfSAlex Zinenko   return unwrap(type).isa<quant::CalibratedQuantizedType>();
2109bcf13bfSAlex Zinenko }
2119bcf13bfSAlex Zinenko 
mlirCalibratedQuantizedTypeGet(MlirType expressedType,double min,double max)2129bcf13bfSAlex Zinenko MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min,
2139bcf13bfSAlex Zinenko                                         double max) {
2149bcf13bfSAlex Zinenko   return wrap(
2159bcf13bfSAlex Zinenko       quant::CalibratedQuantizedType::get(unwrap(expressedType), min, max));
2169bcf13bfSAlex Zinenko }
2179bcf13bfSAlex Zinenko 
mlirCalibratedQuantizedTypeGetMin(MlirType type)2189bcf13bfSAlex Zinenko double mlirCalibratedQuantizedTypeGetMin(MlirType type) {
2199bcf13bfSAlex Zinenko   return unwrap(type).cast<quant::CalibratedQuantizedType>().getMin();
2209bcf13bfSAlex Zinenko }
2219bcf13bfSAlex Zinenko 
mlirCalibratedQuantizedTypeGetMax(MlirType type)2229bcf13bfSAlex Zinenko double mlirCalibratedQuantizedTypeGetMax(MlirType type) {
2239bcf13bfSAlex Zinenko   return unwrap(type).cast<quant::CalibratedQuantizedType>().getMax();
2249bcf13bfSAlex Zinenko }
225