1 //===- QuantizationUtilsTest.cpp - unit tests for quantization utils ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Quant/QuantOps.h" 10 #include "mlir/Dialect/Quant/QuantizeUtils.h" 11 #include "mlir/Dialect/Quant/UniformSupport.h" 12 #include "mlir/IR/Attributes.h" 13 #include "mlir/IR/BuiltinTypes.h" 14 #include "gmock/gmock.h" 15 #include "gtest/gtest.h" 16 17 using namespace mlir; 18 using namespace mlir::quant; 19 20 namespace { 21 22 // Test UniformQuantizedValueConverter converts all APFloat to a magic number 5. 23 class TestUniformQuantizedValueConverter 24 : public UniformQuantizedValueConverter { 25 public: 26 TestUniformQuantizedValueConverter(UniformQuantizedType type) 27 : UniformQuantizedValueConverter(type), qtype(type) {} 28 APInt quantizeFloatToInt(APFloat expressedValue) const override { 29 return APInt(qtype.getStorageType().cast<IntegerType>().getWidth(), 5L); 30 } 31 32 private: 33 UniformQuantizedType qtype; 34 }; 35 36 Attribute getTestFloatAttr(double value, MLIRContext *ctx) { 37 return FloatAttr::get(FloatType::getF32(ctx), value); 38 } 39 40 template <typename ConcreteAttrClass, typename... Arg> 41 ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape, 42 Arg... value) { 43 auto eleType = FloatType::getF32(ctx); 44 ShapedType tensorType; 45 if (shape.size() == 1 && shape[0] == -1) { 46 tensorType = UnrankedTensorType::get(eleType); 47 } else { 48 tensorType = RankedTensorType::get(shape, eleType); 49 } 50 return ConcreteAttrClass::get(tensorType, value...); 51 } 52 53 ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx, 54 ArrayRef<int64_t> shape) { 55 auto eleType = FloatType::getF32(ctx); 56 ShapedType tensorType; 57 if (shape.size() == 1 && shape[0] == -1) { 58 tensorType = UnrankedTensorType::get(eleType); 59 } else { 60 tensorType = RankedTensorType::get(shape, eleType); 61 } 62 auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(ctx, 64)); 63 auto indices = 64 DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); 65 auto valuesType = RankedTensorType::get({1}, eleType); 66 auto values = DenseFPElementsAttr::get(valuesType, {APFloat(0.0f)}); 67 return SparseElementsAttr::get(tensorType, indices, values); 68 } 69 70 UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) { 71 return UniformQuantizedType::get(/*flags=*/false, storageType, 72 FloatType::getF32(ctx), /*scale=*/1.0, 73 /*zeroPoint=*/0, /*storageTypeMin=*/0, 74 /*storageTypeMax=*/255); 75 } 76 77 TEST(QuantizationUtilsTest, convertFloatAttrUniform) { 78 MLIRContext ctx; 79 ctx.getOrLoadDialect<QuantizationDialect>(); 80 IntegerType convertedType = IntegerType::get(&ctx, 8); 81 auto quantizedType = getTestQuantizedType(convertedType, &ctx); 82 TestUniformQuantizedValueConverter converter(quantizedType); 83 84 auto realValue = getTestFloatAttr(1.0, &ctx); 85 Type typeResult; 86 auto valueResult = 87 quantizeAttrUniform(realValue, quantizedType, converter, typeResult); 88 89 EXPECT_EQ(valueResult.cast<IntegerAttr>().getInt(), 5); 90 EXPECT_EQ( 91 valueResult.cast<IntegerAttr>().getType().cast<IntegerType>().getWidth(), 92 convertedType.getWidth()); 93 } 94 95 TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) { 96 MLIRContext ctx; 97 ctx.getOrLoadDialect<QuantizationDialect>(); 98 IntegerType convertedType = IntegerType::get(&ctx, 8); 99 auto quantizedType = getTestQuantizedType(convertedType, &ctx); 100 TestUniformQuantizedValueConverter converter(quantizedType); 101 auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>( 102 &ctx, {1, 2}, {getTestFloatAttr(1.0, &ctx), getTestFloatAttr(2.0, &ctx)}); 103 104 Type returnedType; 105 auto returnedValue = 106 quantizeAttrUniform(realValue, quantizedType, converter, returnedType); 107 108 // Check Elements attribute shape and kind are not changed. 109 auto tensorType = returnedType.cast<TensorType>(); 110 auto expectedTensorType = realValue.getType().cast<TensorType>(); 111 EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape()); 112 EXPECT_EQ(tensorType.getElementType(), convertedType); 113 EXPECT_TRUE(returnedValue.isa<DenseIntElementsAttr>()); 114 115 // Check Elements attribute element value is expected. 116 auto firstValue = 117 returnedValue.cast<ElementsAttr>().getValues<Attribute>()[{0, 0}]; 118 EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5); 119 } 120 121 TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) { 122 MLIRContext ctx; 123 ctx.getOrLoadDialect<QuantizationDialect>(); 124 IntegerType convertedType = IntegerType::get(&ctx, 8); 125 auto quantizedType = getTestQuantizedType(convertedType, &ctx); 126 TestUniformQuantizedValueConverter converter(quantizedType); 127 auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>( 128 &ctx, {1, 2}, getTestFloatAttr(1.0, &ctx)); 129 130 Type returnedType; 131 auto returnedValue = 132 quantizeAttrUniform(realValue, quantizedType, converter, returnedType); 133 134 // Check Elements attribute shape and kind are not changed. 135 auto tensorType = returnedType.cast<TensorType>(); 136 auto expectedTensorType = realValue.getType().cast<TensorType>(); 137 EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape()); 138 EXPECT_EQ(tensorType.getElementType(), convertedType); 139 EXPECT_TRUE(returnedValue.isa<SplatElementsAttr>()); 140 141 // Check Elements attribute element value is expected. 142 auto firstValue = 143 returnedValue.cast<ElementsAttr>().getValues<Attribute>()[{0, 0}]; 144 EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5); 145 } 146 147 TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) { 148 MLIRContext ctx; 149 ctx.getOrLoadDialect<QuantizationDialect>(); 150 IntegerType convertedType = IntegerType::get(&ctx, 8); 151 auto quantizedType = getTestQuantizedType(convertedType, &ctx); 152 TestUniformQuantizedValueConverter converter(quantizedType); 153 auto realValue = getTestSparseElementsAttr(&ctx, {1, 2}); 154 155 Type returnedType; 156 auto returnedValue = 157 quantizeAttrUniform(realValue, quantizedType, converter, returnedType); 158 159 // Check Elements attribute shape and kind are not changed. 160 auto tensorType = returnedType.cast<TensorType>(); 161 auto expectedTensorType = realValue.getType().cast<TensorType>(); 162 EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape()); 163 EXPECT_EQ(tensorType.getElementType(), convertedType); 164 EXPECT_TRUE(returnedValue.isa<SparseElementsAttr>()); 165 166 // Check Elements attribute element value is expected. 167 auto firstValue = 168 returnedValue.cast<ElementsAttr>().getValues<Attribute>()[{0, 0}]; 169 EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5); 170 } 171 172 } // namespace 173