1 //===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Quant/FakeQuantSupport.h" 10 #include "mlir/Dialect/Quant/QuantTypes.h" 11 12 using namespace mlir; 13 using namespace mlir::quant; 14 15 static bool getDefaultStorageParams(unsigned numBits, bool narrowRange, 16 bool isSigned, MLIRContext *ctx, 17 Type &storageType, int64_t &qmin, 18 int64_t &qmax) { 19 // Hard-coded type mapping from TFLite. 20 if (numBits <= 8) { 21 storageType = IntegerType::get(8, ctx); 22 if (isSigned) { 23 qmin = -128; 24 qmax = 127; 25 } else { 26 qmin = 0; 27 qmax = 255; 28 } 29 } else if (numBits <= 16) { 30 storageType = IntegerType::get(16, ctx); 31 if (isSigned) { 32 qmin = -32768; 33 qmax = 32767; 34 } else { 35 qmin = 0; 36 qmax = 65535; 37 } 38 } else { 39 return true; 40 } 41 42 // Handle narrowRange. 43 if (narrowRange) { 44 qmin += 1; 45 } 46 return false; 47 } 48 49 // This is a specific implementation of nudging: 50 // If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted 51 // to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero 52 // point is derived from the shifted range, and the scale isn't changed. As 53 // a consequence some values, which are supposed in the original [rmin, rmax] 54 // range will be outside the shifted range and be clamped during quantization. 55 // TODO(fengliuai): we should nudge the scale as well, but that requires the 56 // fake quant op used in the training to use the nudged scale as well. 57 static void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, 58 double rmax, double &scale, 59 int64_t &nudgedZeroPoint) { 60 // Determine the scale. 61 const double qminDouble = qmin; 62 const double qmaxDouble = qmax; 63 scale = (rmax - rmin) / (qmaxDouble - qminDouble); 64 65 // Zero point computation. 66 // In float, solve the affine equation for any known pair 67 // (real value, corresponding quantized value), of which, two such pairs 68 // are known: (rmin, qmin), (rmax, qmax). 69 // The arithmetic error on the zero point computed from either pair will be 70 // roughly machine_epsilon * (sum of absolute values of terms). 71 // Use the variant that adds the smaller error. 72 const double zeroPointFromMin = qminDouble - rmin / scale; 73 const double zeroPointFromMinError = 74 std::abs(qminDouble) + std::abs(rmin / scale); 75 const double zeroPointFromMax = qmaxDouble - rmax / scale; 76 const double zeroPointFromMaxError = 77 std::abs(qmaxDouble) + std::abs(rmax / scale); 78 79 const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError) 80 ? zeroPointFromMin 81 : zeroPointFromMax; 82 83 // Now nudge the zero point to be an integer. 84 nudgedZeroPoint = 0; 85 if (zeroPointDouble < qminDouble) { 86 nudgedZeroPoint = qmin; 87 } else if (zeroPointDouble > qmaxDouble) { 88 nudgedZeroPoint = qmax; 89 } else { 90 nudgedZeroPoint = round(zeroPointDouble); 91 } 92 93 // By construction, the nudged zero point should always be in range. 94 assert(nudgedZeroPoint >= qmin); 95 assert(nudgedZeroPoint <= qmax); 96 } 97 98 UniformQuantizedType 99 mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin, 100 double rmax, bool narrowRange, 101 Type expressedType, bool isSigned) { 102 MLIRContext *ctx = expressedType.getContext(); 103 unsigned flags = isSigned ? QuantizationFlags::Signed : 0; 104 Type storageType; 105 int64_t qmin; 106 int64_t qmax; 107 if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType, 108 qmin, qmax)) { 109 return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits, 110 nullptr); 111 } 112 113 // Special case where min/max is close enough. The tensor contents are all 114 // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero 115 // points and dequantized to 0.0. 116 if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) { 117 return UniformQuantizedType::getChecked(flags, storageType, expressedType, 118 1.0, qmin, qmin, qmax, loc); 119 } 120 121 double scale; 122 int64_t nudgedZeroPoint; 123 getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); 124 125 return UniformQuantizedType::getChecked(flags, storageType, expressedType, 126 scale, nudgedZeroPoint, qmin, qmax, 127 loc); 128 } 129 130 UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType( 131 Location loc, unsigned numBits, int32_t quantizedDimension, 132 ArrayRef<double> rmins, ArrayRef<double> rmaxs, bool narrowRange, 133 Type expressedType, bool isSigned) { 134 size_t axis_size = rmins.size(); 135 if (axis_size != rmaxs.size()) { 136 return (emitError(loc, "mismatched per-axis min and max size: ") 137 << axis_size << " vs. " << rmaxs.size(), 138 nullptr); 139 } 140 141 MLIRContext *ctx = expressedType.getContext(); 142 Type storageType; 143 int64_t qmin; 144 int64_t qmax; 145 if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType, 146 qmin, qmax)) { 147 return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits, 148 nullptr); 149 } 150 151 SmallVector<double, 4> scales; 152 SmallVector<int64_t, 4> zeroPoints; 153 scales.reserve(axis_size); 154 zeroPoints.reserve(axis_size); 155 for (size_t axis = 0; axis != axis_size; ++axis) { 156 double rmin = rmins[axis]; 157 double rmax = rmaxs[axis]; 158 if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) { 159 scales.push_back(1.0); 160 zeroPoints.push_back(qmin); 161 continue; 162 } 163 164 double scale; 165 int64_t nudgedZeroPoint; 166 getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); 167 scales.push_back(scale); 168 zeroPoints.push_back(nudgedZeroPoint); 169 } 170 171 unsigned flags = isSigned ? QuantizationFlags::Signed : 0; 172 return UniformQuantizedPerAxisType::getChecked( 173 flags, storageType, expressedType, scales, zeroPoints, quantizedDimension, 174 qmin, qmax, loc); 175 } 176