1 //===- UniformSupport.h - Support utilities for uniform quant ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_ 10 #define MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_ 11 12 #include <utility> 13 14 #include "mlir/Dialect/Quant/QuantTypes.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/Types.h" 17 #include "llvm/ADT/APFloat.h" 18 #include "llvm/ADT/APInt.h" 19 #include "llvm/ADT/APSInt.h" 20 21 namespace mlir { 22 namespace quant { 23 24 /// Performs type conversion from an arbitrary input type to a type 25 /// that is expressed by a QuantizedType. 26 /// 27 /// This handles cases where the inputType is a supported primitive type 28 /// (i.e. f32, bf16, etc) or a vector/tensor type based on a supported 29 /// elemental type. 30 /// 31 /// Since conversion often involves introspecting some attributes of the 32 /// input type in order to determine how to represent it, this is a two step 33 /// process. 34 struct ExpressedToQuantizedConverter { 35 /// Creates a converter for the given input type. 36 static ExpressedToQuantizedConverter forInputType(Type inputType); 37 38 /// Converts the inputType to be based on the given elemental type, 39 /// returning the new type (or nullptr and emit an error on failure). 40 Type convert(QuantizedType elementalType) const; 41 42 /// Whether the conversion is legal. 43 explicit operator bool() const { return (bool)expressedType; } 44 45 /// The input type that is being converted from. 46 /// This may be an elemental or composite type. 47 const Type inputType; 48 49 /// Supported, elemental expressed type (i.e. f32). 50 /// Will be nullptr if conversion is not supported. 51 const Type expressedType; 52 }; 53 54 /// Reference implementation of converting between real numbers and values 55 /// represented by a UniformQuantizedType. 56 /// Note that this is not expected to be speedy and may be superseded eventually 57 /// by a more optimal implementation. 58 /// Also, the interface assumes that quantization is done per-layer and will 59 /// need to be wider for various per-channel schemes. As such, this is a 60 /// placeholder. 61 class UniformQuantizedValueConverter { 62 public: UniformQuantizedValueConverter(UniformQuantizedType uniformType)63 explicit UniformQuantizedValueConverter(UniformQuantizedType uniformType) 64 : UniformQuantizedValueConverter( 65 uniformType.getScale(), 66 static_cast<double>(uniformType.getZeroPoint()), 67 static_cast<double>(uniformType.getStorageTypeMin()), 68 static_cast<double>(uniformType.getStorageTypeMax()), 69 uniformType.getStorageTypeIntegralWidth(), uniformType.isSigned()) { 70 assert(uniformType.getExpressedType().isa<FloatType>()); 71 assert(uniformType.getStorageType().isSignlessInteger()); 72 } 73 UniformQuantizedValueConverter(double scale,double zeroPoint,double clampMin,double clampMax,uint32_t storageBitWidth,bool isSigned)74 UniformQuantizedValueConverter(double scale, double zeroPoint, 75 double clampMin, double clampMax, 76 uint32_t storageBitWidth, bool isSigned) 77 : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin), 78 clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint), 79 clampMinDouble(clampMin), clampMaxDouble(clampMax), 80 storageBitWidth(storageBitWidth), isSigned(isSigned), 81 roundMode(APFloat::rmNearestTiesToAway) {} 82 UniformQuantizedValueConverter(double scale,double zeroPoint,const APFloat & clampMin,const APFloat & clampMax,uint32_t storageBitWidth,bool isSigned)83 UniformQuantizedValueConverter(double scale, double zeroPoint, 84 const APFloat &clampMin, 85 const APFloat &clampMax, 86 uint32_t storageBitWidth, bool isSigned) 87 : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin), 88 clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint), 89 clampMinDouble(clampMin.convertToDouble()), 90 clampMaxDouble(clampMax.convertToDouble()), 91 storageBitWidth(storageBitWidth), isSigned(isSigned), 92 roundMode(APFloat::rmNearestTiesToAway) {} 93 quantizeFloatToInt(APFloat expressedValue)94 virtual APInt quantizeFloatToInt(APFloat expressedValue) const { 95 // This function is a performance critical code path in quantization 96 // since it runs for each single float parameter value. 97 98 // Specialize f32->u8/i8 case to optimize performance. 99 if (&expressedValue.getSemantics() == &APFloat::IEEEsingle() && 100 storageBitWidth == 8 && 101 roundMode == llvm::APFloatBase::rmNearestTiesToAway) { 102 return quantizeF32ToInt8(expressedValue); 103 } 104 105 bool lossy; 106 expressedValue.convert(scale.getSemantics(), roundMode, &lossy); 107 // fixedpoint = clamp(clampMin, clampMax, ( 108 // roundHalfToEven(expressed / scale) + zeroPoint)) 109 APFloat scaled = (expressedValue / scale); 110 scaled.roundToIntegral(roundMode); 111 scaled.add(zeroPoint, roundMode); 112 APFloat fixedpoint = llvm::minimum(scaled, clampMax); 113 fixedpoint = llvm::maximum(fixedpoint, clampMin); 114 115 llvm::APSInt result(storageBitWidth, !isSigned); 116 fixedpoint.convertToInteger(result, roundMode, &lossy); 117 118 return std::move(result); 119 } 120 quantizeFloatToInt64(APFloat expressedValue)121 int64_t quantizeFloatToInt64(APFloat expressedValue) const { 122 APInt qValue = quantizeFloatToInt(std::move(expressedValue)); 123 return isSigned ? qValue.getSExtValue() : qValue.getZExtValue(); 124 } 125 126 virtual ~UniformQuantizedValueConverter() = default; 127 128 private: 129 // An optimized implementation to quantize f32 to i8/u8 with C++ native 130 // arithmetic. quantizeF32ToInt8(APFloat expressedValue)131 virtual APInt quantizeF32ToInt8(APFloat expressedValue) const { 132 assert(&expressedValue.getSemantics() == &APFloat::IEEEsingle()); 133 assert(storageBitWidth == 8); 134 assert(roundMode == llvm::APFloatBase::rmNearestTiesToAway); 135 136 const float realValue = expressedValue.convertToFloat(); 137 138 const double scaled = realValue / scaleDouble + zeroPointDouble; 139 // Round to nearest integer with halfway cases rounded away from zero. 140 const double scaledRounded = std::round(scaled); 141 const double clamped = 142 std::min(std::max(scaledRounded, clampMinDouble), clampMaxDouble); 143 144 uint64_t signlessResult; 145 if (isSigned) { 146 int64_t clampedInt = static_cast<int8_t>(clamped); 147 memcpy(&signlessResult, &clampedInt, sizeof(clampedInt)); 148 } else { 149 signlessResult = static_cast<uint8_t>(clamped); 150 } 151 return APInt(storageBitWidth, signlessResult); 152 } 153 154 // Keep both APFloat and double versions of the quantization parameters 155 // around since they will be used in generic and specialized arithmetic, 156 // respectively. 157 const APFloat scale; 158 const APFloat zeroPoint; 159 const APFloat clampMin; 160 const APFloat clampMax; 161 162 const double scaleDouble; 163 const double zeroPointDouble; 164 const double clampMinDouble; 165 const double clampMaxDouble; 166 167 const uint32_t storageBitWidth; 168 const bool isSigned; 169 const llvm::APFloat::roundingMode roundMode; 170 }; 171 172 /// An utility class to quantize an attribute by the per-axis quantization 173 /// parameters. The size of the quantization dim in the converted elements 174 /// attribute should matche the size of of scales/zeroPoints vectors in the 175 /// quantization parameters. 176 class UniformQuantizedPerAxisValueConverter { 177 public: UniformQuantizedPerAxisValueConverter(UniformQuantizedPerAxisType uniformType)178 explicit UniformQuantizedPerAxisValueConverter( 179 UniformQuantizedPerAxisType uniformType) 180 : scales(uniformType.getScales()), 181 zeroPoints(uniformType.getZeroPoints()), 182 clampMin(static_cast<double>(uniformType.getStorageTypeMin())), 183 clampMax(static_cast<double>(uniformType.getStorageTypeMax())), 184 storageBitWidth(uniformType.getStorageTypeIntegralWidth()), 185 isSigned(uniformType.isSigned()), 186 quantizationDim(uniformType.getQuantizedDimension()) { 187 assert(uniformType.getExpressedType().isa<FloatType>()); 188 assert(uniformType.getStorageType().isSignlessInteger()); 189 assert(scales.size() == zeroPoints.size()); 190 } 191 192 /// Quantize an Attribute by the quantization parameters. Return nullptr if 193 /// the conversion fails or the input array isn't an ElementsAttr. 194 ElementsAttr convert(Attribute realValue); 195 196 private: 197 /// Quantize an DenseFPElementsAttr by the quantization parameters. 198 DenseElementsAttr convert(DenseFPElementsAttr attr); 199 200 /// Get a uniform converter for the index-th chunk along the quantizationDim. 201 /// All the elements in this chunk is quantized by the returned converter. getPerChunkConverter(int index)202 UniformQuantizedValueConverter getPerChunkConverter(int index) const { 203 UniformQuantizedValueConverter converter(scales[index], zeroPoints[index], 204 clampMin, clampMax, 205 storageBitWidth, isSigned); 206 return converter; 207 } 208 209 const ArrayRef<double> scales; 210 const ArrayRef<int64_t> zeroPoints; 211 const APFloat clampMin; 212 const APFloat clampMax; 213 const uint32_t storageBitWidth; 214 const bool isSigned; 215 int32_t quantizationDim; 216 }; 217 218 } // namespace quant 219 } // namespace mlir 220 221 #endif // MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_ 222