1 //===- QuantizationUtilsTest.cpp - unit tests for quantization utils ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Quant/QuantOps.h" 10 #include "mlir/Dialect/Quant/QuantizeUtils.h" 11 #include "mlir/Dialect/Quant/UniformSupport.h" 12 #include "mlir/IR/Attributes.h" 13 #include "mlir/IR/StandardTypes.h" 14 #include "gmock/gmock.h" 15 #include "gtest/gtest.h" 16 17 using namespace mlir; 18 using namespace mlir::quant; 19 20 // Load the quant dialect 21 static DialectRegistration<QuantizationDialect> QuantOpsRegistration; 22 23 namespace { 24 25 // Test UniformQuantizedValueConverter converts all APFloat to a magic number 5. 26 class TestUniformQuantizedValueConverter 27 : public UniformQuantizedValueConverter { 28 public: 29 TestUniformQuantizedValueConverter(UniformQuantizedType type) 30 : UniformQuantizedValueConverter(type), qtype(type) {} 31 APInt quantizeFloatToInt(APFloat expressedValue) const { 32 return APInt(qtype.getStorageType().cast<IntegerType>().getWidth(), 5L); 33 } 34 35 private: 36 UniformQuantizedType qtype; 37 }; 38 39 Attribute getTestFloatAttr(double value, MLIRContext *ctx) { 40 return FloatAttr::get(FloatType::getF32(ctx), value); 41 } 42 43 template <typename ConcreteAttrClass, typename... Arg> 44 ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape, 45 Arg... value) { 46 auto eleType = FloatType::getF32(ctx); 47 ShapedType tensorType; 48 if (shape.size() == 1 && shape[0] == -1) { 49 tensorType = UnrankedTensorType::get(eleType); 50 } else { 51 tensorType = RankedTensorType::get(shape, eleType); 52 } 53 return ConcreteAttrClass::get(tensorType, value...); 54 } 55 56 ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx, 57 ArrayRef<int64_t> shape) { 58 auto eleType = FloatType::getF32(ctx); 59 ShapedType tensorType; 60 if (shape.size() == 1 && shape[0] == -1) { 61 tensorType = UnrankedTensorType::get(eleType); 62 } else { 63 tensorType = RankedTensorType::get(shape, eleType); 64 } 65 auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(64, ctx)); 66 auto indices = 67 DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); 68 auto valuesType = RankedTensorType::get({1}, eleType); 69 auto values = DenseFPElementsAttr::get(valuesType, {APFloat(0.0f)}); 70 return SparseElementsAttr::get(tensorType, indices, values); 71 } 72 73 UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) { 74 return UniformQuantizedType::get(/*flags=*/false, storageType, 75 FloatType::getF32(ctx), /*scale=*/1.0, 76 /*zeroPoint=*/0, /*storageTypeMin=*/0, 77 /*storageTypeMax=*/255); 78 } 79 80 TEST(QuantizationUtilsTest, convertFloatAttrUniform) { 81 MLIRContext ctx; 82 IntegerType convertedType = IntegerType::get(8, &ctx); 83 auto quantizedType = getTestQuantizedType(convertedType, &ctx); 84 TestUniformQuantizedValueConverter converter(quantizedType); 85 86 auto realValue = getTestFloatAttr(1.0, &ctx); 87 Type typeResult; 88 auto valueResult = 89 quantizeAttrUniform(realValue, quantizedType, converter, typeResult); 90 91 EXPECT_EQ(valueResult.cast<IntegerAttr>().getInt(), 5); 92 EXPECT_EQ( 93 valueResult.cast<IntegerAttr>().getType().cast<IntegerType>().getWidth(), 94 convertedType.getWidth()); 95 } 96 97 TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) { 98 MLIRContext ctx; 99 IntegerType convertedType = IntegerType::get(8, &ctx); 100 auto quantizedType = getTestQuantizedType(convertedType, &ctx); 101 TestUniformQuantizedValueConverter converter(quantizedType); 102 auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>( 103 &ctx, {1, 2}, {getTestFloatAttr(1.0, &ctx), getTestFloatAttr(2.0, &ctx)}); 104 105 Type returnedType; 106 auto returnedValue = 107 quantizeAttrUniform(realValue, quantizedType, converter, returnedType); 108 109 // Check Elements attribute shape and kind are not changed. 110 auto tensorType = returnedType.cast<TensorType>(); 111 auto expectedTensorType = realValue.getType().cast<TensorType>(); 112 EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape()); 113 EXPECT_EQ(tensorType.getElementType(), convertedType); 114 EXPECT_TRUE(returnedValue.isa<DenseIntElementsAttr>()); 115 116 // Check Elements attribute element value is expected. 117 auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0}); 118 EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5); 119 } 120 121 TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) { 122 MLIRContext ctx; 123 IntegerType convertedType = IntegerType::get(8, &ctx); 124 auto quantizedType = getTestQuantizedType(convertedType, &ctx); 125 TestUniformQuantizedValueConverter converter(quantizedType); 126 auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>( 127 &ctx, {1, 2}, getTestFloatAttr(1.0, &ctx)); 128 129 Type returnedType; 130 auto returnedValue = 131 quantizeAttrUniform(realValue, quantizedType, converter, returnedType); 132 133 // Check Elements attribute shape and kind are not changed. 134 auto tensorType = returnedType.cast<TensorType>(); 135 auto expectedTensorType = realValue.getType().cast<TensorType>(); 136 EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape()); 137 EXPECT_EQ(tensorType.getElementType(), convertedType); 138 EXPECT_TRUE(returnedValue.isa<SplatElementsAttr>()); 139 140 // Check Elements attribute element value is expected. 141 auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0}); 142 EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5); 143 } 144 145 TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) { 146 MLIRContext ctx; 147 IntegerType convertedType = IntegerType::get(8, &ctx); 148 auto quantizedType = getTestQuantizedType(convertedType, &ctx); 149 TestUniformQuantizedValueConverter converter(quantizedType); 150 auto realValue = getTestSparseElementsAttr(&ctx, {1, 2}); 151 152 Type returnedType; 153 auto returnedValue = 154 quantizeAttrUniform(realValue, quantizedType, converter, returnedType); 155 156 // Check Elements attribute shape and kind are not changed. 157 auto tensorType = returnedType.cast<TensorType>(); 158 auto expectedTensorType = realValue.getType().cast<TensorType>(); 159 EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape()); 160 EXPECT_EQ(tensorType.getElementType(), convertedType); 161 EXPECT_TRUE(returnedValue.isa<SparseElementsAttr>()); 162 163 // Check Elements attribute element value is expected. 164 auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0}); 165 EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5); 166 } 167 168 } // end namespace 169