1 //===- QuantTypes.h - Quantization Ops and Types ----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_QUANT_QUANTTYPES_H
10 #define MLIR_DIALECT_QUANT_QUANTTYPES_H
11 
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/OpDefinition.h"
17 #include "mlir/IR/Types.h"
18 #include "llvm/Support/MathExtras.h"
19 
20 namespace mlir {
21 namespace quant {
22 
23 class QuantizedIntegerType;
24 
25 namespace detail {
26 
27 struct QuantizedTypeStorage;
28 struct AnyQuantizedTypeStorage;
29 struct UniformQuantizedTypeStorage;
30 struct UniformQuantizedPerAxisTypeStorage;
31 struct CalibratedQuantizedTypeStorage;
32 
33 } // namespace detail
34 
35 /// Enumeration of bit-mapped flags related to quantized types.
36 namespace QuantizationFlags {
37 enum FlagValue {
38   // Indicates that the storage type should be interpreted as a signed
39   // integer. The default is to interpret it as an unsigned value.
40   Signed = 1,
41 };
42 } // namespace QuantizationFlags
43 
44 /// Base class for all quantized types known to this dialect.
45 /// All quantized types have:
46 ///   - storageType: The (narrower) numeric type that is being used to
47 ///     approximate some expressed type.
48 ///   - expressedType: The type that is being approximated.
49 ///
50 /// The base class provides generic support for manipulating the types based
51 /// on these fields.
52 class QuantizedType : public Type {
53 public:
54   using ImplType = detail::QuantizedTypeStorage;
55   using Type::Type;
56 
57   /// The maximum number of bits supported for storage types.
58   static constexpr unsigned MaxStorageBits = 32;
59 
60   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
61                               unsigned flags, Type storageType,
62                               Type expressedType, int64_t storageTypeMin,
63                               int64_t storageTypeMax);
64 
65   /// Support method to enable LLVM-style type casting.
66   static bool classof(Type type);
67 
68   /// Gets the minimum possible stored by a storageType. storageTypeMin must
69   /// be greater than or equal to this value.
getDefaultMinimumForInteger(bool isSigned,unsigned integralWidth)70   static int64_t getDefaultMinimumForInteger(bool isSigned,
71                                              unsigned integralWidth) {
72     if (isSigned) {
73       return llvm::minIntN(integralWidth);
74     }
75     return 0;
76   }
77 
78   /// Gets the maximum possible stored by a storageType. storageTypeMax must
79   /// be less than or equal to this value.
getDefaultMaximumForInteger(bool isSigned,unsigned integralWidth)80   static int64_t getDefaultMaximumForInteger(bool isSigned,
81                                              unsigned integralWidth) {
82     if (isSigned) {
83       return llvm::maxIntN(integralWidth);
84     }
85     return llvm::maxUIntN(integralWidth);
86   }
87 
88   /// Gets the original expressed type that this quantized type approximates.
89   /// Note that this presumes that the quantized type was always derived from
90   /// a floating point type, which in the broadest definition, is not true (i.e.
91   /// it could be some form of integral, fixed type or affine type in its own
92   /// right); however, at the high level, no examples of such usage are
93   /// presently known and the restriction serves some useful purposes (such as
94   /// always being able to reverse a transformation or measure error). In most
95   /// cases, this will be f32.
96   Type getExpressedType() const;
97 
98   /// Gets the flags associated with this type. Typically a more specific
99   /// accessor is appropriate.
100   unsigned getFlags() const;
101 
102   // Convenience helpers.
103   /// Whether the storage type should be interpreted as a signed quantity
104   /// (true) or an unsigned value (false).
isSigned()105   bool isSigned() const {
106     return (getFlags() & QuantizationFlags::Signed) ==
107            QuantizationFlags::Signed;
108   }
109 
110   /// Gets the underlying type used for to store values. Note that this may
111   /// be signed or unsigned. Use the isSigned() accessor to differentiate.
112   Type getStorageType() const;
113 
114   /// The minimum value that storageType can take.
115   int64_t getStorageTypeMin() const;
116 
117   /// The maximum value that storageType can take.
118   int64_t getStorageTypeMax() const;
119 
120   /// Gets the integral bit width that the underlying storage type can exactly
121   /// represent. For integral storage types, this will just be their width.
122   unsigned getStorageTypeIntegralWidth() const;
123 
124   /// Returns whether the candidateExpressedType is a match for this
125   /// QuantizedType. This will be true if the candidate type is either a
126   /// primitive type or a container type whose element type equals this
127   /// QuantizedType's expressed type.
128   /// Examples of compatible candidateExpressedType:
129   ///   !quant.uniform<i8:f32, 1.0> =~ f32
130   ///   !quant.uniform<i8:f32, 1.0> =~ tensor<4xf32>
131   bool isCompatibleExpressedType(Type candidateExpressedType);
132 
133   /// Returns the element type as a QuantizedType or nullptr if it is not
134   /// a quantized type. If the type is primitive, returns that. If it is a
135   /// container (vector/tensor), return the element type.
136   /// Examples:
137   ///   !quant.uniform<i8:f32, 1.0> -> !quant.uniform<i8:f32, 1.0>
138   ///   tensor<4x!quant.uniform<i8:f32, 1.0> -> quant.uniform<i8:f32, 1.0>
139   static QuantizedType getQuantizedElementType(Type primitiveOrContainerType);
140 
141   /// Casts from a type based on the storageType to a corresponding type based
142   /// on this type (returns nullptr if the cast is not valid).
143   /// Examples:
144   ///   i8 -> !quant.uniform<i8:f32, 1.0>
145   ///   tensor<4xi8> -> tensor<4x!quant.uniform<i8:f32, 1.0}>>
146   ///   vector<4xi8> -> vector<4x!quant.uniform<i8:f32, 1.0>>
147   Type castFromStorageType(Type candidateType);
148 
149   /// Casts from a type based on a QuantizedType to a corresponding type based
150   /// on the storageType (returns nullptr if the cast is not valid).
151   /// This is the inverse of castFromStorageType().
152   static Type castToStorageType(Type quantizedType);
153 
154   /// Casts from a type based on the expressedType to a corresponding type based
155   /// on this type (returns nullptr if the cast is not valid).
156   /// Examples:
157   ///   f32 -> !quant.uniform<i8:f32, 1.0>
158   ///   tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>>
159   ///   vector<4xf32> -> vector<4x!quant.uniform<i8:f32, 1.0>>
160   Type castFromExpressedType(Type candidateType);
161 
162   /// Casts from a type based on QuantizedType to a corresponding type based
163   /// on the expressedType (returns nullptr if the cast is not valid).
164   /// This is the inverse of castFromExpressedType.
165   static Type castToExpressedType(Type quantizedType);
166 
167   /// Casts from a type based on the expressedType to the equivalent type
168   /// based on storageType by way of this QuantizedType. Equivalent to:
169   ///   QuantizedType::castToStorageType(castFromExpressedType(candidateType))
170   /// (but with validity checks).
171   /// Example (for this = !quant.uniform<i8:f32, 1.0>):
172   ///   tensor<4xf32> -> tensor<4xi8>
173   Type castExpressedToStorageType(Type candidateType);
174 
175 private:
176   /// Hide the following methods inherited from `Type`. It is almost certainly
177   /// a bug to call them from a `QuantizedType` object. Users should call
178   /// `getStorageType` or `getExpressedType` to get the underlying types
179   /// they want to inspect.
180   using Type::isBF16;
181   using Type::isF16;
182   using Type::isF32;
183   using Type::isF64;
184   using Type::isIndex;
185   using Type::isInteger;
186 };
187 
188 /// A quantized type that maps storage to/from expressed types in an
189 /// unspecified way.
190 ///
191 /// Typical syntax:
192 ///   quant.any<i8:f32>
193 ///   quant.any<i8>
194 ///   quant.any<i8<-16,15>>
195 ///
196 /// Note that for the any type, the expressed type is optional.
197 class AnyQuantizedType
198     : public Type::TypeBase<AnyQuantizedType, QuantizedType,
199                             detail::AnyQuantizedTypeStorage> {
200 public:
201   using Base::Base;
202   using Base::getChecked;
203 
204   /// Gets an instance of the type with all parameters specified but not
205   /// checked.
206   static AnyQuantizedType get(unsigned flags, Type storageType,
207                               Type expressedType, int64_t storageTypeMin,
208                               int64_t storageTypeMax);
209 
210   /// Gets an instance of the type with all specified parameters checked.
211   /// Returns a nullptr convertible type on failure.
212   static AnyQuantizedType
213   getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
214              Type storageType, Type expressedType, int64_t storageTypeMin,
215              int64_t storageTypeMax);
216 
217   /// Verifies construction invariants and issues errors/warnings.
218   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
219                               unsigned flags, Type storageType,
220                               Type expressedType, int64_t storageTypeMin,
221                               int64_t storageTypeMax);
222 };
223 
224 /// Represents a family of uniform, quantized types.
225 ///
226 /// Each instance of this type expresses a mapping between real values (most
227 /// often expressed in floating point f32) and quantized values (either fixed
228 /// point or affine).
229 ///
230 /// The relationship is:
231 ///     real_value = scale * (quantized_value - zero_point)
232 ///
233 /// It is used as part of high level graph transformations that have the goal
234 /// of re-expressing parts of a computation in terms of this common form for
235 /// more efficient execution at runtime. In addition, it is designed to be
236 /// expressive enough to facilitate lowering to precise types and operations
237 /// in target hardware.
238 ///
239 /// As a high-level type, focused on intermediate passes, this type holds
240 /// opinions consistent with high-level usage. If lowering math kernels below
241 /// the high level arithmetic ops (i.e. to LLVM IR or hardware specific
242 /// instruction sets), it is expected that the information expressed here
243 /// will be used to drive low level codegen and target specific type selection,
244 /// but this type will likely be erased in the process.
245 ///
246 /// Syntax synopsis:
247 ///   Per-layer, all parameters expressed:
248 ///     !quant<uniform[StorageType:ExpressedType]{Scale:ZeroPoint}>
249 ///   Per-layer, optional parameters omitted:
250 ///     !quant<uniform[StorageType]{Scale}>
251 ///
252 ///   StorageType: 'i'|'u' NumBits
253 ///   ExpressedType: 'f16', 'f32', 'bf16', 'f64'
254 ///   Scale: A legal double value
255 ///   ZeroPoint: An integer value
256 class UniformQuantizedType
257     : public Type::TypeBase<UniformQuantizedType, QuantizedType,
258                             detail::UniformQuantizedTypeStorage> {
259 public:
260   using Base::Base;
261   using Base::getChecked;
262 
263   /// Gets an instance of the type with all parameters specified but not
264   /// checked.
265   static UniformQuantizedType get(unsigned flags, Type storageType,
266                                   Type expressedType, double scale,
267                                   int64_t zeroPoint, int64_t storageTypeMin,
268                                   int64_t storageTypeMax);
269 
270   /// Gets an instance of the type with all specified parameters checked.
271   /// Returns a nullptr convertible type on failure.
272   static UniformQuantizedType
273   getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
274              Type storageType, Type expressedType, double scale,
275              int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
276 
277   /// Verifies construction invariants and issues errors/warnings.
278   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
279                               unsigned flags, Type storageType,
280                               Type expressedType, double scale,
281                               int64_t zeroPoint, int64_t storageTypeMin,
282                               int64_t storageTypeMax);
283 
284   /// Gets the scale term. The scale designates the difference between the real
285   /// values corresponding to consecutive quantized values differing by 1.
286   double getScale() const;
287 
288   /// Gets the storage value corresponding to the real value 0 in the affine
289   /// equation.
290   int64_t getZeroPoint() const;
291 
292   // Fixed point values are real numbers divided by a scale.
293   // Currently, only signed storage types are treated as fixed point.
294   // A fixed point value can be obtained from an affine value by subtracting
295   // the zeroPoint.
296   // In the future, this may be explicit versus implied by type and zeroPoint.
isFixedPoint()297   bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
298 };
299 
300 /// Represents per-axis (also known as per-channel quantization).
301 ///
302 /// Syntax synopsis:
303 ///   Per-axis, all parameters expressed:
304 ///     !quant<uniform[StorageType:ExpressedType:QuantizedDim]{QuantParams}>
305 ///   Per-axis, optional parameters omitted:
306 ///     !quant<uniform[StorageType]{Scale}>
307 ///
308 ///   StorageType: 'i'|'u' NumBits
309 ///   ExpressedType: 'f16', 'f32', 'bf16', 'f64'
310 ///   QuantizedDim: An integer value
311 ///   QuantParams: (Scale ':' ZeroPoint)+
312 ///   Scale: A legal double value
313 ///   ZeroPoint: An integer value
314 class UniformQuantizedPerAxisType
315     : public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType,
316                             detail::UniformQuantizedPerAxisTypeStorage> {
317 public:
318   using Base::Base;
319   using Base::getChecked;
320 
321   /// Gets an instance of the type with all parameters specified but not
322   /// checked.
323   static UniformQuantizedPerAxisType
324   get(unsigned flags, Type storageType, Type expressedType,
325       ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
326       int32_t quantizedDimension, int64_t storageTypeMin,
327       int64_t storageTypeMax);
328 
329   /// Gets an instance of the type with all specified parameters checked.
330   /// Returns a nullptr convertible type on failure.
331   static UniformQuantizedPerAxisType
332   getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
333              Type storageType, Type expressedType, ArrayRef<double> scales,
334              ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
335              int64_t storageTypeMin, int64_t storageTypeMax);
336 
337   /// Verifies construction invariants and issues errors/warnings.
338   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
339                               unsigned flags, Type storageType,
340                               Type expressedType, ArrayRef<double> scales,
341                               ArrayRef<int64_t> zeroPoints,
342                               int32_t quantizedDimension,
343                               int64_t storageTypeMin, int64_t storageTypeMax);
344 
345   /// Gets the quantization scales. The scales designate the difference between
346   /// the real values corresponding to consecutive quantized values differing
347   /// by 1. The ith scale corresponds to the ith slice in the
348   /// quantized_dimension.
349   ArrayRef<double> getScales() const;
350 
351   /// Gets the storage values corresponding to the real value 0 in the affine
352   /// equation. The ith zero point corresponds to the ith slice in the
353   /// quantized_dimension.
354   ArrayRef<int64_t> getZeroPoints() const;
355 
356   /// Specifies the dimension of the Tensor's shape that the scales and
357   /// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1]
358   /// with quantization params:
359   ///   scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1
360   /// will be quantized across the second dimension of t.
361   ///   t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1
362   ///   t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2
363   ///   t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3
364   int32_t getQuantizedDimension() const;
365 
366   /// Fixed point values are real numbers divided by a scale.
367   /// Currently, only signed storage types are treated as fixed point.
368   /// A fixed point value can be obtained from an affine value by subtracting
369   /// the zeroPoint.
370   /// In the future, this may be explicit versus implied by type and zeroPoint.
isFixedPoint()371   bool isFixedPoint() const {
372     if (!isSigned())
373       return false;
374     return llvm::all_of(getZeroPoints(),
375                         [](int64_t zeroPoint) { return zeroPoint != 0; });
376   }
377 };
378 
379 /// A quantized type that infers its range from given min/max values.
380 ///
381 /// Typical syntax:
382 ///   quant.calibrated<f32<-0.922,0.981>>
383 class CalibratedQuantizedType
384     : public Type::TypeBase<CalibratedQuantizedType, QuantizedType,
385                             detail::CalibratedQuantizedTypeStorage> {
386 public:
387   using Base::Base;
388   using Base::getChecked;
389 
390   /// Gets an instance of the type with all parameters specified but not
391   /// checked.
392   static CalibratedQuantizedType get(Type expressedType, double min,
393                                      double max);
394 
395   /// Gets an instance of the type with all specified parameters checked.
396   /// Returns a nullptr convertible type on failure.
397   static CalibratedQuantizedType
398   getChecked(function_ref<InFlightDiagnostic()> emitError, Type expressedType,
399              double min, double max);
400 
401   /// Verifies construction invariants and issues errors/warnings.
402   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
403                               Type expressedType, double min, double max);
404   double getMin() const;
405   double getMax() const;
406 };
407 
408 } // namespace quant
409 } // namespace mlir
410 
411 #endif // MLIR_DIALECT_QUANT_QUANTTYPES_H
412