1363dd3f3SRob Suderman //===- QuantizationUtilsTest.cpp - unit tests for quantization utils ------===//
2363dd3f3SRob Suderman //
3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6363dd3f3SRob Suderman //
7363dd3f3SRob Suderman //===----------------------------------------------------------------------===//
8363dd3f3SRob Suderman
9363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.h"
10363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantizeUtils.h"
11363dd3f3SRob Suderman #include "mlir/Dialect/Quant/UniformSupport.h"
12363dd3f3SRob Suderman #include "mlir/IR/Attributes.h"
1309f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
14363dd3f3SRob Suderman #include "gmock/gmock.h"
15363dd3f3SRob Suderman #include "gtest/gtest.h"
16363dd3f3SRob Suderman
17363dd3f3SRob Suderman using namespace mlir;
18363dd3f3SRob Suderman using namespace mlir::quant;
19363dd3f3SRob Suderman
20363dd3f3SRob Suderman namespace {
21363dd3f3SRob Suderman
22363dd3f3SRob Suderman // Test UniformQuantizedValueConverter converts all APFloat to a magic number 5.
23363dd3f3SRob Suderman class TestUniformQuantizedValueConverter
24363dd3f3SRob Suderman : public UniformQuantizedValueConverter {
25363dd3f3SRob Suderman public:
TestUniformQuantizedValueConverter(UniformQuantizedType type)26363dd3f3SRob Suderman TestUniformQuantizedValueConverter(UniformQuantizedType type)
27363dd3f3SRob Suderman : UniformQuantizedValueConverter(type), qtype(type) {}
quantizeFloatToInt(APFloat expressedValue) const28*0ae2e958SMehdi Amini APInt quantizeFloatToInt(APFloat expressedValue) const override {
29363dd3f3SRob Suderman return APInt(qtype.getStorageType().cast<IntegerType>().getWidth(), 5L);
30363dd3f3SRob Suderman }
31363dd3f3SRob Suderman
32363dd3f3SRob Suderman private:
33363dd3f3SRob Suderman UniformQuantizedType qtype;
34363dd3f3SRob Suderman };
35363dd3f3SRob Suderman
getTestFloatAttr(double value,MLIRContext * ctx)36363dd3f3SRob Suderman Attribute getTestFloatAttr(double value, MLIRContext *ctx) {
37363dd3f3SRob Suderman return FloatAttr::get(FloatType::getF32(ctx), value);
38363dd3f3SRob Suderman }
39363dd3f3SRob Suderman
40363dd3f3SRob Suderman template <typename ConcreteAttrClass, typename... Arg>
getTestElementsAttr(MLIRContext * ctx,ArrayRef<int64_t> shape,Arg...value)41363dd3f3SRob Suderman ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
42363dd3f3SRob Suderman Arg... value) {
43363dd3f3SRob Suderman auto eleType = FloatType::getF32(ctx);
44363dd3f3SRob Suderman ShapedType tensorType;
45363dd3f3SRob Suderman if (shape.size() == 1 && shape[0] == -1) {
46363dd3f3SRob Suderman tensorType = UnrankedTensorType::get(eleType);
47363dd3f3SRob Suderman } else {
48363dd3f3SRob Suderman tensorType = RankedTensorType::get(shape, eleType);
49363dd3f3SRob Suderman }
50363dd3f3SRob Suderman return ConcreteAttrClass::get(tensorType, value...);
51363dd3f3SRob Suderman }
52363dd3f3SRob Suderman
getTestSparseElementsAttr(MLIRContext * ctx,ArrayRef<int64_t> shape)53363dd3f3SRob Suderman ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx,
54363dd3f3SRob Suderman ArrayRef<int64_t> shape) {
55363dd3f3SRob Suderman auto eleType = FloatType::getF32(ctx);
56363dd3f3SRob Suderman ShapedType tensorType;
57363dd3f3SRob Suderman if (shape.size() == 1 && shape[0] == -1) {
58363dd3f3SRob Suderman tensorType = UnrankedTensorType::get(eleType);
59363dd3f3SRob Suderman } else {
60363dd3f3SRob Suderman tensorType = RankedTensorType::get(shape, eleType);
61363dd3f3SRob Suderman }
621b97cdf8SRiver Riddle auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(ctx, 64));
63363dd3f3SRob Suderman auto indices =
64363dd3f3SRob Suderman DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
65363dd3f3SRob Suderman auto valuesType = RankedTensorType::get({1}, eleType);
66363dd3f3SRob Suderman auto values = DenseFPElementsAttr::get(valuesType, {APFloat(0.0f)});
67363dd3f3SRob Suderman return SparseElementsAttr::get(tensorType, indices, values);
68363dd3f3SRob Suderman }
69363dd3f3SRob Suderman
getTestQuantizedType(Type storageType,MLIRContext * ctx)70363dd3f3SRob Suderman UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
71363dd3f3SRob Suderman return UniformQuantizedType::get(/*flags=*/false, storageType,
72363dd3f3SRob Suderman FloatType::getF32(ctx), /*scale=*/1.0,
73363dd3f3SRob Suderman /*zeroPoint=*/0, /*storageTypeMin=*/0,
74363dd3f3SRob Suderman /*storageTypeMax=*/255);
75363dd3f3SRob Suderman }
76363dd3f3SRob Suderman
TEST(QuantizationUtilsTest,convertFloatAttrUniform)77363dd3f3SRob Suderman TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
78e7021232SMehdi Amini MLIRContext ctx;
79f9dc2b70SMehdi Amini ctx.getOrLoadDialect<QuantizationDialect>();
801b97cdf8SRiver Riddle IntegerType convertedType = IntegerType::get(&ctx, 8);
81363dd3f3SRob Suderman auto quantizedType = getTestQuantizedType(convertedType, &ctx);
82363dd3f3SRob Suderman TestUniformQuantizedValueConverter converter(quantizedType);
83363dd3f3SRob Suderman
84363dd3f3SRob Suderman auto realValue = getTestFloatAttr(1.0, &ctx);
85363dd3f3SRob Suderman Type typeResult;
86363dd3f3SRob Suderman auto valueResult =
87363dd3f3SRob Suderman quantizeAttrUniform(realValue, quantizedType, converter, typeResult);
88363dd3f3SRob Suderman
89363dd3f3SRob Suderman EXPECT_EQ(valueResult.cast<IntegerAttr>().getInt(), 5);
90363dd3f3SRob Suderman EXPECT_EQ(
91363dd3f3SRob Suderman valueResult.cast<IntegerAttr>().getType().cast<IntegerType>().getWidth(),
92363dd3f3SRob Suderman convertedType.getWidth());
93363dd3f3SRob Suderman }
94363dd3f3SRob Suderman
TEST(QuantizationUtilsTest,convertRankedDenseAttrUniform)95363dd3f3SRob Suderman TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
96e7021232SMehdi Amini MLIRContext ctx;
97f9dc2b70SMehdi Amini ctx.getOrLoadDialect<QuantizationDialect>();
981b97cdf8SRiver Riddle IntegerType convertedType = IntegerType::get(&ctx, 8);
99363dd3f3SRob Suderman auto quantizedType = getTestQuantizedType(convertedType, &ctx);
100363dd3f3SRob Suderman TestUniformQuantizedValueConverter converter(quantizedType);
101363dd3f3SRob Suderman auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>(
102363dd3f3SRob Suderman &ctx, {1, 2}, {getTestFloatAttr(1.0, &ctx), getTestFloatAttr(2.0, &ctx)});
103363dd3f3SRob Suderman
104363dd3f3SRob Suderman Type returnedType;
105363dd3f3SRob Suderman auto returnedValue =
106363dd3f3SRob Suderman quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
107363dd3f3SRob Suderman
108363dd3f3SRob Suderman // Check Elements attribute shape and kind are not changed.
109363dd3f3SRob Suderman auto tensorType = returnedType.cast<TensorType>();
110363dd3f3SRob Suderman auto expectedTensorType = realValue.getType().cast<TensorType>();
111363dd3f3SRob Suderman EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
112363dd3f3SRob Suderman EXPECT_EQ(tensorType.getElementType(), convertedType);
113363dd3f3SRob Suderman EXPECT_TRUE(returnedValue.isa<DenseIntElementsAttr>());
114363dd3f3SRob Suderman
115363dd3f3SRob Suderman // Check Elements attribute element value is expected.
116ae40d625SRiver Riddle auto firstValue =
117ae40d625SRiver Riddle returnedValue.cast<ElementsAttr>().getValues<Attribute>()[{0, 0}];
118363dd3f3SRob Suderman EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
119363dd3f3SRob Suderman }
120363dd3f3SRob Suderman
TEST(QuantizationUtilsTest,convertRankedSplatAttrUniform)121363dd3f3SRob Suderman TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
122e7021232SMehdi Amini MLIRContext ctx;
123f9dc2b70SMehdi Amini ctx.getOrLoadDialect<QuantizationDialect>();
1241b97cdf8SRiver Riddle IntegerType convertedType = IntegerType::get(&ctx, 8);
125363dd3f3SRob Suderman auto quantizedType = getTestQuantizedType(convertedType, &ctx);
126363dd3f3SRob Suderman TestUniformQuantizedValueConverter converter(quantizedType);
127363dd3f3SRob Suderman auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>(
128363dd3f3SRob Suderman &ctx, {1, 2}, getTestFloatAttr(1.0, &ctx));
129363dd3f3SRob Suderman
130363dd3f3SRob Suderman Type returnedType;
131363dd3f3SRob Suderman auto returnedValue =
132363dd3f3SRob Suderman quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
133363dd3f3SRob Suderman
134363dd3f3SRob Suderman // Check Elements attribute shape and kind are not changed.
135363dd3f3SRob Suderman auto tensorType = returnedType.cast<TensorType>();
136363dd3f3SRob Suderman auto expectedTensorType = realValue.getType().cast<TensorType>();
137363dd3f3SRob Suderman EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
138363dd3f3SRob Suderman EXPECT_EQ(tensorType.getElementType(), convertedType);
139363dd3f3SRob Suderman EXPECT_TRUE(returnedValue.isa<SplatElementsAttr>());
140363dd3f3SRob Suderman
141363dd3f3SRob Suderman // Check Elements attribute element value is expected.
142ae40d625SRiver Riddle auto firstValue =
143ae40d625SRiver Riddle returnedValue.cast<ElementsAttr>().getValues<Attribute>()[{0, 0}];
144363dd3f3SRob Suderman EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
145363dd3f3SRob Suderman }
146363dd3f3SRob Suderman
TEST(QuantizationUtilsTest,convertRankedSparseAttrUniform)147363dd3f3SRob Suderman TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
148e7021232SMehdi Amini MLIRContext ctx;
149f9dc2b70SMehdi Amini ctx.getOrLoadDialect<QuantizationDialect>();
1501b97cdf8SRiver Riddle IntegerType convertedType = IntegerType::get(&ctx, 8);
151363dd3f3SRob Suderman auto quantizedType = getTestQuantizedType(convertedType, &ctx);
152363dd3f3SRob Suderman TestUniformQuantizedValueConverter converter(quantizedType);
153363dd3f3SRob Suderman auto realValue = getTestSparseElementsAttr(&ctx, {1, 2});
154363dd3f3SRob Suderman
155363dd3f3SRob Suderman Type returnedType;
156363dd3f3SRob Suderman auto returnedValue =
157363dd3f3SRob Suderman quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
158363dd3f3SRob Suderman
159363dd3f3SRob Suderman // Check Elements attribute shape and kind are not changed.
160363dd3f3SRob Suderman auto tensorType = returnedType.cast<TensorType>();
161363dd3f3SRob Suderman auto expectedTensorType = realValue.getType().cast<TensorType>();
162363dd3f3SRob Suderman EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
163363dd3f3SRob Suderman EXPECT_EQ(tensorType.getElementType(), convertedType);
164fff39b62SRiver Riddle EXPECT_TRUE(returnedValue.isa<SparseElementsAttr>());
165363dd3f3SRob Suderman
166363dd3f3SRob Suderman // Check Elements attribute element value is expected.
167ae40d625SRiver Riddle auto firstValue =
168ae40d625SRiver Riddle returnedValue.cast<ElementsAttr>().getValues<Attribute>()[{0, 0}];
169363dd3f3SRob Suderman EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
170363dd3f3SRob Suderman }
171363dd3f3SRob Suderman
172be0a7e9fSMehdi Amini } // namespace
173