1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Quant/QuantTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/Dialect/Quant/QuantOps.h"
12 
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/MathExtras.h"
18 
19 using namespace mlir;
20 using namespace mlir::quant;
21 using namespace mlir::quant::detail;
22 
getFlags() const23 unsigned QuantizedType::getFlags() const {
24   return static_cast<ImplType *>(impl)->flags;
25 }
26 
classof(Type type)27 bool QuantizedType::classof(Type type) {
28   return llvm::isa<QuantizationDialect>(type.getDialect());
29 }
30 
31 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)32 QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
33                       unsigned flags, Type storageType, Type expressedType,
34                       int64_t storageTypeMin, int64_t storageTypeMax) {
35   // Verify that the storage type is integral.
36   // This restriction may be lifted at some point in favor of using bf16
37   // or f16 as exact representations on hardware where that is advantageous.
38   auto intStorageType = storageType.dyn_cast<IntegerType>();
39   if (!intStorageType)
40     return emitError() << "storage type must be integral";
41   unsigned integralWidth = intStorageType.getWidth();
42 
43   // Verify storage width.
44   if (integralWidth == 0 || integralWidth > MaxStorageBits)
45     return emitError() << "illegal storage type size: " << integralWidth;
46 
47   // Verify storageTypeMin and storageTypeMax.
48   bool isSigned =
49       (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
50   int64_t defaultIntegerMin =
51       getDefaultMinimumForInteger(isSigned, integralWidth);
52   int64_t defaultIntegerMax =
53       getDefaultMaximumForInteger(isSigned, integralWidth);
54   if (storageTypeMax - storageTypeMin <= 0 ||
55       storageTypeMin < defaultIntegerMin ||
56       storageTypeMax > defaultIntegerMax) {
57     return emitError() << "illegal storage min and storage max: ("
58                        << storageTypeMin << ":" << storageTypeMax << ")";
59   }
60   return success();
61 }
62 
getStorageType() const63 Type QuantizedType::getStorageType() const {
64   return static_cast<ImplType *>(impl)->storageType;
65 }
66 
getStorageTypeMin() const67 int64_t QuantizedType::getStorageTypeMin() const {
68   return static_cast<ImplType *>(impl)->storageTypeMin;
69 }
70 
getStorageTypeMax() const71 int64_t QuantizedType::getStorageTypeMax() const {
72   return static_cast<ImplType *>(impl)->storageTypeMax;
73 }
74 
getStorageTypeIntegralWidth() const75 unsigned QuantizedType::getStorageTypeIntegralWidth() const {
76   // NOTE: If ever supporting non-integral storage types, some other scheme
77   // for determining the width will be needed.
78   return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
79 }
80 
getExpressedType() const81 Type QuantizedType::getExpressedType() const {
82   return static_cast<ImplType *>(impl)->expressedType;
83 }
84 
isCompatibleExpressedType(Type candidateExpressedType)85 bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
86   if (candidateExpressedType.isa<ShapedType>()) {
87     return candidateExpressedType.cast<ShapedType>().getElementType() ==
88            getExpressedType();
89   }
90   return candidateExpressedType == getExpressedType();
91 }
92 
93 QuantizedType
getQuantizedElementType(Type primitiveOrContainerType)94 QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
95   if (primitiveOrContainerType.isa<ShapedType>()) {
96     Type elementType =
97         primitiveOrContainerType.cast<ShapedType>().getElementType();
98     return elementType.dyn_cast<QuantizedType>();
99   }
100   return primitiveOrContainerType.dyn_cast<QuantizedType>();
101 }
102 
castFromStorageType(Type candidateType)103 Type QuantizedType::castFromStorageType(Type candidateType) {
104   if (candidateType == getStorageType()) {
105     // i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
106     return *this;
107   }
108   if (candidateType.isa<RankedTensorType>()) {
109     // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
110     return RankedTensorType::get(
111         candidateType.cast<RankedTensorType>().getShape(), getStorageType());
112   }
113   if (candidateType.isa<UnrankedTensorType>()) {
114     // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
115     return UnrankedTensorType::get(getStorageType());
116   }
117   if (candidateType.isa<VectorType>()) {
118     // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
119     return VectorType::get(candidateType.cast<VectorType>().getShape(),
120                            getStorageType());
121   }
122 
123   return nullptr;
124 }
125 
castToStorageType(Type quantizedType)126 Type QuantizedType::castToStorageType(Type quantizedType) {
127   if (quantizedType.isa<QuantizedType>()) {
128     // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
129     return quantizedType.cast<QuantizedType>().getStorageType();
130   }
131   if (quantizedType.isa<ShapedType>()) {
132     // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
133     ShapedType sType = quantizedType.cast<ShapedType>();
134     if (!sType.getElementType().isa<QuantizedType>()) {
135       return nullptr;
136     }
137     Type storageType =
138         sType.getElementType().cast<QuantizedType>().getStorageType();
139     if (quantizedType.isa<RankedTensorType>()) {
140       return RankedTensorType::get(sType.getShape(), storageType);
141     }
142     if (quantizedType.isa<UnrankedTensorType>()) {
143       return UnrankedTensorType::get(storageType);
144     }
145     if (quantizedType.isa<VectorType>()) {
146       return VectorType::get(sType.getShape(), storageType);
147     }
148   }
149 
150   return nullptr;
151 }
152 
castFromExpressedType(Type candidateType)153 Type QuantizedType::castFromExpressedType(Type candidateType) {
154   if (candidateType == getExpressedType()) {
155     // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
156     return *this;
157   }
158   if (candidateType.isa<ShapedType>()) {
159     ShapedType candidateShapedType = candidateType.cast<ShapedType>();
160     if (candidateShapedType.getElementType() != getExpressedType()) {
161       return nullptr;
162     }
163 
164     if (candidateType.isa<RankedTensorType>()) {
165       // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
166       return RankedTensorType::get(candidateShapedType.getShape(), *this);
167     }
168     if (candidateType.isa<UnrankedTensorType>()) {
169       // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
170       return UnrankedTensorType::get(*this);
171     }
172     if (candidateType.isa<VectorType>()) {
173       // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
174       return VectorType::get(candidateShapedType.getShape(), *this);
175     }
176   }
177 
178   return nullptr;
179 }
180 
castToExpressedType(Type quantizedType)181 Type QuantizedType::castToExpressedType(Type quantizedType) {
182   if (quantizedType.isa<QuantizedType>()) {
183     // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
184     return quantizedType.cast<QuantizedType>().getExpressedType();
185   }
186   if (quantizedType.isa<ShapedType>()) {
187     // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
188     ShapedType sType = quantizedType.cast<ShapedType>();
189     if (!sType.getElementType().isa<QuantizedType>()) {
190       return nullptr;
191     }
192     Type expressedType =
193         sType.getElementType().cast<QuantizedType>().getExpressedType();
194     if (quantizedType.isa<RankedTensorType>()) {
195       return RankedTensorType::get(sType.getShape(), expressedType);
196     }
197     if (quantizedType.isa<UnrankedTensorType>()) {
198       return UnrankedTensorType::get(expressedType);
199     }
200     if (quantizedType.isa<VectorType>()) {
201       return VectorType::get(sType.getShape(), expressedType);
202     }
203   }
204 
205   return nullptr;
206 }
207 
castExpressedToStorageType(Type candidateType)208 Type QuantizedType::castExpressedToStorageType(Type candidateType) {
209   Type expressedQuantizedType = castFromExpressedType(candidateType);
210   if (!expressedQuantizedType) {
211     return nullptr;
212   }
213   return QuantizedType::castToStorageType(expressedQuantizedType);
214 }
215 
get(unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)216 AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
217                                        Type expressedType,
218                                        int64_t storageTypeMin,
219                                        int64_t storageTypeMax) {
220   return Base::get(storageType.getContext(), flags, storageType, expressedType,
221                    storageTypeMin, storageTypeMax);
222 }
223 
224 AnyQuantizedType
getChecked(function_ref<InFlightDiagnostic ()> emitError,unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)225 AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
226                              unsigned flags, Type storageType,
227                              Type expressedType, int64_t storageTypeMin,
228                              int64_t storageTypeMax) {
229   return Base::getChecked(emitError, storageType.getContext(), flags,
230                           storageType, expressedType, storageTypeMin,
231                           storageTypeMax);
232 }
233 
234 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)235 AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
236                          unsigned flags, Type storageType, Type expressedType,
237                          int64_t storageTypeMin, int64_t storageTypeMax) {
238   if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
239                                    storageTypeMin, storageTypeMax))) {
240     return failure();
241   }
242 
243   // Verify that the expressed type is floating point.
244   // If this restriction is ever eliminated, the parser/printer must be
245   // extended.
246   if (expressedType && !expressedType.isa<FloatType>())
247     return emitError() << "expressed type must be floating point";
248 
249   return success();
250 }
251 
get(unsigned flags,Type storageType,Type expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax)252 UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
253                                                Type expressedType, double scale,
254                                                int64_t zeroPoint,
255                                                int64_t storageTypeMin,
256                                                int64_t storageTypeMax) {
257   return Base::get(storageType.getContext(), flags, storageType, expressedType,
258                    scale, zeroPoint, storageTypeMin, storageTypeMax);
259 }
260 
getChecked(function_ref<InFlightDiagnostic ()> emitError,unsigned flags,Type storageType,Type expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax)261 UniformQuantizedType UniformQuantizedType::getChecked(
262     function_ref<InFlightDiagnostic()> emitError, unsigned flags,
263     Type storageType, Type expressedType, double scale, int64_t zeroPoint,
264     int64_t storageTypeMin, int64_t storageTypeMax) {
265   return Base::getChecked(emitError, storageType.getContext(), flags,
266                           storageType, expressedType, scale, zeroPoint,
267                           storageTypeMin, storageTypeMax);
268 }
269 
verify(function_ref<InFlightDiagnostic ()> emitError,unsigned flags,Type storageType,Type expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax)270 LogicalResult UniformQuantizedType::verify(
271     function_ref<InFlightDiagnostic()> emitError, unsigned flags,
272     Type storageType, Type expressedType, double scale, int64_t zeroPoint,
273     int64_t storageTypeMin, int64_t storageTypeMax) {
274   if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
275                                    storageTypeMin, storageTypeMax))) {
276     return failure();
277   }
278 
279   // Uniform quantization requires fully expressed parameters, including
280   // expressed type.
281   if (!expressedType)
282     return emitError() << "uniform quantization requires expressed type";
283 
284   // Verify that the expressed type is floating point.
285   // If this restriction is ever eliminated, the parser/printer must be
286   // extended.
287   if (!expressedType.isa<FloatType>())
288     return emitError() << "expressed type must be floating point";
289 
290   // Verify scale.
291   if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
292     return emitError() << "illegal scale: " << scale;
293 
294   return success();
295 }
296 
getScale() const297 double UniformQuantizedType::getScale() const { return getImpl()->scale; }
298 
getZeroPoint() const299 int64_t UniformQuantizedType::getZeroPoint() const {
300   return getImpl()->zeroPoint;
301 }
302 
get(unsigned flags,Type storageType,Type expressedType,ArrayRef<double> scales,ArrayRef<int64_t> zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax)303 UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
304     unsigned flags, Type storageType, Type expressedType,
305     ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
306     int32_t quantizedDimension, int64_t storageTypeMin,
307     int64_t storageTypeMax) {
308   return Base::get(storageType.getContext(), flags, storageType, expressedType,
309                    scales, zeroPoints, quantizedDimension, storageTypeMin,
310                    storageTypeMax);
311 }
312 
getChecked(function_ref<InFlightDiagnostic ()> emitError,unsigned flags,Type storageType,Type expressedType,ArrayRef<double> scales,ArrayRef<int64_t> zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax)313 UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
314     function_ref<InFlightDiagnostic()> emitError, unsigned flags,
315     Type storageType, Type expressedType, ArrayRef<double> scales,
316     ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
317     int64_t storageTypeMin, int64_t storageTypeMax) {
318   return Base::getChecked(emitError, storageType.getContext(), flags,
319                           storageType, expressedType, scales, zeroPoints,
320                           quantizedDimension, storageTypeMin, storageTypeMax);
321 }
322 
verify(function_ref<InFlightDiagnostic ()> emitError,unsigned flags,Type storageType,Type expressedType,ArrayRef<double> scales,ArrayRef<int64_t> zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax)323 LogicalResult UniformQuantizedPerAxisType::verify(
324     function_ref<InFlightDiagnostic()> emitError, unsigned flags,
325     Type storageType, Type expressedType, ArrayRef<double> scales,
326     ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
327     int64_t storageTypeMin, int64_t storageTypeMax) {
328   if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
329                                    storageTypeMin, storageTypeMax))) {
330     return failure();
331   }
332 
333   // Uniform quantization requires fully expressed parameters, including
334   // expressed type.
335   if (!expressedType)
336     return emitError() << "uniform quantization requires expressed type";
337 
338   // Verify that the expressed type is floating point.
339   // If this restriction is ever eliminated, the parser/printer must be
340   // extended.
341   if (!expressedType.isa<FloatType>())
342     return emitError() << "expressed type must be floating point";
343 
344   // Ensure that the number of scales and zeroPoints match.
345   if (scales.size() != zeroPoints.size())
346     return emitError() << "illegal number of scales and zeroPoints: "
347                        << scales.size() << ", " << zeroPoints.size();
348 
349   // Verify scale.
350   for (double scale : scales) {
351     if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
352       return emitError() << "illegal scale: " << scale;
353   }
354 
355   return success();
356 }
357 
getScales() const358 ArrayRef<double> UniformQuantizedPerAxisType::getScales() const {
359   return getImpl()->getScales();
360 }
361 
getZeroPoints() const362 ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const {
363   return getImpl()->getZeroPoints();
364 }
365 
getQuantizedDimension() const366 int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
367   return getImpl()->quantizedDimension;
368 }
369 
get(Type expressedType,double min,double max)370 CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType,
371                                                      double min, double max) {
372   return Base::get(expressedType.getContext(), expressedType, min, max);
373 }
374 
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type expressedType,double min,double max)375 CalibratedQuantizedType CalibratedQuantizedType::getChecked(
376     function_ref<InFlightDiagnostic()> emitError, Type expressedType,
377     double min, double max) {
378   return Base::getChecked(emitError, expressedType.getContext(), expressedType,
379                           min, max);
380 }
381 
382 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type expressedType,double min,double max)383 CalibratedQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
384                                 Type expressedType, double min, double max) {
385   // Verify that the expressed type is floating point.
386   // If this restriction is ever eliminated, the parser/printer must be
387   // extended.
388   if (!expressedType.isa<FloatType>())
389     return emitError() << "expressed type must be floating point";
390   if (max <= min)
391     return emitError() << "illegal min and max: (" << min << ":" << max << ")";
392 
393   return success();
394 }
395 
getMin() const396 double CalibratedQuantizedType::getMin() const { return getImpl()->min; }
397 
getMax() const398 double CalibratedQuantizedType::getMax() const { return getImpl()->max; }
399