1363dd3f3SRob Suderman //===- QuantizeUtils.cpp - Support utilities for quantization -------------===//
2363dd3f3SRob Suderman //
3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6363dd3f3SRob Suderman //
7363dd3f3SRob Suderman //===----------------------------------------------------------------------===//
8363dd3f3SRob Suderman 
9363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantizeUtils.h"
10363dd3f3SRob Suderman #include "mlir/Dialect/Quant/UniformSupport.h"
11363dd3f3SRob Suderman #include "mlir/IR/Attributes.h"
1209f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
13363dd3f3SRob Suderman 
14363dd3f3SRob Suderman using namespace mlir;
15363dd3f3SRob Suderman using namespace mlir::quant;
16363dd3f3SRob Suderman 
17363dd3f3SRob Suderman /// Converts a possible primitive, real expressed value attribute to a
18363dd3f3SRob Suderman /// corresponding storage attribute (typically FloatAttr -> IntegerAttr).
19363dd3f3SRob Suderman /// quantizedElementType is the QuantizedType that describes the expressed
20363dd3f3SRob Suderman /// origValue.
21363dd3f3SRob Suderman /// Returns a converter Attribute or nullptr if conversion is not possible.
convertPrimitiveValueAttr(Attribute origRealValue,QuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter,Type & outConvertedType)22363dd3f3SRob Suderman static Attribute convertPrimitiveValueAttr(
23363dd3f3SRob Suderman     Attribute origRealValue, QuantizedType quantizedElementType,
24363dd3f3SRob Suderman     const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
25363dd3f3SRob Suderman   if (origRealValue.isa<FloatAttr>()) {
26363dd3f3SRob Suderman     FloatAttr floatAttr = origRealValue.cast<FloatAttr>();
27363dd3f3SRob Suderman     outConvertedType = quantizedElementType.getStorageType();
28363dd3f3SRob Suderman     return IntegerAttr::get(quantizedElementType.getStorageType(),
29363dd3f3SRob Suderman                             converter.quantizeFloatToInt(floatAttr.getValue()));
30363dd3f3SRob Suderman   }
31363dd3f3SRob Suderman 
32363dd3f3SRob Suderman   return nullptr;
33363dd3f3SRob Suderman }
34363dd3f3SRob Suderman 
35363dd3f3SRob Suderman /// Converts a real expressed DenseFPElementsAttr to a corresponding
36363dd3f3SRob Suderman /// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized
37363dd3f3SRob Suderman /// storage values assuming the given quantizedElementType and converter.
38363dd3f3SRob Suderman static DenseElementsAttr
convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,QuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter)39363dd3f3SRob Suderman convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
40363dd3f3SRob Suderman                            QuantizedType quantizedElementType,
41363dd3f3SRob Suderman                            const UniformQuantizedValueConverter &converter) {
42363dd3f3SRob Suderman   // Convert to corresponding quantized value attributes.
43363dd3f3SRob Suderman   SmallVector<APInt, 8> quantValues;
44363dd3f3SRob Suderman   if (realFPElementsAttr.isSplat()) {
45363dd3f3SRob Suderman     quantValues.push_back(
46363dd3f3SRob Suderman         converter.quantizeFloatToInt(*realFPElementsAttr.begin()));
47363dd3f3SRob Suderman   } else {
48363dd3f3SRob Suderman     quantValues.reserve(realFPElementsAttr.getNumElements());
49363dd3f3SRob Suderman     for (APFloat realVal : realFPElementsAttr) {
50363dd3f3SRob Suderman       quantValues.push_back(converter.quantizeFloatToInt(realVal));
51363dd3f3SRob Suderman     }
52363dd3f3SRob Suderman   }
53363dd3f3SRob Suderman 
54363dd3f3SRob Suderman   // Cast from an expressed-type-based type to storage-type-based type,
55363dd3f3SRob Suderman   // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>).
56363dd3f3SRob Suderman   ShapedType newDenseType =
57363dd3f3SRob Suderman       quantizedElementType
58363dd3f3SRob Suderman           .castExpressedToStorageType(realFPElementsAttr.getType())
59363dd3f3SRob Suderman           .dyn_cast_or_null<ShapedType>();
60363dd3f3SRob Suderman   if (!newDenseType) {
61363dd3f3SRob Suderman     return nullptr;
62363dd3f3SRob Suderman   }
63363dd3f3SRob Suderman   return DenseIntElementsAttr::get(newDenseType, quantValues);
64363dd3f3SRob Suderman }
65363dd3f3SRob Suderman 
66363dd3f3SRob Suderman /// Converts a real expressed SplatElementsAttr to a corresponding
67363dd3f3SRob Suderman /// SplatElementsAttr containing quantized storage values assuming the given
68363dd3f3SRob Suderman /// quantizedElementType and converter.
69363dd3f3SRob Suderman static SparseElementsAttr
convertSparseElementsAttr(SparseElementsAttr realSparseAttr,QuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter)70363dd3f3SRob Suderman convertSparseElementsAttr(SparseElementsAttr realSparseAttr,
71363dd3f3SRob Suderman                           QuantizedType quantizedElementType,
72363dd3f3SRob Suderman                           const UniformQuantizedValueConverter &converter) {
73363dd3f3SRob Suderman   DenseElementsAttr realDenseAttr = realSparseAttr.getValues();
74363dd3f3SRob Suderman   if (!realDenseAttr.isa<DenseFPElementsAttr>()) {
75363dd3f3SRob Suderman     return nullptr;
76363dd3f3SRob Suderman   }
77363dd3f3SRob Suderman   DenseElementsAttr quantDenseAttr =
78363dd3f3SRob Suderman       convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(),
79363dd3f3SRob Suderman                                  quantizedElementType, converter);
80363dd3f3SRob Suderman   if (!quantDenseAttr) {
81363dd3f3SRob Suderman     return nullptr;
82363dd3f3SRob Suderman   }
83363dd3f3SRob Suderman 
84363dd3f3SRob Suderman   // Cast from an expressed-type-based type to storage-type-based type,
85363dd3f3SRob Suderman   // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>).
86363dd3f3SRob Suderman   ShapedType newSparseType =
87363dd3f3SRob Suderman       quantizedElementType.castExpressedToStorageType(realSparseAttr.getType())
88363dd3f3SRob Suderman           .dyn_cast_or_null<ShapedType>();
89363dd3f3SRob Suderman   if (!newSparseType) {
90363dd3f3SRob Suderman     return nullptr;
91363dd3f3SRob Suderman   }
92363dd3f3SRob Suderman   return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(),
93363dd3f3SRob Suderman                                  quantDenseAttr);
94363dd3f3SRob Suderman }
95363dd3f3SRob Suderman 
96363dd3f3SRob Suderman /// Converts a real expressed Attribute to a corresponding Attribute containing
97363dd3f3SRob Suderman /// quantized storage values assuming the given uniform quantizedElementType and
98363dd3f3SRob Suderman /// converter.
quantizeAttrUniform(Attribute realValue,UniformQuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter,Type & outConvertedType)99363dd3f3SRob Suderman Attribute mlir::quant::quantizeAttrUniform(
100363dd3f3SRob Suderman     Attribute realValue, UniformQuantizedType quantizedElementType,
101363dd3f3SRob Suderman     const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
102363dd3f3SRob Suderman   // Fork to handle different variants of constants supported.
103363dd3f3SRob Suderman   if (realValue.isa<DenseFPElementsAttr>()) {
104363dd3f3SRob Suderman     // Dense tensor or vector constant.
105363dd3f3SRob Suderman     auto converted = convertDenseFPElementsAttr(
106363dd3f3SRob Suderman         realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
107363dd3f3SRob Suderman     outConvertedType = converted.getType();
108363dd3f3SRob Suderman     return converted;
109*02b6fb21SMehdi Amini   }
110*02b6fb21SMehdi Amini   if (realValue.isa<SparseElementsAttr>()) {
111363dd3f3SRob Suderman     // Sparse tensor or vector constant.
112363dd3f3SRob Suderman     auto converted = convertSparseElementsAttr(
113363dd3f3SRob Suderman         realValue.cast<SparseElementsAttr>(), quantizedElementType, converter);
114363dd3f3SRob Suderman     outConvertedType = converted.getType();
115363dd3f3SRob Suderman     return converted;
116*02b6fb21SMehdi Amini   }
117363dd3f3SRob Suderman   // Nothing else matched: try to convert a primitive.
118363dd3f3SRob Suderman   return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
119363dd3f3SRob Suderman                                    outConvertedType);
120363dd3f3SRob Suderman }
121363dd3f3SRob Suderman 
122363dd3f3SRob Suderman /// Convert an attribute from a type based on
123363dd3f3SRob Suderman /// quantizedElementType.getExpressedType() to one based on
124363dd3f3SRob Suderman /// quantizedElementType.getStorageType().
125363dd3f3SRob Suderman /// Returns nullptr if the conversion is not supported.
126363dd3f3SRob Suderman /// On success, stores the converted type in outConvertedType.
quantizeAttr(Attribute realValue,QuantizedType quantizedElementType,Type & outConvertedType)127363dd3f3SRob Suderman Attribute mlir::quant::quantizeAttr(Attribute realValue,
128363dd3f3SRob Suderman                                     QuantizedType quantizedElementType,
129363dd3f3SRob Suderman                                     Type &outConvertedType) {
130363dd3f3SRob Suderman   if (auto uniformQuantized =
131363dd3f3SRob Suderman           quantizedElementType.dyn_cast<UniformQuantizedType>()) {
132363dd3f3SRob Suderman     UniformQuantizedValueConverter converter(uniformQuantized);
133363dd3f3SRob Suderman     return quantizeAttrUniform(realValue, uniformQuantized, converter,
134363dd3f3SRob Suderman                                outConvertedType);
135*02b6fb21SMehdi Amini   }
136*02b6fb21SMehdi Amini   if (auto uniformQuantizedPerAxis =
137363dd3f3SRob Suderman           quantizedElementType.dyn_cast<UniformQuantizedPerAxisType>()) {
138363dd3f3SRob Suderman     UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis);
139363dd3f3SRob Suderman     auto converted = converter.convert(realValue);
1409db53a18SRiver Riddle     // TODO: why we need this outConvertedType? remove it?
141363dd3f3SRob Suderman     if (converted) {
142363dd3f3SRob Suderman       outConvertedType = converted.getType();
143363dd3f3SRob Suderman     }
144363dd3f3SRob Suderman     return converted;
145363dd3f3SRob Suderman   }
146*02b6fb21SMehdi Amini   return nullptr;
147363dd3f3SRob Suderman }
148