1 //===- QuantizationUtilsTest.cpp - unit tests for quantization utils ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Quant/QuantOps.h" 10 #include "mlir/Dialect/Quant/QuantizeUtils.h" 11 #include "mlir/Dialect/Quant/UniformSupport.h" 12 #include "mlir/IR/Attributes.h" 13 #include "mlir/IR/BuiltinTypes.h" 14 #include "gmock/gmock.h" 15 #include "gtest/gtest.h" 16 17 using namespace mlir; 18 using namespace mlir::quant; 19 20 namespace { 21 22 // Test UniformQuantizedValueConverter converts all APFloat to a magic number 5. 23 class TestUniformQuantizedValueConverter 24 : public UniformQuantizedValueConverter { 25 public: 26 TestUniformQuantizedValueConverter(UniformQuantizedType type) 27 : UniformQuantizedValueConverter(type), qtype(type) {} 28 APInt quantizeFloatToInt(APFloat expressedValue) const { 29 return APInt(qtype.getStorageType().cast<IntegerType>().getWidth(), 5L); 30 } 31 32 private: 33 UniformQuantizedType qtype; 34 }; 35 36 Attribute getTestFloatAttr(double value, MLIRContext *ctx) { 37 return FloatAttr::get(FloatType::getF32(ctx), value); 38 } 39 40 template <typename ConcreteAttrClass, typename... Arg> 41 ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape, 42 Arg... value) { 43 auto eleType = FloatType::getF32(ctx); 44 ShapedType tensorType; 45 if (shape.size() == 1 && shape[0] == -1) { 46 tensorType = UnrankedTensorType::get(eleType); 47 } else { 48 tensorType = RankedTensorType::get(shape, eleType); 49 } 50 return ConcreteAttrClass::get(tensorType, value...); 51 } 52 53 ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx, 54 ArrayRef<int64_t> shape) { 55 auto eleType = FloatType::getF32(ctx); 56 ShapedType tensorType; 57 if (shape.size() == 1 && shape[0] == -1) { 58 tensorType = UnrankedTensorType::get(eleType); 59 } else { 60 tensorType = RankedTensorType::get(shape, eleType); 61 } 62 auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(ctx, 64)); 63 auto indices = 64 DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); 65 auto valuesType = RankedTensorType::get({1}, eleType); 66 auto values = DenseFPElementsAttr::get(valuesType, {APFloat(0.0f)}); 67 return SparseElementsAttr::get(tensorType, indices, values); 68 } 69 70 UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) { 71 return UniformQuantizedType::get(/*flags=*/false, storageType, 72 FloatType::getF32(ctx), /*scale=*/1.0, 73 /*zeroPoint=*/0, /*storageTypeMin=*/0, 74 /*storageTypeMax=*/255); 75 } 76 77 TEST(QuantizationUtilsTest, convertFloatAttrUniform) { 78 MLIRContext ctx; 79 ctx.getOrLoadDialect<QuantizationDialect>(); 80 IntegerType convertedType = IntegerType::get(&ctx, 8); 81 auto quantizedType = getTestQuantizedType(convertedType, &ctx); 82 TestUniformQuantizedValueConverter converter(quantizedType); 83 84 auto realValue = getTestFloatAttr(1.0, &ctx); 85 Type typeResult; 86 auto valueResult = 87 quantizeAttrUniform(realValue, quantizedType, converter, typeResult); 88 89 EXPECT_EQ(valueResult.cast<IntegerAttr>().getInt(), 5); 90 EXPECT_EQ( 91 valueResult.cast<IntegerAttr>().getType().cast<IntegerType>().getWidth(), 92 convertedType.getWidth()); 93 } 94 95 TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) { 96 MLIRContext ctx; 97 ctx.getOrLoadDialect<QuantizationDialect>(); 98 IntegerType convertedType = IntegerType::get(&ctx, 8); 99 auto quantizedType = getTestQuantizedType(convertedType, &ctx); 100 TestUniformQuantizedValueConverter converter(quantizedType); 101 auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>( 102 &ctx, {1, 2}, {getTestFloatAttr(1.0, &ctx), getTestFloatAttr(2.0, &ctx)}); 103 104 Type returnedType; 105 auto returnedValue = 106 quantizeAttrUniform(realValue, quantizedType, converter, returnedType); 107 108 // Check Elements attribute shape and kind are not changed. 109 auto tensorType = returnedType.cast<TensorType>(); 110 auto expectedTensorType = realValue.getType().cast<TensorType>(); 111 EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape()); 112 EXPECT_EQ(tensorType.getElementType(), convertedType); 113 EXPECT_TRUE(returnedValue.isa<DenseIntElementsAttr>()); 114 115 // Check Elements attribute element value is expected. 116 auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0}); 117 EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5); 118 } 119 120 TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) { 121 MLIRContext ctx; 122 ctx.getOrLoadDialect<QuantizationDialect>(); 123 IntegerType convertedType = IntegerType::get(&ctx, 8); 124 auto quantizedType = getTestQuantizedType(convertedType, &ctx); 125 TestUniformQuantizedValueConverter converter(quantizedType); 126 auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>( 127 &ctx, {1, 2}, getTestFloatAttr(1.0, &ctx)); 128 129 Type returnedType; 130 auto returnedValue = 131 quantizeAttrUniform(realValue, quantizedType, converter, returnedType); 132 133 // Check Elements attribute shape and kind are not changed. 134 auto tensorType = returnedType.cast<TensorType>(); 135 auto expectedTensorType = realValue.getType().cast<TensorType>(); 136 EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape()); 137 EXPECT_EQ(tensorType.getElementType(), convertedType); 138 EXPECT_TRUE(returnedValue.isa<SplatElementsAttr>()); 139 140 // Check Elements attribute element value is expected. 141 auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0}); 142 EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5); 143 } 144 145 TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) { 146 MLIRContext ctx; 147 ctx.getOrLoadDialect<QuantizationDialect>(); 148 IntegerType convertedType = IntegerType::get(&ctx, 8); 149 auto quantizedType = getTestQuantizedType(convertedType, &ctx); 150 TestUniformQuantizedValueConverter converter(quantizedType); 151 auto realValue = getTestSparseElementsAttr(&ctx, {1, 2}); 152 153 Type returnedType; 154 auto returnedValue = 155 quantizeAttrUniform(realValue, quantizedType, converter, returnedType); 156 157 // Check Elements attribute shape and kind are not changed. 158 auto tensorType = returnedType.cast<TensorType>(); 159 auto expectedTensorType = realValue.getType().cast<TensorType>(); 160 EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape()); 161 EXPECT_EQ(tensorType.getElementType(), convertedType); 162 EXPECT_TRUE(returnedValue.isa<SparseElementsAttr>()); 163 164 // Check Elements attribute element value is expected. 165 auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0}); 166 EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5); 167 } 168 169 } // end namespace 170