1 //===- Quant.cpp - C Interface for Quant dialect --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/Quant.h"
10 #include "mlir/CAPI/Registration.h"
11 #include "mlir/Dialect/Quant/QuantOps.h"
12 #include "mlir/Dialect/Quant/QuantTypes.h"
13 
14 using namespace mlir;
15 
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant,quant,quant::QuantizationDialect)16 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect)
17 
18 //===---------------------------------------------------------------------===//
19 // QuantizedType
20 //===---------------------------------------------------------------------===//
21 
22 bool mlirTypeIsAQuantizedType(MlirType type) {
23   return unwrap(type).isa<quant::QuantizedType>();
24 }
25 
mlirQuantizedTypeGetSignedFlag()26 unsigned mlirQuantizedTypeGetSignedFlag() {
27   return quant::QuantizationFlags::Signed;
28 }
29 
mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned,unsigned integralWidth)30 int64_t mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned,
31                                                      unsigned integralWidth) {
32   return quant::QuantizedType::getDefaultMinimumForInteger(isSigned,
33                                                            integralWidth);
34 }
35 
mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned,unsigned integralWidth)36 int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned,
37                                                      unsigned integralWidth) {
38   return quant::QuantizedType::getDefaultMaximumForInteger(isSigned,
39                                                            integralWidth);
40 }
41 
mlirQuantizedTypeGetExpressedType(MlirType type)42 MlirType mlirQuantizedTypeGetExpressedType(MlirType type) {
43   return wrap(unwrap(type).cast<quant::QuantizedType>().getExpressedType());
44 }
45 
mlirQuantizedTypeGetFlags(MlirType type)46 unsigned mlirQuantizedTypeGetFlags(MlirType type) {
47   return unwrap(type).cast<quant::QuantizedType>().getFlags();
48 }
49 
mlirQuantizedTypeIsSigned(MlirType type)50 bool mlirQuantizedTypeIsSigned(MlirType type) {
51   return unwrap(type).cast<quant::QuantizedType>().isSigned();
52 }
53 
mlirQuantizedTypeGetStorageType(MlirType type)54 MlirType mlirQuantizedTypeGetStorageType(MlirType type) {
55   return wrap(unwrap(type).cast<quant::QuantizedType>().getStorageType());
56 }
57 
mlirQuantizedTypeGetStorageTypeMin(MlirType type)58 int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) {
59   return unwrap(type).cast<quant::QuantizedType>().getStorageTypeMin();
60 }
61 
mlirQuantizedTypeGetStorageTypeMax(MlirType type)62 int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) {
63   return unwrap(type).cast<quant::QuantizedType>().getStorageTypeMax();
64 }
65 
mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type)66 unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) {
67   return unwrap(type)
68       .cast<quant::QuantizedType>()
69       .getStorageTypeIntegralWidth();
70 }
71 
mlirQuantizedTypeIsCompatibleExpressedType(MlirType type,MlirType candidate)72 bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type,
73                                                 MlirType candidate) {
74   return unwrap(type).cast<quant::QuantizedType>().isCompatibleExpressedType(
75       unwrap(candidate));
76 }
77 
mlirQuantizedTypeGetQuantizedElementType(MlirType type)78 MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) {
79   return wrap(quant::QuantizedType::getQuantizedElementType(unwrap(type)));
80 }
81 
mlirQuantizedTypeCastFromStorageType(MlirType type,MlirType candidate)82 MlirType mlirQuantizedTypeCastFromStorageType(MlirType type,
83                                               MlirType candidate) {
84   return wrap(unwrap(type).cast<quant::QuantizedType>().castFromStorageType(
85       unwrap(candidate)));
86 }
87 
mlirQuantizedTypeCastToStorageType(MlirType type)88 MlirType mlirQuantizedTypeCastToStorageType(MlirType type) {
89   return wrap(quant::QuantizedType::castToStorageType(
90       unwrap(type).cast<quant::QuantizedType>()));
91 }
92 
mlirQuantizedTypeCastFromExpressedType(MlirType type,MlirType candidate)93 MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type,
94                                                 MlirType candidate) {
95   return wrap(unwrap(type).cast<quant::QuantizedType>().castFromExpressedType(
96       unwrap(candidate)));
97 }
98 
mlirQuantizedTypeCastToExpressedType(MlirType type)99 MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) {
100   return wrap(quant::QuantizedType::castToExpressedType(unwrap(type)));
101 }
102 
mlirQuantizedTypeCastExpressedToStorageType(MlirType type,MlirType candidate)103 MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type,
104                                                      MlirType candidate) {
105   return wrap(
106       unwrap(type).cast<quant::QuantizedType>().castExpressedToStorageType(
107           unwrap(candidate)));
108 }
109 
110 //===---------------------------------------------------------------------===//
111 // AnyQuantizedType
112 //===---------------------------------------------------------------------===//
113 
mlirTypeIsAAnyQuantizedType(MlirType type)114 bool mlirTypeIsAAnyQuantizedType(MlirType type) {
115   return unwrap(type).isa<quant::AnyQuantizedType>();
116 }
117 
mlirAnyQuantizedTypeGet(unsigned flags,MlirType storageType,MlirType expressedType,int64_t storageTypeMin,int64_t storageTypeMax)118 MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType,
119                                  MlirType expressedType, int64_t storageTypeMin,
120                                  int64_t storageTypeMax) {
121   return wrap(quant::AnyQuantizedType::get(flags, unwrap(storageType),
122                                            unwrap(expressedType),
123                                            storageTypeMin, storageTypeMax));
124 }
125 
126 //===---------------------------------------------------------------------===//
127 // UniformQuantizedType
128 //===---------------------------------------------------------------------===//
129 
mlirTypeIsAUniformQuantizedType(MlirType type)130 bool mlirTypeIsAUniformQuantizedType(MlirType type) {
131   return unwrap(type).isa<quant::UniformQuantizedType>();
132 }
133 
mlirUniformQuantizedTypeGet(unsigned flags,MlirType storageType,MlirType expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax)134 MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType,
135                                      MlirType expressedType, double scale,
136                                      int64_t zeroPoint, int64_t storageTypeMin,
137                                      int64_t storageTypeMax) {
138   return wrap(quant::UniformQuantizedType::get(
139       flags, unwrap(storageType), unwrap(expressedType), scale, zeroPoint,
140       storageTypeMin, storageTypeMax));
141 }
142 
mlirUniformQuantizedTypeGetScale(MlirType type)143 double mlirUniformQuantizedTypeGetScale(MlirType type) {
144   return unwrap(type).cast<quant::UniformQuantizedType>().getScale();
145 }
146 
mlirUniformQuantizedTypeGetZeroPoint(MlirType type)147 int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) {
148   return unwrap(type).cast<quant::UniformQuantizedType>().getZeroPoint();
149 }
150 
mlirUniformQuantizedTypeIsFixedPoint(MlirType type)151 bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) {
152   return unwrap(type).cast<quant::UniformQuantizedType>().isFixedPoint();
153 }
154 
155 //===---------------------------------------------------------------------===//
156 // UniformQuantizedPerAxisType
157 //===---------------------------------------------------------------------===//
158 
mlirTypeIsAUniformQuantizedPerAxisType(MlirType type)159 bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) {
160   return unwrap(type).isa<quant::UniformQuantizedPerAxisType>();
161 }
162 
mlirUniformQuantizedPerAxisTypeGet(unsigned flags,MlirType storageType,MlirType expressedType,intptr_t nDims,double * scales,int64_t * zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax)163 MlirType mlirUniformQuantizedPerAxisTypeGet(
164     unsigned flags, MlirType storageType, MlirType expressedType,
165     intptr_t nDims, double *scales, int64_t *zeroPoints,
166     int32_t quantizedDimension, int64_t storageTypeMin,
167     int64_t storageTypeMax) {
168   return wrap(quant::UniformQuantizedPerAxisType::get(
169       flags, unwrap(storageType), unwrap(expressedType),
170       llvm::makeArrayRef(scales, nDims), llvm::makeArrayRef(zeroPoints, nDims),
171       quantizedDimension, storageTypeMin, storageTypeMax));
172 }
173 
mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type)174 intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) {
175   return unwrap(type)
176       .cast<quant::UniformQuantizedPerAxisType>()
177       .getScales()
178       .size();
179 }
180 
mlirUniformQuantizedPerAxisTypeGetScale(MlirType type,intptr_t pos)181 double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) {
182   return unwrap(type)
183       .cast<quant::UniformQuantizedPerAxisType>()
184       .getScales()[pos];
185 }
186 
mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type,intptr_t pos)187 int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type,
188                                                     intptr_t pos) {
189   return unwrap(type)
190       .cast<quant::UniformQuantizedPerAxisType>()
191       .getZeroPoints()[pos];
192 }
193 
mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type)194 int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) {
195   return unwrap(type)
196       .cast<quant::UniformQuantizedPerAxisType>()
197       .getQuantizedDimension();
198 }
199 
mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type)200 bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) {
201   return unwrap(type).cast<quant::UniformQuantizedPerAxisType>().isFixedPoint();
202 }
203 
204 //===---------------------------------------------------------------------===//
205 // CalibratedQuantizedType
206 //===---------------------------------------------------------------------===//
207 
mlirTypeIsACalibratedQuantizedType(MlirType type)208 bool mlirTypeIsACalibratedQuantizedType(MlirType type) {
209   return unwrap(type).isa<quant::CalibratedQuantizedType>();
210 }
211 
mlirCalibratedQuantizedTypeGet(MlirType expressedType,double min,double max)212 MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min,
213                                         double max) {
214   return wrap(
215       quant::CalibratedQuantizedType::get(unwrap(expressedType), min, max));
216 }
217 
mlirCalibratedQuantizedTypeGetMin(MlirType type)218 double mlirCalibratedQuantizedTypeGetMin(MlirType type) {
219   return unwrap(type).cast<quant::CalibratedQuantizedType>().getMin();
220 }
221 
mlirCalibratedQuantizedTypeGetMax(MlirType type)222 double mlirCalibratedQuantizedTypeGetMax(MlirType type) {
223   return unwrap(type).cast<quant::CalibratedQuantizedType>().getMax();
224 }
225