1 //===- Quant.cpp - C Interface for Quant dialect --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir-c/Dialect/Quant.h"
10 #include "mlir/CAPI/Registration.h"
11 #include "mlir/Dialect/Quant/QuantOps.h"
12 #include "mlir/Dialect/Quant/QuantTypes.h"
13
14 using namespace mlir;
15
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant,quant,quant::QuantizationDialect)16 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect)
17
18 //===---------------------------------------------------------------------===//
19 // QuantizedType
20 //===---------------------------------------------------------------------===//
21
22 bool mlirTypeIsAQuantizedType(MlirType type) {
23 return unwrap(type).isa<quant::QuantizedType>();
24 }
25
mlirQuantizedTypeGetSignedFlag()26 unsigned mlirQuantizedTypeGetSignedFlag() {
27 return quant::QuantizationFlags::Signed;
28 }
29
mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned,unsigned integralWidth)30 int64_t mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned,
31 unsigned integralWidth) {
32 return quant::QuantizedType::getDefaultMinimumForInteger(isSigned,
33 integralWidth);
34 }
35
mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned,unsigned integralWidth)36 int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned,
37 unsigned integralWidth) {
38 return quant::QuantizedType::getDefaultMaximumForInteger(isSigned,
39 integralWidth);
40 }
41
mlirQuantizedTypeGetExpressedType(MlirType type)42 MlirType mlirQuantizedTypeGetExpressedType(MlirType type) {
43 return wrap(unwrap(type).cast<quant::QuantizedType>().getExpressedType());
44 }
45
mlirQuantizedTypeGetFlags(MlirType type)46 unsigned mlirQuantizedTypeGetFlags(MlirType type) {
47 return unwrap(type).cast<quant::QuantizedType>().getFlags();
48 }
49
mlirQuantizedTypeIsSigned(MlirType type)50 bool mlirQuantizedTypeIsSigned(MlirType type) {
51 return unwrap(type).cast<quant::QuantizedType>().isSigned();
52 }
53
mlirQuantizedTypeGetStorageType(MlirType type)54 MlirType mlirQuantizedTypeGetStorageType(MlirType type) {
55 return wrap(unwrap(type).cast<quant::QuantizedType>().getStorageType());
56 }
57
mlirQuantizedTypeGetStorageTypeMin(MlirType type)58 int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) {
59 return unwrap(type).cast<quant::QuantizedType>().getStorageTypeMin();
60 }
61
mlirQuantizedTypeGetStorageTypeMax(MlirType type)62 int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) {
63 return unwrap(type).cast<quant::QuantizedType>().getStorageTypeMax();
64 }
65
mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type)66 unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) {
67 return unwrap(type)
68 .cast<quant::QuantizedType>()
69 .getStorageTypeIntegralWidth();
70 }
71
mlirQuantizedTypeIsCompatibleExpressedType(MlirType type,MlirType candidate)72 bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type,
73 MlirType candidate) {
74 return unwrap(type).cast<quant::QuantizedType>().isCompatibleExpressedType(
75 unwrap(candidate));
76 }
77
mlirQuantizedTypeGetQuantizedElementType(MlirType type)78 MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) {
79 return wrap(quant::QuantizedType::getQuantizedElementType(unwrap(type)));
80 }
81
mlirQuantizedTypeCastFromStorageType(MlirType type,MlirType candidate)82 MlirType mlirQuantizedTypeCastFromStorageType(MlirType type,
83 MlirType candidate) {
84 return wrap(unwrap(type).cast<quant::QuantizedType>().castFromStorageType(
85 unwrap(candidate)));
86 }
87
mlirQuantizedTypeCastToStorageType(MlirType type)88 MlirType mlirQuantizedTypeCastToStorageType(MlirType type) {
89 return wrap(quant::QuantizedType::castToStorageType(
90 unwrap(type).cast<quant::QuantizedType>()));
91 }
92
mlirQuantizedTypeCastFromExpressedType(MlirType type,MlirType candidate)93 MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type,
94 MlirType candidate) {
95 return wrap(unwrap(type).cast<quant::QuantizedType>().castFromExpressedType(
96 unwrap(candidate)));
97 }
98
mlirQuantizedTypeCastToExpressedType(MlirType type)99 MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) {
100 return wrap(quant::QuantizedType::castToExpressedType(unwrap(type)));
101 }
102
mlirQuantizedTypeCastExpressedToStorageType(MlirType type,MlirType candidate)103 MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type,
104 MlirType candidate) {
105 return wrap(
106 unwrap(type).cast<quant::QuantizedType>().castExpressedToStorageType(
107 unwrap(candidate)));
108 }
109
110 //===---------------------------------------------------------------------===//
111 // AnyQuantizedType
112 //===---------------------------------------------------------------------===//
113
mlirTypeIsAAnyQuantizedType(MlirType type)114 bool mlirTypeIsAAnyQuantizedType(MlirType type) {
115 return unwrap(type).isa<quant::AnyQuantizedType>();
116 }
117
mlirAnyQuantizedTypeGet(unsigned flags,MlirType storageType,MlirType expressedType,int64_t storageTypeMin,int64_t storageTypeMax)118 MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType,
119 MlirType expressedType, int64_t storageTypeMin,
120 int64_t storageTypeMax) {
121 return wrap(quant::AnyQuantizedType::get(flags, unwrap(storageType),
122 unwrap(expressedType),
123 storageTypeMin, storageTypeMax));
124 }
125
126 //===---------------------------------------------------------------------===//
127 // UniformQuantizedType
128 //===---------------------------------------------------------------------===//
129
mlirTypeIsAUniformQuantizedType(MlirType type)130 bool mlirTypeIsAUniformQuantizedType(MlirType type) {
131 return unwrap(type).isa<quant::UniformQuantizedType>();
132 }
133
mlirUniformQuantizedTypeGet(unsigned flags,MlirType storageType,MlirType expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax)134 MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType,
135 MlirType expressedType, double scale,
136 int64_t zeroPoint, int64_t storageTypeMin,
137 int64_t storageTypeMax) {
138 return wrap(quant::UniformQuantizedType::get(
139 flags, unwrap(storageType), unwrap(expressedType), scale, zeroPoint,
140 storageTypeMin, storageTypeMax));
141 }
142
mlirUniformQuantizedTypeGetScale(MlirType type)143 double mlirUniformQuantizedTypeGetScale(MlirType type) {
144 return unwrap(type).cast<quant::UniformQuantizedType>().getScale();
145 }
146
mlirUniformQuantizedTypeGetZeroPoint(MlirType type)147 int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) {
148 return unwrap(type).cast<quant::UniformQuantizedType>().getZeroPoint();
149 }
150
mlirUniformQuantizedTypeIsFixedPoint(MlirType type)151 bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) {
152 return unwrap(type).cast<quant::UniformQuantizedType>().isFixedPoint();
153 }
154
155 //===---------------------------------------------------------------------===//
156 // UniformQuantizedPerAxisType
157 //===---------------------------------------------------------------------===//
158
mlirTypeIsAUniformQuantizedPerAxisType(MlirType type)159 bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) {
160 return unwrap(type).isa<quant::UniformQuantizedPerAxisType>();
161 }
162
mlirUniformQuantizedPerAxisTypeGet(unsigned flags,MlirType storageType,MlirType expressedType,intptr_t nDims,double * scales,int64_t * zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax)163 MlirType mlirUniformQuantizedPerAxisTypeGet(
164 unsigned flags, MlirType storageType, MlirType expressedType,
165 intptr_t nDims, double *scales, int64_t *zeroPoints,
166 int32_t quantizedDimension, int64_t storageTypeMin,
167 int64_t storageTypeMax) {
168 return wrap(quant::UniformQuantizedPerAxisType::get(
169 flags, unwrap(storageType), unwrap(expressedType),
170 llvm::makeArrayRef(scales, nDims), llvm::makeArrayRef(zeroPoints, nDims),
171 quantizedDimension, storageTypeMin, storageTypeMax));
172 }
173
mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type)174 intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) {
175 return unwrap(type)
176 .cast<quant::UniformQuantizedPerAxisType>()
177 .getScales()
178 .size();
179 }
180
mlirUniformQuantizedPerAxisTypeGetScale(MlirType type,intptr_t pos)181 double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) {
182 return unwrap(type)
183 .cast<quant::UniformQuantizedPerAxisType>()
184 .getScales()[pos];
185 }
186
mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type,intptr_t pos)187 int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type,
188 intptr_t pos) {
189 return unwrap(type)
190 .cast<quant::UniformQuantizedPerAxisType>()
191 .getZeroPoints()[pos];
192 }
193
mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type)194 int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) {
195 return unwrap(type)
196 .cast<quant::UniformQuantizedPerAxisType>()
197 .getQuantizedDimension();
198 }
199
mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type)200 bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) {
201 return unwrap(type).cast<quant::UniformQuantizedPerAxisType>().isFixedPoint();
202 }
203
204 //===---------------------------------------------------------------------===//
205 // CalibratedQuantizedType
206 //===---------------------------------------------------------------------===//
207
mlirTypeIsACalibratedQuantizedType(MlirType type)208 bool mlirTypeIsACalibratedQuantizedType(MlirType type) {
209 return unwrap(type).isa<quant::CalibratedQuantizedType>();
210 }
211
mlirCalibratedQuantizedTypeGet(MlirType expressedType,double min,double max)212 MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min,
213 double max) {
214 return wrap(
215 quant::CalibratedQuantizedType::get(unwrap(expressedType), min, max));
216 }
217
mlirCalibratedQuantizedTypeGetMin(MlirType type)218 double mlirCalibratedQuantizedTypeGetMin(MlirType type) {
219 return unwrap(type).cast<quant::CalibratedQuantizedType>().getMin();
220 }
221
mlirCalibratedQuantizedTypeGetMax(MlirType type)222 double mlirCalibratedQuantizedTypeGetMax(MlirType type) {
223 return unwrap(type).cast<quant::CalibratedQuantizedType>().getMax();
224 }
225