1 //===- QuantizeUtils.cpp - Support utilities for quantization -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Quant/QuantizeUtils.h" 10 #include "mlir/Dialect/Quant/UniformSupport.h" 11 #include "mlir/IR/Attributes.h" 12 #include "mlir/IR/StandardTypes.h" 13 14 using namespace mlir; 15 using namespace mlir::quant; 16 17 /// Converts a possible primitive, real expressed value attribute to a 18 /// corresponding storage attribute (typically FloatAttr -> IntegerAttr). 19 /// quantizedElementType is the QuantizedType that describes the expressed 20 /// origValue. 21 /// Returns a converter Attribute or nullptr if conversion is not possible. 22 static Attribute convertPrimitiveValueAttr( 23 Attribute origRealValue, QuantizedType quantizedElementType, 24 const UniformQuantizedValueConverter &converter, Type &outConvertedType) { 25 if (origRealValue.isa<FloatAttr>()) { 26 FloatAttr floatAttr = origRealValue.cast<FloatAttr>(); 27 outConvertedType = quantizedElementType.getStorageType(); 28 return IntegerAttr::get(quantizedElementType.getStorageType(), 29 converter.quantizeFloatToInt(floatAttr.getValue())); 30 } 31 32 return nullptr; 33 } 34 35 /// Converts a real expressed DenseFPElementsAttr to a corresponding 36 /// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized 37 /// storage values assuming the given quantizedElementType and converter. 38 static DenseElementsAttr 39 convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr, 40 QuantizedType quantizedElementType, 41 const UniformQuantizedValueConverter &converter) { 42 // Convert to corresponding quantized value attributes. 43 SmallVector<APInt, 8> quantValues; 44 if (realFPElementsAttr.isSplat()) { 45 quantValues.push_back( 46 converter.quantizeFloatToInt(*realFPElementsAttr.begin())); 47 } else { 48 quantValues.reserve(realFPElementsAttr.getNumElements()); 49 for (APFloat realVal : realFPElementsAttr) { 50 quantValues.push_back(converter.quantizeFloatToInt(realVal)); 51 } 52 } 53 54 // Cast from an expressed-type-based type to storage-type-based type, 55 // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>). 56 ShapedType newDenseType = 57 quantizedElementType 58 .castExpressedToStorageType(realFPElementsAttr.getType()) 59 .dyn_cast_or_null<ShapedType>(); 60 if (!newDenseType) { 61 return nullptr; 62 } 63 return DenseIntElementsAttr::get(newDenseType, quantValues); 64 } 65 66 /// Converts a real expressed SplatElementsAttr to a corresponding 67 /// SplatElementsAttr containing quantized storage values assuming the given 68 /// quantizedElementType and converter. 69 static SparseElementsAttr 70 convertSparseElementsAttr(SparseElementsAttr realSparseAttr, 71 QuantizedType quantizedElementType, 72 const UniformQuantizedValueConverter &converter) { 73 DenseElementsAttr realDenseAttr = realSparseAttr.getValues(); 74 if (!realDenseAttr.isa<DenseFPElementsAttr>()) { 75 return nullptr; 76 } 77 DenseElementsAttr quantDenseAttr = 78 convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(), 79 quantizedElementType, converter); 80 if (!quantDenseAttr) { 81 return nullptr; 82 } 83 84 // Cast from an expressed-type-based type to storage-type-based type, 85 // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>). 86 ShapedType newSparseType = 87 quantizedElementType.castExpressedToStorageType(realSparseAttr.getType()) 88 .dyn_cast_or_null<ShapedType>(); 89 if (!newSparseType) { 90 return nullptr; 91 } 92 return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(), 93 quantDenseAttr); 94 } 95 96 /// Converts a real expressed Attribute to a corresponding Attribute containing 97 /// quantized storage values assuming the given uniform quantizedElementType and 98 /// converter. 99 Attribute mlir::quant::quantizeAttrUniform( 100 Attribute realValue, UniformQuantizedType quantizedElementType, 101 const UniformQuantizedValueConverter &converter, Type &outConvertedType) { 102 // Fork to handle different variants of constants supported. 103 if (realValue.isa<DenseFPElementsAttr>()) { 104 // Dense tensor or vector constant. 105 auto converted = convertDenseFPElementsAttr( 106 realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter); 107 outConvertedType = converted.getType(); 108 return converted; 109 } else if (realValue.isa<SparseElementsAttr>()) { 110 // Sparse tensor or vector constant. 111 auto converted = convertSparseElementsAttr( 112 realValue.cast<SparseElementsAttr>(), quantizedElementType, converter); 113 outConvertedType = converted.getType(); 114 return converted; 115 } else { 116 // Nothing else matched: try to convert a primitive. 117 return convertPrimitiveValueAttr(realValue, quantizedElementType, converter, 118 outConvertedType); 119 } 120 } 121 122 /// Convert an attribute from a type based on 123 /// quantizedElementType.getExpressedType() to one based on 124 /// quantizedElementType.getStorageType(). 125 /// Returns nullptr if the conversion is not supported. 126 /// On success, stores the converted type in outConvertedType. 127 Attribute mlir::quant::quantizeAttr(Attribute realValue, 128 QuantizedType quantizedElementType, 129 Type &outConvertedType) { 130 if (auto uniformQuantized = 131 quantizedElementType.dyn_cast<UniformQuantizedType>()) { 132 UniformQuantizedValueConverter converter(uniformQuantized); 133 return quantizeAttrUniform(realValue, uniformQuantized, converter, 134 outConvertedType); 135 136 } else if (auto uniformQuantizedPerAxis = 137 quantizedElementType.dyn_cast<UniformQuantizedPerAxisType>()) { 138 UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis); 139 auto converted = converter.convert(realValue); 140 // TODO: why we need this outConvertedType? remove it? 141 if (converted) { 142 outConvertedType = converted.getType(); 143 } 144 return converted; 145 } else { 146 return nullptr; 147 } 148 } 149