1 //===- QuantizationUtilsTest.cpp - unit tests for quantization utils ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Quant/QuantOps.h"
10 #include "mlir/Dialect/Quant/QuantizeUtils.h"
11 #include "mlir/Dialect/Quant/UniformSupport.h"
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "gmock/gmock.h"
15 #include "gtest/gtest.h"
16
17 using namespace mlir;
18 using namespace mlir::quant;
19
20 namespace {
21
22 // Test UniformQuantizedValueConverter converts all APFloat to a magic number 5.
23 class TestUniformQuantizedValueConverter
24 : public UniformQuantizedValueConverter {
25 public:
TestUniformQuantizedValueConverter(UniformQuantizedType type)26 TestUniformQuantizedValueConverter(UniformQuantizedType type)
27 : UniformQuantizedValueConverter(type), qtype(type) {}
quantizeFloatToInt(APFloat expressedValue) const28 APInt quantizeFloatToInt(APFloat expressedValue) const override {
29 return APInt(qtype.getStorageType().cast<IntegerType>().getWidth(), 5L);
30 }
31
32 private:
33 UniformQuantizedType qtype;
34 };
35
getTestFloatAttr(double value,MLIRContext * ctx)36 Attribute getTestFloatAttr(double value, MLIRContext *ctx) {
37 return FloatAttr::get(FloatType::getF32(ctx), value);
38 }
39
40 template <typename ConcreteAttrClass, typename... Arg>
getTestElementsAttr(MLIRContext * ctx,ArrayRef<int64_t> shape,Arg...value)41 ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
42 Arg... value) {
43 auto eleType = FloatType::getF32(ctx);
44 ShapedType tensorType;
45 if (shape.size() == 1 && shape[0] == -1) {
46 tensorType = UnrankedTensorType::get(eleType);
47 } else {
48 tensorType = RankedTensorType::get(shape, eleType);
49 }
50 return ConcreteAttrClass::get(tensorType, value...);
51 }
52
getTestSparseElementsAttr(MLIRContext * ctx,ArrayRef<int64_t> shape)53 ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx,
54 ArrayRef<int64_t> shape) {
55 auto eleType = FloatType::getF32(ctx);
56 ShapedType tensorType;
57 if (shape.size() == 1 && shape[0] == -1) {
58 tensorType = UnrankedTensorType::get(eleType);
59 } else {
60 tensorType = RankedTensorType::get(shape, eleType);
61 }
62 auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(ctx, 64));
63 auto indices =
64 DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
65 auto valuesType = RankedTensorType::get({1}, eleType);
66 auto values = DenseFPElementsAttr::get(valuesType, {APFloat(0.0f)});
67 return SparseElementsAttr::get(tensorType, indices, values);
68 }
69
getTestQuantizedType(Type storageType,MLIRContext * ctx)70 UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
71 return UniformQuantizedType::get(/*flags=*/false, storageType,
72 FloatType::getF32(ctx), /*scale=*/1.0,
73 /*zeroPoint=*/0, /*storageTypeMin=*/0,
74 /*storageTypeMax=*/255);
75 }
76
TEST(QuantizationUtilsTest,convertFloatAttrUniform)77 TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
78 MLIRContext ctx;
79 ctx.getOrLoadDialect<QuantizationDialect>();
80 IntegerType convertedType = IntegerType::get(&ctx, 8);
81 auto quantizedType = getTestQuantizedType(convertedType, &ctx);
82 TestUniformQuantizedValueConverter converter(quantizedType);
83
84 auto realValue = getTestFloatAttr(1.0, &ctx);
85 Type typeResult;
86 auto valueResult =
87 quantizeAttrUniform(realValue, quantizedType, converter, typeResult);
88
89 EXPECT_EQ(valueResult.cast<IntegerAttr>().getInt(), 5);
90 EXPECT_EQ(
91 valueResult.cast<IntegerAttr>().getType().cast<IntegerType>().getWidth(),
92 convertedType.getWidth());
93 }
94
TEST(QuantizationUtilsTest,convertRankedDenseAttrUniform)95 TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
96 MLIRContext ctx;
97 ctx.getOrLoadDialect<QuantizationDialect>();
98 IntegerType convertedType = IntegerType::get(&ctx, 8);
99 auto quantizedType = getTestQuantizedType(convertedType, &ctx);
100 TestUniformQuantizedValueConverter converter(quantizedType);
101 auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>(
102 &ctx, {1, 2}, {getTestFloatAttr(1.0, &ctx), getTestFloatAttr(2.0, &ctx)});
103
104 Type returnedType;
105 auto returnedValue =
106 quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
107
108 // Check Elements attribute shape and kind are not changed.
109 auto tensorType = returnedType.cast<TensorType>();
110 auto expectedTensorType = realValue.getType().cast<TensorType>();
111 EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
112 EXPECT_EQ(tensorType.getElementType(), convertedType);
113 EXPECT_TRUE(returnedValue.isa<DenseIntElementsAttr>());
114
115 // Check Elements attribute element value is expected.
116 auto firstValue =
117 returnedValue.cast<ElementsAttr>().getValues<Attribute>()[{0, 0}];
118 EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
119 }
120
TEST(QuantizationUtilsTest,convertRankedSplatAttrUniform)121 TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
122 MLIRContext ctx;
123 ctx.getOrLoadDialect<QuantizationDialect>();
124 IntegerType convertedType = IntegerType::get(&ctx, 8);
125 auto quantizedType = getTestQuantizedType(convertedType, &ctx);
126 TestUniformQuantizedValueConverter converter(quantizedType);
127 auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>(
128 &ctx, {1, 2}, getTestFloatAttr(1.0, &ctx));
129
130 Type returnedType;
131 auto returnedValue =
132 quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
133
134 // Check Elements attribute shape and kind are not changed.
135 auto tensorType = returnedType.cast<TensorType>();
136 auto expectedTensorType = realValue.getType().cast<TensorType>();
137 EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
138 EXPECT_EQ(tensorType.getElementType(), convertedType);
139 EXPECT_TRUE(returnedValue.isa<SplatElementsAttr>());
140
141 // Check Elements attribute element value is expected.
142 auto firstValue =
143 returnedValue.cast<ElementsAttr>().getValues<Attribute>()[{0, 0}];
144 EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
145 }
146
TEST(QuantizationUtilsTest,convertRankedSparseAttrUniform)147 TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
148 MLIRContext ctx;
149 ctx.getOrLoadDialect<QuantizationDialect>();
150 IntegerType convertedType = IntegerType::get(&ctx, 8);
151 auto quantizedType = getTestQuantizedType(convertedType, &ctx);
152 TestUniformQuantizedValueConverter converter(quantizedType);
153 auto realValue = getTestSparseElementsAttr(&ctx, {1, 2});
154
155 Type returnedType;
156 auto returnedValue =
157 quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
158
159 // Check Elements attribute shape and kind are not changed.
160 auto tensorType = returnedType.cast<TensorType>();
161 auto expectedTensorType = realValue.getType().cast<TensorType>();
162 EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
163 EXPECT_EQ(tensorType.getElementType(), convertedType);
164 EXPECT_TRUE(returnedValue.isa<SparseElementsAttr>());
165
166 // Check Elements attribute element value is expected.
167 auto firstValue =
168 returnedValue.cast<ElementsAttr>().getValues<Attribute>()[{0, 0}];
169 EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
170 }
171
172 } // namespace
173