1 //===- UniformSupport.h - Support utilities for uniform quant ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_
10 #define MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_
11 
12 #include <utility>
13 
14 #include "mlir/Dialect/Quant/QuantTypes.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Types.h"
17 #include "llvm/ADT/APFloat.h"
18 #include "llvm/ADT/APInt.h"
19 #include "llvm/ADT/APSInt.h"
20 
21 namespace mlir {
22 namespace quant {
23 
24 /// Performs type conversion from an arbitrary input type to a type
25 /// that is expressed by a QuantizedType.
26 ///
27 /// This handles cases where the inputType is a supported primitive type
28 /// (i.e. f32, bf16, etc) or a vector/tensor type based on a supported
29 /// elemental type.
30 ///
31 /// Since conversion often involves introspecting some attributes of the
32 /// input type in order to determine how to represent it, this is a two step
33 /// process.
34 struct ExpressedToQuantizedConverter {
35   /// Creates a converter for the given input type.
36   static ExpressedToQuantizedConverter forInputType(Type inputType);
37 
38   /// Converts the inputType to be based on the given elemental type,
39   /// returning the new type (or nullptr and emit an error on failure).
40   Type convert(QuantizedType elementalType) const;
41 
42   /// Whether the conversion is legal.
43   explicit operator bool() const { return (bool)expressedType; }
44 
45   /// The input type that is being converted from.
46   /// This may be an elemental or composite type.
47   const Type inputType;
48 
49   /// Supported, elemental expressed type (i.e. f32).
50   /// Will be nullptr if conversion is not supported.
51   const Type expressedType;
52 };
53 
54 /// Reference implementation of converting between real numbers and values
55 /// represented by a UniformQuantizedType.
56 /// Note that this is not expected to be speedy and may be superseded eventually
57 /// by a more optimal implementation.
58 /// Also, the interface assumes that quantization is done per-layer and will
59 /// need to be wider for various per-channel schemes. As such, this is a
60 /// placeholder.
61 class UniformQuantizedValueConverter {
62 public:
UniformQuantizedValueConverter(UniformQuantizedType uniformType)63   explicit UniformQuantizedValueConverter(UniformQuantizedType uniformType)
64       : UniformQuantizedValueConverter(
65             uniformType.getScale(),
66             static_cast<double>(uniformType.getZeroPoint()),
67             static_cast<double>(uniformType.getStorageTypeMin()),
68             static_cast<double>(uniformType.getStorageTypeMax()),
69             uniformType.getStorageTypeIntegralWidth(), uniformType.isSigned()) {
70     assert(uniformType.getExpressedType().isa<FloatType>());
71     assert(uniformType.getStorageType().isSignlessInteger());
72   }
73 
UniformQuantizedValueConverter(double scale,double zeroPoint,double clampMin,double clampMax,uint32_t storageBitWidth,bool isSigned)74   UniformQuantizedValueConverter(double scale, double zeroPoint,
75                                  double clampMin, double clampMax,
76                                  uint32_t storageBitWidth, bool isSigned)
77       : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin),
78         clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint),
79         clampMinDouble(clampMin), clampMaxDouble(clampMax),
80         storageBitWidth(storageBitWidth), isSigned(isSigned),
81         roundMode(APFloat::rmNearestTiesToAway) {}
82 
UniformQuantizedValueConverter(double scale,double zeroPoint,const APFloat & clampMin,const APFloat & clampMax,uint32_t storageBitWidth,bool isSigned)83   UniformQuantizedValueConverter(double scale, double zeroPoint,
84                                  const APFloat &clampMin,
85                                  const APFloat &clampMax,
86                                  uint32_t storageBitWidth, bool isSigned)
87       : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin),
88         clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint),
89         clampMinDouble(clampMin.convertToDouble()),
90         clampMaxDouble(clampMax.convertToDouble()),
91         storageBitWidth(storageBitWidth), isSigned(isSigned),
92         roundMode(APFloat::rmNearestTiesToAway) {}
93 
quantizeFloatToInt(APFloat expressedValue)94   virtual APInt quantizeFloatToInt(APFloat expressedValue) const {
95     // This function is a performance critical code path in quantization
96     // since it runs for each single float parameter value.
97 
98     // Specialize f32->u8/i8 case to optimize performance.
99     if (&expressedValue.getSemantics() == &APFloat::IEEEsingle() &&
100         storageBitWidth == 8 &&
101         roundMode == llvm::APFloatBase::rmNearestTiesToAway) {
102       return quantizeF32ToInt8(expressedValue);
103     }
104 
105     bool lossy;
106     expressedValue.convert(scale.getSemantics(), roundMode, &lossy);
107     // fixedpoint = clamp(clampMin, clampMax, (
108     //   roundHalfToEven(expressed / scale) + zeroPoint))
109     APFloat scaled = (expressedValue / scale);
110     scaled.roundToIntegral(roundMode);
111     scaled.add(zeroPoint, roundMode);
112     APFloat fixedpoint = llvm::minimum(scaled, clampMax);
113     fixedpoint = llvm::maximum(fixedpoint, clampMin);
114 
115     llvm::APSInt result(storageBitWidth, !isSigned);
116     fixedpoint.convertToInteger(result, roundMode, &lossy);
117 
118     return std::move(result);
119   }
120 
quantizeFloatToInt64(APFloat expressedValue)121   int64_t quantizeFloatToInt64(APFloat expressedValue) const {
122     APInt qValue = quantizeFloatToInt(std::move(expressedValue));
123     return isSigned ? qValue.getSExtValue() : qValue.getZExtValue();
124   }
125 
126   virtual ~UniformQuantizedValueConverter() = default;
127 
128 private:
129   // An optimized implementation to quantize f32 to i8/u8 with C++ native
130   // arithmetic.
quantizeF32ToInt8(APFloat expressedValue)131   virtual APInt quantizeF32ToInt8(APFloat expressedValue) const {
132     assert(&expressedValue.getSemantics() == &APFloat::IEEEsingle());
133     assert(storageBitWidth == 8);
134     assert(roundMode == llvm::APFloatBase::rmNearestTiesToAway);
135 
136     const float realValue = expressedValue.convertToFloat();
137 
138     const double scaled = realValue / scaleDouble + zeroPointDouble;
139     // Round to nearest integer with halfway cases rounded away from zero.
140     const double scaledRounded = std::round(scaled);
141     const double clamped =
142         std::min(std::max(scaledRounded, clampMinDouble), clampMaxDouble);
143 
144     uint64_t signlessResult;
145     if (isSigned) {
146       int64_t clampedInt = static_cast<int8_t>(clamped);
147       memcpy(&signlessResult, &clampedInt, sizeof(clampedInt));
148     } else {
149       signlessResult = static_cast<uint8_t>(clamped);
150     }
151     return APInt(storageBitWidth, signlessResult);
152   }
153 
154   // Keep both APFloat and double versions of the quantization parameters
155   // around since they will be used in generic and specialized arithmetic,
156   // respectively.
157   const APFloat scale;
158   const APFloat zeroPoint;
159   const APFloat clampMin;
160   const APFloat clampMax;
161 
162   const double scaleDouble;
163   const double zeroPointDouble;
164   const double clampMinDouble;
165   const double clampMaxDouble;
166 
167   const uint32_t storageBitWidth;
168   const bool isSigned;
169   const llvm::APFloat::roundingMode roundMode;
170 };
171 
172 /// An utility class to quantize an attribute by the per-axis quantization
173 /// parameters. The size of the quantization dim in the converted elements
174 /// attribute should matche the size of of scales/zeroPoints vectors in the
175 /// quantization parameters.
176 class UniformQuantizedPerAxisValueConverter {
177 public:
UniformQuantizedPerAxisValueConverter(UniformQuantizedPerAxisType uniformType)178   explicit UniformQuantizedPerAxisValueConverter(
179       UniformQuantizedPerAxisType uniformType)
180       : scales(uniformType.getScales()),
181         zeroPoints(uniformType.getZeroPoints()),
182         clampMin(static_cast<double>(uniformType.getStorageTypeMin())),
183         clampMax(static_cast<double>(uniformType.getStorageTypeMax())),
184         storageBitWidth(uniformType.getStorageTypeIntegralWidth()),
185         isSigned(uniformType.isSigned()),
186         quantizationDim(uniformType.getQuantizedDimension()) {
187     assert(uniformType.getExpressedType().isa<FloatType>());
188     assert(uniformType.getStorageType().isSignlessInteger());
189     assert(scales.size() == zeroPoints.size());
190   }
191 
192   /// Quantize an Attribute by the quantization parameters. Return nullptr if
193   /// the conversion fails or the input array isn't an ElementsAttr.
194   ElementsAttr convert(Attribute realValue);
195 
196 private:
197   /// Quantize an DenseFPElementsAttr by the quantization parameters.
198   DenseElementsAttr convert(DenseFPElementsAttr attr);
199 
200   /// Get a uniform converter for the index-th chunk along the quantizationDim.
201   /// All the elements in this chunk is quantized by the returned converter.
getPerChunkConverter(int index)202   UniformQuantizedValueConverter getPerChunkConverter(int index) const {
203     UniformQuantizedValueConverter converter(scales[index], zeroPoints[index],
204                                              clampMin, clampMax,
205                                              storageBitWidth, isSigned);
206     return converter;
207   }
208 
209   const ArrayRef<double> scales;
210   const ArrayRef<int64_t> zeroPoints;
211   const APFloat clampMin;
212   const APFloat clampMax;
213   const uint32_t storageBitWidth;
214   const bool isSigned;
215   int32_t quantizationDim;
216 };
217 
218 } // namespace quant
219 } // namespace mlir
220 
221 #endif // MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_
222