1 //===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Quant/FakeQuantSupport.h"
10 #include "mlir/Dialect/Quant/QuantTypes.h"
11 
12 using namespace mlir;
13 using namespace mlir::quant;
14 
15 static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
16                                     bool isSigned, MLIRContext *ctx,
17                                     Type &storageType, int64_t &qmin,
18                                     int64_t &qmax) {
19   // Hard-coded type mapping from TFLite.
20   if (numBits <= 8) {
21     storageType = IntegerType::get(8, ctx);
22     if (isSigned) {
23       qmin = -128;
24       qmax = 127;
25     } else {
26       qmin = 0;
27       qmax = 255;
28     }
29   } else if (numBits <= 16) {
30     storageType = IntegerType::get(16, ctx);
31     if (isSigned) {
32       qmin = -32768;
33       qmax = 32767;
34     } else {
35       qmin = 0;
36       qmax = 65535;
37     }
38   } else {
39     return true;
40   }
41 
42   // Handle narrowRange.
43   if (narrowRange) {
44     qmin += 1;
45   }
46   return false;
47 }
48 
49 // This is a specific implementation of nudging:
50 // If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted
51 // to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero
52 // point is derived from the shifted range, and the scale isn't changed. As
53 // a consequence some values, which are supposed in the original [rmin, rmax]
54 // range will be outside the shifted range and be clamped during quantization.
55 // TODO(fengliuai): we should nudge the scale as well, but that requires the
56 // fake quant op used in the training to use the nudged scale as well.
57 static void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin,
58                                        double rmax, double &scale,
59                                        int64_t &nudgedZeroPoint) {
60   // Determine the scale.
61   const double qminDouble = qmin;
62   const double qmaxDouble = qmax;
63   scale = (rmax - rmin) / (qmaxDouble - qminDouble);
64 
65   // Zero point computation.
66   // In float, solve the affine equation for any known pair
67   // (real value, corresponding quantized value), of which, two such pairs
68   // are known: (rmin, qmin), (rmax, qmax).
69   // The arithmetic error on the zero point computed from either pair will be
70   // roughly machine_epsilon * (sum of absolute values of terms).
71   // Use the variant that adds the smaller error.
72   const double zeroPointFromMin = qminDouble - rmin / scale;
73   const double zeroPointFromMinError =
74       std::abs(qminDouble) + std::abs(rmin / scale);
75   const double zeroPointFromMax = qmaxDouble - rmax / scale;
76   const double zeroPointFromMaxError =
77       std::abs(qmaxDouble) + std::abs(rmax / scale);
78 
79   const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError)
80                                      ? zeroPointFromMin
81                                      : zeroPointFromMax;
82 
83   // Now nudge the zero point to be an integer.
84   nudgedZeroPoint = 0;
85   if (zeroPointDouble < qminDouble) {
86     nudgedZeroPoint = qmin;
87   } else if (zeroPointDouble > qmaxDouble) {
88     nudgedZeroPoint = qmax;
89   } else {
90     nudgedZeroPoint = round(zeroPointDouble);
91   }
92 
93   // By construction, the nudged zero point should always be in range.
94   assert(nudgedZeroPoint >= qmin);
95   assert(nudgedZeroPoint <= qmax);
96 }
97 
98 UniformQuantizedType
99 mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
100                                   double rmax, bool narrowRange,
101                                   Type expressedType, bool isSigned) {
102   MLIRContext *ctx = expressedType.getContext();
103   unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
104   Type storageType;
105   int64_t qmin;
106   int64_t qmax;
107   if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
108                               qmin, qmax)) {
109     return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
110             nullptr);
111   }
112 
113   // Special case where min/max is close enough. The tensor contents are all
114   // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
115   // points and dequantized to 0.0.
116   if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
117     return UniformQuantizedType::getChecked(flags, storageType, expressedType,
118                                             1.0, qmin, qmin, qmax, loc);
119   }
120 
121   double scale;
122   int64_t nudgedZeroPoint;
123   getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
124 
125   return UniformQuantizedType::getChecked(flags, storageType, expressedType,
126                                           scale, nudgedZeroPoint, qmin, qmax,
127                                           loc);
128 }
129 
130 UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType(
131     Location loc, unsigned numBits, int32_t quantizedDimension,
132     ArrayRef<double> rmins, ArrayRef<double> rmaxs, bool narrowRange,
133     Type expressedType, bool isSigned) {
134   size_t axis_size = rmins.size();
135   if (axis_size != rmaxs.size()) {
136     return (emitError(loc, "mismatched per-axis min and max size: ")
137                 << axis_size << " vs. " << rmaxs.size(),
138             nullptr);
139   }
140 
141   MLIRContext *ctx = expressedType.getContext();
142   Type storageType;
143   int64_t qmin;
144   int64_t qmax;
145   if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
146                               qmin, qmax)) {
147     return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
148             nullptr);
149   }
150 
151   SmallVector<double, 4> scales;
152   SmallVector<int64_t, 4> zeroPoints;
153   scales.reserve(axis_size);
154   zeroPoints.reserve(axis_size);
155   for (size_t axis = 0; axis != axis_size; ++axis) {
156     double rmin = rmins[axis];
157     double rmax = rmaxs[axis];
158     if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
159       scales.push_back(1.0);
160       zeroPoints.push_back(qmin);
161       continue;
162     }
163 
164     double scale;
165     int64_t nudgedZeroPoint;
166     getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
167     scales.push_back(scale);
168     zeroPoints.push_back(nudgedZeroPoint);
169   }
170 
171   unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
172   return UniformQuantizedPerAxisType::getChecked(
173       flags, storageType, expressedType, scales, zeroPoints, quantizedDimension,
174       qmin, qmax, loc);
175 }
176