1363dd3f3SRob Suderman //===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===//
2363dd3f3SRob Suderman //
3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6363dd3f3SRob Suderman //
7363dd3f3SRob Suderman //===----------------------------------------------------------------------===//
8363dd3f3SRob Suderman
9363dd3f3SRob Suderman #include "mlir/Dialect/Quant/FakeQuantSupport.h"
10363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantTypes.h"
11363dd3f3SRob Suderman
12363dd3f3SRob Suderman using namespace mlir;
13363dd3f3SRob Suderman using namespace mlir::quant;
14363dd3f3SRob Suderman
getDefaultStorageParams(unsigned numBits,bool narrowRange,bool isSigned,MLIRContext * ctx,Type & storageType,int64_t & qmin,int64_t & qmax)15363dd3f3SRob Suderman static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
16363dd3f3SRob Suderman bool isSigned, MLIRContext *ctx,
17363dd3f3SRob Suderman Type &storageType, int64_t &qmin,
18363dd3f3SRob Suderman int64_t &qmax) {
19363dd3f3SRob Suderman // Hard-coded type mapping from TFLite.
20363dd3f3SRob Suderman if (numBits <= 8) {
211b97cdf8SRiver Riddle storageType = IntegerType::get(ctx, 8);
22363dd3f3SRob Suderman if (isSigned) {
23363dd3f3SRob Suderman qmin = -128;
24363dd3f3SRob Suderman qmax = 127;
25363dd3f3SRob Suderman } else {
26363dd3f3SRob Suderman qmin = 0;
27363dd3f3SRob Suderman qmax = 255;
28363dd3f3SRob Suderman }
29363dd3f3SRob Suderman } else if (numBits <= 16) {
301b97cdf8SRiver Riddle storageType = IntegerType::get(ctx, 16);
31363dd3f3SRob Suderman if (isSigned) {
32363dd3f3SRob Suderman qmin = -32768;
33363dd3f3SRob Suderman qmax = 32767;
34363dd3f3SRob Suderman } else {
35363dd3f3SRob Suderman qmin = 0;
36363dd3f3SRob Suderman qmax = 65535;
37363dd3f3SRob Suderman }
38b578c92aSFeng Liu } else if (numBits <= 32) {
391b97cdf8SRiver Riddle storageType = IntegerType::get(ctx, 32);
40b578c92aSFeng Liu if (isSigned) {
41b578c92aSFeng Liu qmin = std::numeric_limits<int32_t>::min();
42b578c92aSFeng Liu qmax = std::numeric_limits<int32_t>::max();
43b578c92aSFeng Liu } else {
44b578c92aSFeng Liu qmin = std::numeric_limits<uint32_t>::min();
45b578c92aSFeng Liu qmax = std::numeric_limits<uint32_t>::max();
46b578c92aSFeng Liu }
47363dd3f3SRob Suderman } else {
48363dd3f3SRob Suderman return true;
49363dd3f3SRob Suderman }
50363dd3f3SRob Suderman
51363dd3f3SRob Suderman // Handle narrowRange.
52363dd3f3SRob Suderman if (narrowRange) {
53363dd3f3SRob Suderman qmin += 1;
54363dd3f3SRob Suderman }
55363dd3f3SRob Suderman return false;
56363dd3f3SRob Suderman }
57363dd3f3SRob Suderman
58363dd3f3SRob Suderman // This is a specific implementation of nudging:
59363dd3f3SRob Suderman // If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted
60363dd3f3SRob Suderman // to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero
61363dd3f3SRob Suderman // point is derived from the shifted range, and the scale isn't changed. As
62363dd3f3SRob Suderman // a consequence some values, which are supposed in the original [rmin, rmax]
63363dd3f3SRob Suderman // range will be outside the shifted range and be clamped during quantization.
649db53a18SRiver Riddle // TODO: we should nudge the scale as well, but that requires the
65363dd3f3SRob Suderman // fake quant op used in the training to use the nudged scale as well.
getNudgedScaleAndZeroPoint(int64_t qmin,int64_t qmax,double rmin,double rmax,double & scale,int64_t & nudgedZeroPoint)66363dd3f3SRob Suderman static void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin,
67363dd3f3SRob Suderman double rmax, double &scale,
68363dd3f3SRob Suderman int64_t &nudgedZeroPoint) {
69363dd3f3SRob Suderman // Determine the scale.
70363dd3f3SRob Suderman const double qminDouble = qmin;
71363dd3f3SRob Suderman const double qmaxDouble = qmax;
72363dd3f3SRob Suderman scale = (rmax - rmin) / (qmaxDouble - qminDouble);
73363dd3f3SRob Suderman
74363dd3f3SRob Suderman // Zero point computation.
75363dd3f3SRob Suderman // In float, solve the affine equation for any known pair
76363dd3f3SRob Suderman // (real value, corresponding quantized value), of which, two such pairs
77363dd3f3SRob Suderman // are known: (rmin, qmin), (rmax, qmax).
78363dd3f3SRob Suderman // The arithmetic error on the zero point computed from either pair will be
79363dd3f3SRob Suderman // roughly machine_epsilon * (sum of absolute values of terms).
80363dd3f3SRob Suderman // Use the variant that adds the smaller error.
81363dd3f3SRob Suderman const double zeroPointFromMin = qminDouble - rmin / scale;
82363dd3f3SRob Suderman const double zeroPointFromMinError =
83363dd3f3SRob Suderman std::abs(qminDouble) + std::abs(rmin / scale);
84363dd3f3SRob Suderman const double zeroPointFromMax = qmaxDouble - rmax / scale;
85363dd3f3SRob Suderman const double zeroPointFromMaxError =
86363dd3f3SRob Suderman std::abs(qmaxDouble) + std::abs(rmax / scale);
87363dd3f3SRob Suderman
88363dd3f3SRob Suderman const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError)
89363dd3f3SRob Suderman ? zeroPointFromMin
90363dd3f3SRob Suderman : zeroPointFromMax;
91363dd3f3SRob Suderman
92363dd3f3SRob Suderman // Now nudge the zero point to be an integer.
93363dd3f3SRob Suderman nudgedZeroPoint = 0;
94363dd3f3SRob Suderman if (zeroPointDouble < qminDouble) {
95363dd3f3SRob Suderman nudgedZeroPoint = qmin;
96363dd3f3SRob Suderman } else if (zeroPointDouble > qmaxDouble) {
97363dd3f3SRob Suderman nudgedZeroPoint = qmax;
98363dd3f3SRob Suderman } else {
99363dd3f3SRob Suderman nudgedZeroPoint = round(zeroPointDouble);
100363dd3f3SRob Suderman }
101363dd3f3SRob Suderman
102363dd3f3SRob Suderman // By construction, the nudged zero point should always be in range.
103363dd3f3SRob Suderman assert(nudgedZeroPoint >= qmin);
104363dd3f3SRob Suderman assert(nudgedZeroPoint <= qmax);
105363dd3f3SRob Suderman }
106363dd3f3SRob Suderman
107363dd3f3SRob Suderman UniformQuantizedType
fakeQuantAttrsToType(Location loc,unsigned numBits,double rmin,double rmax,bool narrowRange,Type expressedType,bool isSigned)108363dd3f3SRob Suderman mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
109363dd3f3SRob Suderman double rmax, bool narrowRange,
110363dd3f3SRob Suderman Type expressedType, bool isSigned) {
111363dd3f3SRob Suderman MLIRContext *ctx = expressedType.getContext();
112363dd3f3SRob Suderman unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
113363dd3f3SRob Suderman Type storageType;
114363dd3f3SRob Suderman int64_t qmin;
115363dd3f3SRob Suderman int64_t qmax;
116363dd3f3SRob Suderman if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
117363dd3f3SRob Suderman qmin, qmax)) {
118363dd3f3SRob Suderman return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
119363dd3f3SRob Suderman nullptr);
120363dd3f3SRob Suderman }
121363dd3f3SRob Suderman
122363dd3f3SRob Suderman // Special case where min/max is close enough. The tensor contents are all
123363dd3f3SRob Suderman // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
124363dd3f3SRob Suderman // points and dequantized to 0.0.
125363dd3f3SRob Suderman if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
12606e25d56SRiver Riddle return UniformQuantizedType::getChecked(
12706e25d56SRiver Riddle loc, flags, storageType, expressedType, 1.0, qmin, qmin, qmax);
128363dd3f3SRob Suderman }
129363dd3f3SRob Suderman
130363dd3f3SRob Suderman double scale;
131363dd3f3SRob Suderman int64_t nudgedZeroPoint;
132363dd3f3SRob Suderman getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
133363dd3f3SRob Suderman
13406e25d56SRiver Riddle return UniformQuantizedType::getChecked(loc, flags, storageType,
13506e25d56SRiver Riddle expressedType, scale, nudgedZeroPoint,
13606e25d56SRiver Riddle qmin, qmax);
137363dd3f3SRob Suderman }
138363dd3f3SRob Suderman
fakeQuantAttrsToType(Location loc,unsigned numBits,int32_t quantizedDimension,ArrayRef<double> rmins,ArrayRef<double> rmaxs,bool narrowRange,Type expressedType,bool isSigned)139363dd3f3SRob Suderman UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType(
140363dd3f3SRob Suderman Location loc, unsigned numBits, int32_t quantizedDimension,
141363dd3f3SRob Suderman ArrayRef<double> rmins, ArrayRef<double> rmaxs, bool narrowRange,
142363dd3f3SRob Suderman Type expressedType, bool isSigned) {
143*02b6fb21SMehdi Amini size_t axisSize = rmins.size();
144*02b6fb21SMehdi Amini if (axisSize != rmaxs.size()) {
145363dd3f3SRob Suderman return (emitError(loc, "mismatched per-axis min and max size: ")
146*02b6fb21SMehdi Amini << axisSize << " vs. " << rmaxs.size(),
147363dd3f3SRob Suderman nullptr);
148363dd3f3SRob Suderman }
149363dd3f3SRob Suderman
150363dd3f3SRob Suderman MLIRContext *ctx = expressedType.getContext();
151363dd3f3SRob Suderman Type storageType;
152363dd3f3SRob Suderman int64_t qmin;
153363dd3f3SRob Suderman int64_t qmax;
154363dd3f3SRob Suderman if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
155363dd3f3SRob Suderman qmin, qmax)) {
156363dd3f3SRob Suderman return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
157363dd3f3SRob Suderman nullptr);
158363dd3f3SRob Suderman }
159363dd3f3SRob Suderman
160363dd3f3SRob Suderman SmallVector<double, 4> scales;
161363dd3f3SRob Suderman SmallVector<int64_t, 4> zeroPoints;
162*02b6fb21SMehdi Amini scales.reserve(axisSize);
163*02b6fb21SMehdi Amini zeroPoints.reserve(axisSize);
164*02b6fb21SMehdi Amini for (size_t axis = 0; axis != axisSize; ++axis) {
165363dd3f3SRob Suderman double rmin = rmins[axis];
166363dd3f3SRob Suderman double rmax = rmaxs[axis];
167363dd3f3SRob Suderman if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
168363dd3f3SRob Suderman scales.push_back(1.0);
169363dd3f3SRob Suderman zeroPoints.push_back(qmin);
170363dd3f3SRob Suderman continue;
171363dd3f3SRob Suderman }
172363dd3f3SRob Suderman
173363dd3f3SRob Suderman double scale;
174363dd3f3SRob Suderman int64_t nudgedZeroPoint;
175363dd3f3SRob Suderman getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
176363dd3f3SRob Suderman scales.push_back(scale);
177363dd3f3SRob Suderman zeroPoints.push_back(nudgedZeroPoint);
178363dd3f3SRob Suderman }
179363dd3f3SRob Suderman
180363dd3f3SRob Suderman unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
181363dd3f3SRob Suderman return UniformQuantizedPerAxisType::getChecked(
18206e25d56SRiver Riddle loc, flags, storageType, expressedType, scales, zeroPoints,
18306e25d56SRiver Riddle quantizedDimension, qmin, qmax);
184363dd3f3SRob Suderman }
185