1 //===- QuantizationUtilsTest.cpp - unit tests for quantization utils ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Quant/QuantOps.h"
10 #include "mlir/Dialect/Quant/QuantizeUtils.h"
11 #include "mlir/Dialect/Quant/UniformSupport.h"
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/StandardTypes.h"
14 #include "gmock/gmock.h"
15 #include "gtest/gtest.h"
16 
17 using namespace mlir;
18 using namespace mlir::quant;
19 
20 // Load the quant dialect
21 static DialectRegistration<QuantizationDialect> QuantOpsRegistration;
22 
23 namespace {
24 
25 // Test UniformQuantizedValueConverter converts all APFloat to a magic number 5.
26 class TestUniformQuantizedValueConverter
27     : public UniformQuantizedValueConverter {
28 public:
29   TestUniformQuantizedValueConverter(UniformQuantizedType type)
30       : UniformQuantizedValueConverter(type), qtype(type) {}
31   APInt quantizeFloatToInt(APFloat expressedValue) const {
32     return APInt(qtype.getStorageType().cast<IntegerType>().getWidth(), 5L);
33   }
34 
35 private:
36   UniformQuantizedType qtype;
37 };
38 
39 Attribute getTestFloatAttr(double value, MLIRContext *ctx) {
40   return FloatAttr::get(FloatType::getF32(ctx), value);
41 }
42 
43 template <typename ConcreteAttrClass, typename... Arg>
44 ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
45                                       Arg... value) {
46   auto eleType = FloatType::getF32(ctx);
47   ShapedType tensorType;
48   if (shape.size() == 1 && shape[0] == -1) {
49     tensorType = UnrankedTensorType::get(eleType);
50   } else {
51     tensorType = RankedTensorType::get(shape, eleType);
52   }
53   return ConcreteAttrClass::get(tensorType, value...);
54 }
55 
56 ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx,
57                                        ArrayRef<int64_t> shape) {
58   auto eleType = FloatType::getF32(ctx);
59   ShapedType tensorType;
60   if (shape.size() == 1 && shape[0] == -1) {
61     tensorType = UnrankedTensorType::get(eleType);
62   } else {
63     tensorType = RankedTensorType::get(shape, eleType);
64   }
65   auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(64, ctx));
66   auto indices =
67       DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
68   auto valuesType = RankedTensorType::get({1}, eleType);
69   auto values = DenseFPElementsAttr::get(valuesType, {APFloat(0.0f)});
70   return SparseElementsAttr::get(tensorType, indices, values);
71 }
72 
73 UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
74   return UniformQuantizedType::get(/*flags=*/false, storageType,
75                                    FloatType::getF32(ctx), /*scale=*/1.0,
76                                    /*zeroPoint=*/0, /*storageTypeMin=*/0,
77                                    /*storageTypeMax=*/255);
78 }
79 
80 TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
81   MLIRContext ctx;
82   IntegerType convertedType = IntegerType::get(8, &ctx);
83   auto quantizedType = getTestQuantizedType(convertedType, &ctx);
84   TestUniformQuantizedValueConverter converter(quantizedType);
85 
86   auto realValue = getTestFloatAttr(1.0, &ctx);
87   Type typeResult;
88   auto valueResult =
89       quantizeAttrUniform(realValue, quantizedType, converter, typeResult);
90 
91   EXPECT_EQ(valueResult.cast<IntegerAttr>().getInt(), 5);
92   EXPECT_EQ(
93       valueResult.cast<IntegerAttr>().getType().cast<IntegerType>().getWidth(),
94       convertedType.getWidth());
95 }
96 
97 TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
98   MLIRContext ctx;
99   IntegerType convertedType = IntegerType::get(8, &ctx);
100   auto quantizedType = getTestQuantizedType(convertedType, &ctx);
101   TestUniformQuantizedValueConverter converter(quantizedType);
102   auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>(
103       &ctx, {1, 2}, {getTestFloatAttr(1.0, &ctx), getTestFloatAttr(2.0, &ctx)});
104 
105   Type returnedType;
106   auto returnedValue =
107       quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
108 
109   // Check Elements attribute shape and kind are not changed.
110   auto tensorType = returnedType.cast<TensorType>();
111   auto expectedTensorType = realValue.getType().cast<TensorType>();
112   EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
113   EXPECT_EQ(tensorType.getElementType(), convertedType);
114   EXPECT_TRUE(returnedValue.isa<DenseIntElementsAttr>());
115 
116   // Check Elements attribute element value is expected.
117   auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
118   EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
119 }
120 
121 TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
122   MLIRContext ctx;
123   IntegerType convertedType = IntegerType::get(8, &ctx);
124   auto quantizedType = getTestQuantizedType(convertedType, &ctx);
125   TestUniformQuantizedValueConverter converter(quantizedType);
126   auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>(
127       &ctx, {1, 2}, getTestFloatAttr(1.0, &ctx));
128 
129   Type returnedType;
130   auto returnedValue =
131       quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
132 
133   // Check Elements attribute shape and kind are not changed.
134   auto tensorType = returnedType.cast<TensorType>();
135   auto expectedTensorType = realValue.getType().cast<TensorType>();
136   EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
137   EXPECT_EQ(tensorType.getElementType(), convertedType);
138   EXPECT_TRUE(returnedValue.isa<SplatElementsAttr>());
139 
140   // Check Elements attribute element value is expected.
141   auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
142   EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
143 }
144 
145 TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
146   MLIRContext ctx;
147   IntegerType convertedType = IntegerType::get(8, &ctx);
148   auto quantizedType = getTestQuantizedType(convertedType, &ctx);
149   TestUniformQuantizedValueConverter converter(quantizedType);
150   auto realValue = getTestSparseElementsAttr(&ctx, {1, 2});
151 
152   Type returnedType;
153   auto returnedValue =
154       quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
155 
156   // Check Elements attribute shape and kind are not changed.
157   auto tensorType = returnedType.cast<TensorType>();
158   auto expectedTensorType = realValue.getType().cast<TensorType>();
159   EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
160   EXPECT_EQ(tensorType.getElementType(), convertedType);
161   EXPECT_TRUE(returnedValue.isa<SparseElementsAttr>());
162 
163   // Check Elements attribute element value is expected.
164   auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
165   EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
166 }
167 
168 } // end namespace
169