1 //===- QuantTypes.h - Quantization Ops and Types ----------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_QUANT_QUANTTYPES_H 10 #define MLIR_DIALECT_QUANT_QUANTTYPES_H 11 12 #include "mlir/IR/Attributes.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/BuiltinTypes.h" 15 #include "mlir/IR/Dialect.h" 16 #include "mlir/IR/OpDefinition.h" 17 #include "mlir/IR/Types.h" 18 #include "llvm/Support/MathExtras.h" 19 20 namespace mlir { 21 namespace quant { 22 23 class QuantizedIntegerType; 24 25 namespace detail { 26 27 struct QuantizedTypeStorage; 28 struct AnyQuantizedTypeStorage; 29 struct UniformQuantizedTypeStorage; 30 struct UniformQuantizedPerAxisTypeStorage; 31 struct CalibratedQuantizedTypeStorage; 32 33 } // namespace detail 34 35 /// Enumeration of bit-mapped flags related to quantized types. 36 namespace QuantizationFlags { 37 enum FlagValue { 38 // Indicates that the storage type should be interpreted as a signed 39 // integer. The default is to interpret it as an unsigned value. 40 Signed = 1, 41 }; 42 } // namespace QuantizationFlags 43 44 /// Base class for all quantized types known to this dialect. 45 /// All quantized types have: 46 /// - storageType: The (narrower) numeric type that is being used to 47 /// approximate some expressed type. 48 /// - expressedType: The type that is being approximated. 49 /// 50 /// The base class provides generic support for manipulating the types based 51 /// on these fields. 52 class QuantizedType : public Type { 53 public: 54 using ImplType = detail::QuantizedTypeStorage; 55 using Type::Type; 56 57 /// The maximum number of bits supported for storage types. 58 static constexpr unsigned MaxStorageBits = 32; 59 60 static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError, 61 unsigned flags, Type storageType, 62 Type expressedType, int64_t storageTypeMin, 63 int64_t storageTypeMax); 64 65 /// Support method to enable LLVM-style type casting. 66 static bool classof(Type type); 67 68 /// Gets the minimum possible stored by a storageType. storageTypeMin must 69 /// be greater than or equal to this value. getDefaultMinimumForInteger(bool isSigned,unsigned integralWidth)70 static int64_t getDefaultMinimumForInteger(bool isSigned, 71 unsigned integralWidth) { 72 if (isSigned) { 73 return llvm::minIntN(integralWidth); 74 } 75 return 0; 76 } 77 78 /// Gets the maximum possible stored by a storageType. storageTypeMax must 79 /// be less than or equal to this value. getDefaultMaximumForInteger(bool isSigned,unsigned integralWidth)80 static int64_t getDefaultMaximumForInteger(bool isSigned, 81 unsigned integralWidth) { 82 if (isSigned) { 83 return llvm::maxIntN(integralWidth); 84 } 85 return llvm::maxUIntN(integralWidth); 86 } 87 88 /// Gets the original expressed type that this quantized type approximates. 89 /// Note that this presumes that the quantized type was always derived from 90 /// a floating point type, which in the broadest definition, is not true (i.e. 91 /// it could be some form of integral, fixed type or affine type in its own 92 /// right); however, at the high level, no examples of such usage are 93 /// presently known and the restriction serves some useful purposes (such as 94 /// always being able to reverse a transformation or measure error). In most 95 /// cases, this will be f32. 96 Type getExpressedType() const; 97 98 /// Gets the flags associated with this type. Typically a more specific 99 /// accessor is appropriate. 100 unsigned getFlags() const; 101 102 // Convenience helpers. 103 /// Whether the storage type should be interpreted as a signed quantity 104 /// (true) or an unsigned value (false). isSigned()105 bool isSigned() const { 106 return (getFlags() & QuantizationFlags::Signed) == 107 QuantizationFlags::Signed; 108 } 109 110 /// Gets the underlying type used for to store values. Note that this may 111 /// be signed or unsigned. Use the isSigned() accessor to differentiate. 112 Type getStorageType() const; 113 114 /// The minimum value that storageType can take. 115 int64_t getStorageTypeMin() const; 116 117 /// The maximum value that storageType can take. 118 int64_t getStorageTypeMax() const; 119 120 /// Gets the integral bit width that the underlying storage type can exactly 121 /// represent. For integral storage types, this will just be their width. 122 unsigned getStorageTypeIntegralWidth() const; 123 124 /// Returns whether the candidateExpressedType is a match for this 125 /// QuantizedType. This will be true if the candidate type is either a 126 /// primitive type or a container type whose element type equals this 127 /// QuantizedType's expressed type. 128 /// Examples of compatible candidateExpressedType: 129 /// !quant.uniform<i8:f32, 1.0> =~ f32 130 /// !quant.uniform<i8:f32, 1.0> =~ tensor<4xf32> 131 bool isCompatibleExpressedType(Type candidateExpressedType); 132 133 /// Returns the element type as a QuantizedType or nullptr if it is not 134 /// a quantized type. If the type is primitive, returns that. If it is a 135 /// container (vector/tensor), return the element type. 136 /// Examples: 137 /// !quant.uniform<i8:f32, 1.0> -> !quant.uniform<i8:f32, 1.0> 138 /// tensor<4x!quant.uniform<i8:f32, 1.0> -> quant.uniform<i8:f32, 1.0> 139 static QuantizedType getQuantizedElementType(Type primitiveOrContainerType); 140 141 /// Casts from a type based on the storageType to a corresponding type based 142 /// on this type (returns nullptr if the cast is not valid). 143 /// Examples: 144 /// i8 -> !quant.uniform<i8:f32, 1.0> 145 /// tensor<4xi8> -> tensor<4x!quant.uniform<i8:f32, 1.0}>> 146 /// vector<4xi8> -> vector<4x!quant.uniform<i8:f32, 1.0>> 147 Type castFromStorageType(Type candidateType); 148 149 /// Casts from a type based on a QuantizedType to a corresponding type based 150 /// on the storageType (returns nullptr if the cast is not valid). 151 /// This is the inverse of castFromStorageType(). 152 static Type castToStorageType(Type quantizedType); 153 154 /// Casts from a type based on the expressedType to a corresponding type based 155 /// on this type (returns nullptr if the cast is not valid). 156 /// Examples: 157 /// f32 -> !quant.uniform<i8:f32, 1.0> 158 /// tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>> 159 /// vector<4xf32> -> vector<4x!quant.uniform<i8:f32, 1.0>> 160 Type castFromExpressedType(Type candidateType); 161 162 /// Casts from a type based on QuantizedType to a corresponding type based 163 /// on the expressedType (returns nullptr if the cast is not valid). 164 /// This is the inverse of castFromExpressedType. 165 static Type castToExpressedType(Type quantizedType); 166 167 /// Casts from a type based on the expressedType to the equivalent type 168 /// based on storageType by way of this QuantizedType. Equivalent to: 169 /// QuantizedType::castToStorageType(castFromExpressedType(candidateType)) 170 /// (but with validity checks). 171 /// Example (for this = !quant.uniform<i8:f32, 1.0>): 172 /// tensor<4xf32> -> tensor<4xi8> 173 Type castExpressedToStorageType(Type candidateType); 174 175 private: 176 /// Hide the following methods inherited from `Type`. It is almost certainly 177 /// a bug to call them from a `QuantizedType` object. Users should call 178 /// `getStorageType` or `getExpressedType` to get the underlying types 179 /// they want to inspect. 180 using Type::isBF16; 181 using Type::isF16; 182 using Type::isF32; 183 using Type::isF64; 184 using Type::isIndex; 185 using Type::isInteger; 186 }; 187 188 /// A quantized type that maps storage to/from expressed types in an 189 /// unspecified way. 190 /// 191 /// Typical syntax: 192 /// quant.any<i8:f32> 193 /// quant.any<i8> 194 /// quant.any<i8<-16,15>> 195 /// 196 /// Note that for the any type, the expressed type is optional. 197 class AnyQuantizedType 198 : public Type::TypeBase<AnyQuantizedType, QuantizedType, 199 detail::AnyQuantizedTypeStorage> { 200 public: 201 using Base::Base; 202 using Base::getChecked; 203 204 /// Gets an instance of the type with all parameters specified but not 205 /// checked. 206 static AnyQuantizedType get(unsigned flags, Type storageType, 207 Type expressedType, int64_t storageTypeMin, 208 int64_t storageTypeMax); 209 210 /// Gets an instance of the type with all specified parameters checked. 211 /// Returns a nullptr convertible type on failure. 212 static AnyQuantizedType 213 getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags, 214 Type storageType, Type expressedType, int64_t storageTypeMin, 215 int64_t storageTypeMax); 216 217 /// Verifies construction invariants and issues errors/warnings. 218 static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError, 219 unsigned flags, Type storageType, 220 Type expressedType, int64_t storageTypeMin, 221 int64_t storageTypeMax); 222 }; 223 224 /// Represents a family of uniform, quantized types. 225 /// 226 /// Each instance of this type expresses a mapping between real values (most 227 /// often expressed in floating point f32) and quantized values (either fixed 228 /// point or affine). 229 /// 230 /// The relationship is: 231 /// real_value = scale * (quantized_value - zero_point) 232 /// 233 /// It is used as part of high level graph transformations that have the goal 234 /// of re-expressing parts of a computation in terms of this common form for 235 /// more efficient execution at runtime. In addition, it is designed to be 236 /// expressive enough to facilitate lowering to precise types and operations 237 /// in target hardware. 238 /// 239 /// As a high-level type, focused on intermediate passes, this type holds 240 /// opinions consistent with high-level usage. If lowering math kernels below 241 /// the high level arithmetic ops (i.e. to LLVM IR or hardware specific 242 /// instruction sets), it is expected that the information expressed here 243 /// will be used to drive low level codegen and target specific type selection, 244 /// but this type will likely be erased in the process. 245 /// 246 /// Syntax synopsis: 247 /// Per-layer, all parameters expressed: 248 /// !quant<uniform[StorageType:ExpressedType]{Scale:ZeroPoint}> 249 /// Per-layer, optional parameters omitted: 250 /// !quant<uniform[StorageType]{Scale}> 251 /// 252 /// StorageType: 'i'|'u' NumBits 253 /// ExpressedType: 'f16', 'f32', 'bf16', 'f64' 254 /// Scale: A legal double value 255 /// ZeroPoint: An integer value 256 class UniformQuantizedType 257 : public Type::TypeBase<UniformQuantizedType, QuantizedType, 258 detail::UniformQuantizedTypeStorage> { 259 public: 260 using Base::Base; 261 using Base::getChecked; 262 263 /// Gets an instance of the type with all parameters specified but not 264 /// checked. 265 static UniformQuantizedType get(unsigned flags, Type storageType, 266 Type expressedType, double scale, 267 int64_t zeroPoint, int64_t storageTypeMin, 268 int64_t storageTypeMax); 269 270 /// Gets an instance of the type with all specified parameters checked. 271 /// Returns a nullptr convertible type on failure. 272 static UniformQuantizedType 273 getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags, 274 Type storageType, Type expressedType, double scale, 275 int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax); 276 277 /// Verifies construction invariants and issues errors/warnings. 278 static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError, 279 unsigned flags, Type storageType, 280 Type expressedType, double scale, 281 int64_t zeroPoint, int64_t storageTypeMin, 282 int64_t storageTypeMax); 283 284 /// Gets the scale term. The scale designates the difference between the real 285 /// values corresponding to consecutive quantized values differing by 1. 286 double getScale() const; 287 288 /// Gets the storage value corresponding to the real value 0 in the affine 289 /// equation. 290 int64_t getZeroPoint() const; 291 292 // Fixed point values are real numbers divided by a scale. 293 // Currently, only signed storage types are treated as fixed point. 294 // A fixed point value can be obtained from an affine value by subtracting 295 // the zeroPoint. 296 // In the future, this may be explicit versus implied by type and zeroPoint. isFixedPoint()297 bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; } 298 }; 299 300 /// Represents per-axis (also known as per-channel quantization). 301 /// 302 /// Syntax synopsis: 303 /// Per-axis, all parameters expressed: 304 /// !quant<uniform[StorageType:ExpressedType:QuantizedDim]{QuantParams}> 305 /// Per-axis, optional parameters omitted: 306 /// !quant<uniform[StorageType]{Scale}> 307 /// 308 /// StorageType: 'i'|'u' NumBits 309 /// ExpressedType: 'f16', 'f32', 'bf16', 'f64' 310 /// QuantizedDim: An integer value 311 /// QuantParams: (Scale ':' ZeroPoint)+ 312 /// Scale: A legal double value 313 /// ZeroPoint: An integer value 314 class UniformQuantizedPerAxisType 315 : public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType, 316 detail::UniformQuantizedPerAxisTypeStorage> { 317 public: 318 using Base::Base; 319 using Base::getChecked; 320 321 /// Gets an instance of the type with all parameters specified but not 322 /// checked. 323 static UniformQuantizedPerAxisType 324 get(unsigned flags, Type storageType, Type expressedType, 325 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints, 326 int32_t quantizedDimension, int64_t storageTypeMin, 327 int64_t storageTypeMax); 328 329 /// Gets an instance of the type with all specified parameters checked. 330 /// Returns a nullptr convertible type on failure. 331 static UniformQuantizedPerAxisType 332 getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags, 333 Type storageType, Type expressedType, ArrayRef<double> scales, 334 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension, 335 int64_t storageTypeMin, int64_t storageTypeMax); 336 337 /// Verifies construction invariants and issues errors/warnings. 338 static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError, 339 unsigned flags, Type storageType, 340 Type expressedType, ArrayRef<double> scales, 341 ArrayRef<int64_t> zeroPoints, 342 int32_t quantizedDimension, 343 int64_t storageTypeMin, int64_t storageTypeMax); 344 345 /// Gets the quantization scales. The scales designate the difference between 346 /// the real values corresponding to consecutive quantized values differing 347 /// by 1. The ith scale corresponds to the ith slice in the 348 /// quantized_dimension. 349 ArrayRef<double> getScales() const; 350 351 /// Gets the storage values corresponding to the real value 0 in the affine 352 /// equation. The ith zero point corresponds to the ith slice in the 353 /// quantized_dimension. 354 ArrayRef<int64_t> getZeroPoints() const; 355 356 /// Specifies the dimension of the Tensor's shape that the scales and 357 /// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1] 358 /// with quantization params: 359 /// scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1 360 /// will be quantized across the second dimension of t. 361 /// t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1 362 /// t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2 363 /// t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3 364 int32_t getQuantizedDimension() const; 365 366 /// Fixed point values are real numbers divided by a scale. 367 /// Currently, only signed storage types are treated as fixed point. 368 /// A fixed point value can be obtained from an affine value by subtracting 369 /// the zeroPoint. 370 /// In the future, this may be explicit versus implied by type and zeroPoint. isFixedPoint()371 bool isFixedPoint() const { 372 if (!isSigned()) 373 return false; 374 return llvm::all_of(getZeroPoints(), 375 [](int64_t zeroPoint) { return zeroPoint != 0; }); 376 } 377 }; 378 379 /// A quantized type that infers its range from given min/max values. 380 /// 381 /// Typical syntax: 382 /// quant.calibrated<f32<-0.922,0.981>> 383 class CalibratedQuantizedType 384 : public Type::TypeBase<CalibratedQuantizedType, QuantizedType, 385 detail::CalibratedQuantizedTypeStorage> { 386 public: 387 using Base::Base; 388 using Base::getChecked; 389 390 /// Gets an instance of the type with all parameters specified but not 391 /// checked. 392 static CalibratedQuantizedType get(Type expressedType, double min, 393 double max); 394 395 /// Gets an instance of the type with all specified parameters checked. 396 /// Returns a nullptr convertible type on failure. 397 static CalibratedQuantizedType 398 getChecked(function_ref<InFlightDiagnostic()> emitError, Type expressedType, 399 double min, double max); 400 401 /// Verifies construction invariants and issues errors/warnings. 402 static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError, 403 Type expressedType, double min, double max); 404 double getMin() const; 405 double getMax() const; 406 }; 407 408 } // namespace quant 409 } // namespace mlir 410 411 #endif // MLIR_DIALECT_QUANT_QUANTTYPES_H 412