1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Quant/QuantTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/Dialect/Quant/QuantOps.h"
12
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/MathExtras.h"
18
19 using namespace mlir;
20 using namespace mlir::quant;
21 using namespace mlir::quant::detail;
22
getFlags() const23 unsigned QuantizedType::getFlags() const {
24 return static_cast<ImplType *>(impl)->flags;
25 }
26
classof(Type type)27 bool QuantizedType::classof(Type type) {
28 return llvm::isa<QuantizationDialect>(type.getDialect());
29 }
30
31 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)32 QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
33 unsigned flags, Type storageType, Type expressedType,
34 int64_t storageTypeMin, int64_t storageTypeMax) {
35 // Verify that the storage type is integral.
36 // This restriction may be lifted at some point in favor of using bf16
37 // or f16 as exact representations on hardware where that is advantageous.
38 auto intStorageType = storageType.dyn_cast<IntegerType>();
39 if (!intStorageType)
40 return emitError() << "storage type must be integral";
41 unsigned integralWidth = intStorageType.getWidth();
42
43 // Verify storage width.
44 if (integralWidth == 0 || integralWidth > MaxStorageBits)
45 return emitError() << "illegal storage type size: " << integralWidth;
46
47 // Verify storageTypeMin and storageTypeMax.
48 bool isSigned =
49 (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
50 int64_t defaultIntegerMin =
51 getDefaultMinimumForInteger(isSigned, integralWidth);
52 int64_t defaultIntegerMax =
53 getDefaultMaximumForInteger(isSigned, integralWidth);
54 if (storageTypeMax - storageTypeMin <= 0 ||
55 storageTypeMin < defaultIntegerMin ||
56 storageTypeMax > defaultIntegerMax) {
57 return emitError() << "illegal storage min and storage max: ("
58 << storageTypeMin << ":" << storageTypeMax << ")";
59 }
60 return success();
61 }
62
getStorageType() const63 Type QuantizedType::getStorageType() const {
64 return static_cast<ImplType *>(impl)->storageType;
65 }
66
getStorageTypeMin() const67 int64_t QuantizedType::getStorageTypeMin() const {
68 return static_cast<ImplType *>(impl)->storageTypeMin;
69 }
70
getStorageTypeMax() const71 int64_t QuantizedType::getStorageTypeMax() const {
72 return static_cast<ImplType *>(impl)->storageTypeMax;
73 }
74
getStorageTypeIntegralWidth() const75 unsigned QuantizedType::getStorageTypeIntegralWidth() const {
76 // NOTE: If ever supporting non-integral storage types, some other scheme
77 // for determining the width will be needed.
78 return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
79 }
80
getExpressedType() const81 Type QuantizedType::getExpressedType() const {
82 return static_cast<ImplType *>(impl)->expressedType;
83 }
84
isCompatibleExpressedType(Type candidateExpressedType)85 bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
86 if (candidateExpressedType.isa<ShapedType>()) {
87 return candidateExpressedType.cast<ShapedType>().getElementType() ==
88 getExpressedType();
89 }
90 return candidateExpressedType == getExpressedType();
91 }
92
93 QuantizedType
getQuantizedElementType(Type primitiveOrContainerType)94 QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
95 if (primitiveOrContainerType.isa<ShapedType>()) {
96 Type elementType =
97 primitiveOrContainerType.cast<ShapedType>().getElementType();
98 return elementType.dyn_cast<QuantizedType>();
99 }
100 return primitiveOrContainerType.dyn_cast<QuantizedType>();
101 }
102
castFromStorageType(Type candidateType)103 Type QuantizedType::castFromStorageType(Type candidateType) {
104 if (candidateType == getStorageType()) {
105 // i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
106 return *this;
107 }
108 if (candidateType.isa<RankedTensorType>()) {
109 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
110 return RankedTensorType::get(
111 candidateType.cast<RankedTensorType>().getShape(), getStorageType());
112 }
113 if (candidateType.isa<UnrankedTensorType>()) {
114 // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
115 return UnrankedTensorType::get(getStorageType());
116 }
117 if (candidateType.isa<VectorType>()) {
118 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
119 return VectorType::get(candidateType.cast<VectorType>().getShape(),
120 getStorageType());
121 }
122
123 return nullptr;
124 }
125
castToStorageType(Type quantizedType)126 Type QuantizedType::castToStorageType(Type quantizedType) {
127 if (quantizedType.isa<QuantizedType>()) {
128 // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
129 return quantizedType.cast<QuantizedType>().getStorageType();
130 }
131 if (quantizedType.isa<ShapedType>()) {
132 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
133 ShapedType sType = quantizedType.cast<ShapedType>();
134 if (!sType.getElementType().isa<QuantizedType>()) {
135 return nullptr;
136 }
137 Type storageType =
138 sType.getElementType().cast<QuantizedType>().getStorageType();
139 if (quantizedType.isa<RankedTensorType>()) {
140 return RankedTensorType::get(sType.getShape(), storageType);
141 }
142 if (quantizedType.isa<UnrankedTensorType>()) {
143 return UnrankedTensorType::get(storageType);
144 }
145 if (quantizedType.isa<VectorType>()) {
146 return VectorType::get(sType.getShape(), storageType);
147 }
148 }
149
150 return nullptr;
151 }
152
castFromExpressedType(Type candidateType)153 Type QuantizedType::castFromExpressedType(Type candidateType) {
154 if (candidateType == getExpressedType()) {
155 // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
156 return *this;
157 }
158 if (candidateType.isa<ShapedType>()) {
159 ShapedType candidateShapedType = candidateType.cast<ShapedType>();
160 if (candidateShapedType.getElementType() != getExpressedType()) {
161 return nullptr;
162 }
163
164 if (candidateType.isa<RankedTensorType>()) {
165 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
166 return RankedTensorType::get(candidateShapedType.getShape(), *this);
167 }
168 if (candidateType.isa<UnrankedTensorType>()) {
169 // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
170 return UnrankedTensorType::get(*this);
171 }
172 if (candidateType.isa<VectorType>()) {
173 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
174 return VectorType::get(candidateShapedType.getShape(), *this);
175 }
176 }
177
178 return nullptr;
179 }
180
castToExpressedType(Type quantizedType)181 Type QuantizedType::castToExpressedType(Type quantizedType) {
182 if (quantizedType.isa<QuantizedType>()) {
183 // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
184 return quantizedType.cast<QuantizedType>().getExpressedType();
185 }
186 if (quantizedType.isa<ShapedType>()) {
187 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
188 ShapedType sType = quantizedType.cast<ShapedType>();
189 if (!sType.getElementType().isa<QuantizedType>()) {
190 return nullptr;
191 }
192 Type expressedType =
193 sType.getElementType().cast<QuantizedType>().getExpressedType();
194 if (quantizedType.isa<RankedTensorType>()) {
195 return RankedTensorType::get(sType.getShape(), expressedType);
196 }
197 if (quantizedType.isa<UnrankedTensorType>()) {
198 return UnrankedTensorType::get(expressedType);
199 }
200 if (quantizedType.isa<VectorType>()) {
201 return VectorType::get(sType.getShape(), expressedType);
202 }
203 }
204
205 return nullptr;
206 }
207
castExpressedToStorageType(Type candidateType)208 Type QuantizedType::castExpressedToStorageType(Type candidateType) {
209 Type expressedQuantizedType = castFromExpressedType(candidateType);
210 if (!expressedQuantizedType) {
211 return nullptr;
212 }
213 return QuantizedType::castToStorageType(expressedQuantizedType);
214 }
215
get(unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)216 AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
217 Type expressedType,
218 int64_t storageTypeMin,
219 int64_t storageTypeMax) {
220 return Base::get(storageType.getContext(), flags, storageType, expressedType,
221 storageTypeMin, storageTypeMax);
222 }
223
224 AnyQuantizedType
getChecked(function_ref<InFlightDiagnostic ()> emitError,unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)225 AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
226 unsigned flags, Type storageType,
227 Type expressedType, int64_t storageTypeMin,
228 int64_t storageTypeMax) {
229 return Base::getChecked(emitError, storageType.getContext(), flags,
230 storageType, expressedType, storageTypeMin,
231 storageTypeMax);
232 }
233
234 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)235 AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
236 unsigned flags, Type storageType, Type expressedType,
237 int64_t storageTypeMin, int64_t storageTypeMax) {
238 if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
239 storageTypeMin, storageTypeMax))) {
240 return failure();
241 }
242
243 // Verify that the expressed type is floating point.
244 // If this restriction is ever eliminated, the parser/printer must be
245 // extended.
246 if (expressedType && !expressedType.isa<FloatType>())
247 return emitError() << "expressed type must be floating point";
248
249 return success();
250 }
251
get(unsigned flags,Type storageType,Type expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax)252 UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
253 Type expressedType, double scale,
254 int64_t zeroPoint,
255 int64_t storageTypeMin,
256 int64_t storageTypeMax) {
257 return Base::get(storageType.getContext(), flags, storageType, expressedType,
258 scale, zeroPoint, storageTypeMin, storageTypeMax);
259 }
260
getChecked(function_ref<InFlightDiagnostic ()> emitError,unsigned flags,Type storageType,Type expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax)261 UniformQuantizedType UniformQuantizedType::getChecked(
262 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
263 Type storageType, Type expressedType, double scale, int64_t zeroPoint,
264 int64_t storageTypeMin, int64_t storageTypeMax) {
265 return Base::getChecked(emitError, storageType.getContext(), flags,
266 storageType, expressedType, scale, zeroPoint,
267 storageTypeMin, storageTypeMax);
268 }
269
verify(function_ref<InFlightDiagnostic ()> emitError,unsigned flags,Type storageType,Type expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax)270 LogicalResult UniformQuantizedType::verify(
271 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
272 Type storageType, Type expressedType, double scale, int64_t zeroPoint,
273 int64_t storageTypeMin, int64_t storageTypeMax) {
274 if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
275 storageTypeMin, storageTypeMax))) {
276 return failure();
277 }
278
279 // Uniform quantization requires fully expressed parameters, including
280 // expressed type.
281 if (!expressedType)
282 return emitError() << "uniform quantization requires expressed type";
283
284 // Verify that the expressed type is floating point.
285 // If this restriction is ever eliminated, the parser/printer must be
286 // extended.
287 if (!expressedType.isa<FloatType>())
288 return emitError() << "expressed type must be floating point";
289
290 // Verify scale.
291 if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
292 return emitError() << "illegal scale: " << scale;
293
294 return success();
295 }
296
getScale() const297 double UniformQuantizedType::getScale() const { return getImpl()->scale; }
298
getZeroPoint() const299 int64_t UniformQuantizedType::getZeroPoint() const {
300 return getImpl()->zeroPoint;
301 }
302
get(unsigned flags,Type storageType,Type expressedType,ArrayRef<double> scales,ArrayRef<int64_t> zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax)303 UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
304 unsigned flags, Type storageType, Type expressedType,
305 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
306 int32_t quantizedDimension, int64_t storageTypeMin,
307 int64_t storageTypeMax) {
308 return Base::get(storageType.getContext(), flags, storageType, expressedType,
309 scales, zeroPoints, quantizedDimension, storageTypeMin,
310 storageTypeMax);
311 }
312
getChecked(function_ref<InFlightDiagnostic ()> emitError,unsigned flags,Type storageType,Type expressedType,ArrayRef<double> scales,ArrayRef<int64_t> zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax)313 UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
314 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
315 Type storageType, Type expressedType, ArrayRef<double> scales,
316 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
317 int64_t storageTypeMin, int64_t storageTypeMax) {
318 return Base::getChecked(emitError, storageType.getContext(), flags,
319 storageType, expressedType, scales, zeroPoints,
320 quantizedDimension, storageTypeMin, storageTypeMax);
321 }
322
verify(function_ref<InFlightDiagnostic ()> emitError,unsigned flags,Type storageType,Type expressedType,ArrayRef<double> scales,ArrayRef<int64_t> zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax)323 LogicalResult UniformQuantizedPerAxisType::verify(
324 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
325 Type storageType, Type expressedType, ArrayRef<double> scales,
326 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
327 int64_t storageTypeMin, int64_t storageTypeMax) {
328 if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
329 storageTypeMin, storageTypeMax))) {
330 return failure();
331 }
332
333 // Uniform quantization requires fully expressed parameters, including
334 // expressed type.
335 if (!expressedType)
336 return emitError() << "uniform quantization requires expressed type";
337
338 // Verify that the expressed type is floating point.
339 // If this restriction is ever eliminated, the parser/printer must be
340 // extended.
341 if (!expressedType.isa<FloatType>())
342 return emitError() << "expressed type must be floating point";
343
344 // Ensure that the number of scales and zeroPoints match.
345 if (scales.size() != zeroPoints.size())
346 return emitError() << "illegal number of scales and zeroPoints: "
347 << scales.size() << ", " << zeroPoints.size();
348
349 // Verify scale.
350 for (double scale : scales) {
351 if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
352 return emitError() << "illegal scale: " << scale;
353 }
354
355 return success();
356 }
357
getScales() const358 ArrayRef<double> UniformQuantizedPerAxisType::getScales() const {
359 return getImpl()->getScales();
360 }
361
getZeroPoints() const362 ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const {
363 return getImpl()->getZeroPoints();
364 }
365
getQuantizedDimension() const366 int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
367 return getImpl()->quantizedDimension;
368 }
369
get(Type expressedType,double min,double max)370 CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType,
371 double min, double max) {
372 return Base::get(expressedType.getContext(), expressedType, min, max);
373 }
374
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type expressedType,double min,double max)375 CalibratedQuantizedType CalibratedQuantizedType::getChecked(
376 function_ref<InFlightDiagnostic()> emitError, Type expressedType,
377 double min, double max) {
378 return Base::getChecked(emitError, expressedType.getContext(), expressedType,
379 min, max);
380 }
381
382 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type expressedType,double min,double max)383 CalibratedQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
384 Type expressedType, double min, double max) {
385 // Verify that the expressed type is floating point.
386 // If this restriction is ever eliminated, the parser/printer must be
387 // extended.
388 if (!expressedType.isa<FloatType>())
389 return emitError() << "expressed type must be floating point";
390 if (max <= min)
391 return emitError() << "illegal min and max: (" << min << ":" << max << ")";
392
393 return success();
394 }
395
getMin() const396 double CalibratedQuantizedType::getMin() const { return getImpl()->min; }
397
getMax() const398 double CalibratedQuantizedType::getMax() const { return getImpl()->max; }
399