1363dd3f3SRob Suderman //===- QuantizeUtils.cpp - Support utilities for quantization -------------===//
2363dd3f3SRob Suderman //
3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6363dd3f3SRob Suderman //
7363dd3f3SRob Suderman //===----------------------------------------------------------------------===//
8363dd3f3SRob Suderman
9363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantizeUtils.h"
10363dd3f3SRob Suderman #include "mlir/Dialect/Quant/UniformSupport.h"
11363dd3f3SRob Suderman #include "mlir/IR/Attributes.h"
1209f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
13363dd3f3SRob Suderman
14363dd3f3SRob Suderman using namespace mlir;
15363dd3f3SRob Suderman using namespace mlir::quant;
16363dd3f3SRob Suderman
17363dd3f3SRob Suderman /// Converts a possible primitive, real expressed value attribute to a
18363dd3f3SRob Suderman /// corresponding storage attribute (typically FloatAttr -> IntegerAttr).
19363dd3f3SRob Suderman /// quantizedElementType is the QuantizedType that describes the expressed
20363dd3f3SRob Suderman /// origValue.
21363dd3f3SRob Suderman /// Returns a converter Attribute or nullptr if conversion is not possible.
convertPrimitiveValueAttr(Attribute origRealValue,QuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter,Type & outConvertedType)22363dd3f3SRob Suderman static Attribute convertPrimitiveValueAttr(
23363dd3f3SRob Suderman Attribute origRealValue, QuantizedType quantizedElementType,
24363dd3f3SRob Suderman const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
25363dd3f3SRob Suderman if (origRealValue.isa<FloatAttr>()) {
26363dd3f3SRob Suderman FloatAttr floatAttr = origRealValue.cast<FloatAttr>();
27363dd3f3SRob Suderman outConvertedType = quantizedElementType.getStorageType();
28363dd3f3SRob Suderman return IntegerAttr::get(quantizedElementType.getStorageType(),
29363dd3f3SRob Suderman converter.quantizeFloatToInt(floatAttr.getValue()));
30363dd3f3SRob Suderman }
31363dd3f3SRob Suderman
32363dd3f3SRob Suderman return nullptr;
33363dd3f3SRob Suderman }
34363dd3f3SRob Suderman
35363dd3f3SRob Suderman /// Converts a real expressed DenseFPElementsAttr to a corresponding
36363dd3f3SRob Suderman /// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized
37363dd3f3SRob Suderman /// storage values assuming the given quantizedElementType and converter.
38363dd3f3SRob Suderman static DenseElementsAttr
convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,QuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter)39363dd3f3SRob Suderman convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
40363dd3f3SRob Suderman QuantizedType quantizedElementType,
41363dd3f3SRob Suderman const UniformQuantizedValueConverter &converter) {
42363dd3f3SRob Suderman // Convert to corresponding quantized value attributes.
43363dd3f3SRob Suderman SmallVector<APInt, 8> quantValues;
44363dd3f3SRob Suderman if (realFPElementsAttr.isSplat()) {
45363dd3f3SRob Suderman quantValues.push_back(
46363dd3f3SRob Suderman converter.quantizeFloatToInt(*realFPElementsAttr.begin()));
47363dd3f3SRob Suderman } else {
48363dd3f3SRob Suderman quantValues.reserve(realFPElementsAttr.getNumElements());
49363dd3f3SRob Suderman for (APFloat realVal : realFPElementsAttr) {
50363dd3f3SRob Suderman quantValues.push_back(converter.quantizeFloatToInt(realVal));
51363dd3f3SRob Suderman }
52363dd3f3SRob Suderman }
53363dd3f3SRob Suderman
54363dd3f3SRob Suderman // Cast from an expressed-type-based type to storage-type-based type,
55363dd3f3SRob Suderman // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>).
56363dd3f3SRob Suderman ShapedType newDenseType =
57363dd3f3SRob Suderman quantizedElementType
58363dd3f3SRob Suderman .castExpressedToStorageType(realFPElementsAttr.getType())
59363dd3f3SRob Suderman .dyn_cast_or_null<ShapedType>();
60363dd3f3SRob Suderman if (!newDenseType) {
61363dd3f3SRob Suderman return nullptr;
62363dd3f3SRob Suderman }
63363dd3f3SRob Suderman return DenseIntElementsAttr::get(newDenseType, quantValues);
64363dd3f3SRob Suderman }
65363dd3f3SRob Suderman
66363dd3f3SRob Suderman /// Converts a real expressed SplatElementsAttr to a corresponding
67363dd3f3SRob Suderman /// SplatElementsAttr containing quantized storage values assuming the given
68363dd3f3SRob Suderman /// quantizedElementType and converter.
69363dd3f3SRob Suderman static SparseElementsAttr
convertSparseElementsAttr(SparseElementsAttr realSparseAttr,QuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter)70363dd3f3SRob Suderman convertSparseElementsAttr(SparseElementsAttr realSparseAttr,
71363dd3f3SRob Suderman QuantizedType quantizedElementType,
72363dd3f3SRob Suderman const UniformQuantizedValueConverter &converter) {
73363dd3f3SRob Suderman DenseElementsAttr realDenseAttr = realSparseAttr.getValues();
74363dd3f3SRob Suderman if (!realDenseAttr.isa<DenseFPElementsAttr>()) {
75363dd3f3SRob Suderman return nullptr;
76363dd3f3SRob Suderman }
77363dd3f3SRob Suderman DenseElementsAttr quantDenseAttr =
78363dd3f3SRob Suderman convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(),
79363dd3f3SRob Suderman quantizedElementType, converter);
80363dd3f3SRob Suderman if (!quantDenseAttr) {
81363dd3f3SRob Suderman return nullptr;
82363dd3f3SRob Suderman }
83363dd3f3SRob Suderman
84363dd3f3SRob Suderman // Cast from an expressed-type-based type to storage-type-based type,
85363dd3f3SRob Suderman // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>).
86363dd3f3SRob Suderman ShapedType newSparseType =
87363dd3f3SRob Suderman quantizedElementType.castExpressedToStorageType(realSparseAttr.getType())
88363dd3f3SRob Suderman .dyn_cast_or_null<ShapedType>();
89363dd3f3SRob Suderman if (!newSparseType) {
90363dd3f3SRob Suderman return nullptr;
91363dd3f3SRob Suderman }
92363dd3f3SRob Suderman return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(),
93363dd3f3SRob Suderman quantDenseAttr);
94363dd3f3SRob Suderman }
95363dd3f3SRob Suderman
96363dd3f3SRob Suderman /// Converts a real expressed Attribute to a corresponding Attribute containing
97363dd3f3SRob Suderman /// quantized storage values assuming the given uniform quantizedElementType and
98363dd3f3SRob Suderman /// converter.
quantizeAttrUniform(Attribute realValue,UniformQuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter,Type & outConvertedType)99363dd3f3SRob Suderman Attribute mlir::quant::quantizeAttrUniform(
100363dd3f3SRob Suderman Attribute realValue, UniformQuantizedType quantizedElementType,
101363dd3f3SRob Suderman const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
102363dd3f3SRob Suderman // Fork to handle different variants of constants supported.
103363dd3f3SRob Suderman if (realValue.isa<DenseFPElementsAttr>()) {
104363dd3f3SRob Suderman // Dense tensor or vector constant.
105363dd3f3SRob Suderman auto converted = convertDenseFPElementsAttr(
106363dd3f3SRob Suderman realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
107363dd3f3SRob Suderman outConvertedType = converted.getType();
108363dd3f3SRob Suderman return converted;
109*02b6fb21SMehdi Amini }
110*02b6fb21SMehdi Amini if (realValue.isa<SparseElementsAttr>()) {
111363dd3f3SRob Suderman // Sparse tensor or vector constant.
112363dd3f3SRob Suderman auto converted = convertSparseElementsAttr(
113363dd3f3SRob Suderman realValue.cast<SparseElementsAttr>(), quantizedElementType, converter);
114363dd3f3SRob Suderman outConvertedType = converted.getType();
115363dd3f3SRob Suderman return converted;
116*02b6fb21SMehdi Amini }
117363dd3f3SRob Suderman // Nothing else matched: try to convert a primitive.
118363dd3f3SRob Suderman return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
119363dd3f3SRob Suderman outConvertedType);
120363dd3f3SRob Suderman }
121363dd3f3SRob Suderman
122363dd3f3SRob Suderman /// Convert an attribute from a type based on
123363dd3f3SRob Suderman /// quantizedElementType.getExpressedType() to one based on
124363dd3f3SRob Suderman /// quantizedElementType.getStorageType().
125363dd3f3SRob Suderman /// Returns nullptr if the conversion is not supported.
126363dd3f3SRob Suderman /// On success, stores the converted type in outConvertedType.
quantizeAttr(Attribute realValue,QuantizedType quantizedElementType,Type & outConvertedType)127363dd3f3SRob Suderman Attribute mlir::quant::quantizeAttr(Attribute realValue,
128363dd3f3SRob Suderman QuantizedType quantizedElementType,
129363dd3f3SRob Suderman Type &outConvertedType) {
130363dd3f3SRob Suderman if (auto uniformQuantized =
131363dd3f3SRob Suderman quantizedElementType.dyn_cast<UniformQuantizedType>()) {
132363dd3f3SRob Suderman UniformQuantizedValueConverter converter(uniformQuantized);
133363dd3f3SRob Suderman return quantizeAttrUniform(realValue, uniformQuantized, converter,
134363dd3f3SRob Suderman outConvertedType);
135*02b6fb21SMehdi Amini }
136*02b6fb21SMehdi Amini if (auto uniformQuantizedPerAxis =
137363dd3f3SRob Suderman quantizedElementType.dyn_cast<UniformQuantizedPerAxisType>()) {
138363dd3f3SRob Suderman UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis);
139363dd3f3SRob Suderman auto converted = converter.convert(realValue);
1409db53a18SRiver Riddle // TODO: why we need this outConvertedType? remove it?
141363dd3f3SRob Suderman if (converted) {
142363dd3f3SRob Suderman outConvertedType = converted.getType();
143363dd3f3SRob Suderman }
144363dd3f3SRob Suderman return converted;
145363dd3f3SRob Suderman }
146*02b6fb21SMehdi Amini return nullptr;
147363dd3f3SRob Suderman }
148