1 //===- UniformSupport.cpp - Support utilities for uniform quant -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Quant/UniformSupport.h" 10 #include "mlir/IR/StandardTypes.h" 11 #include <numeric> 12 13 using namespace mlir; 14 using namespace mlir::quant; 15 16 static bool isQuantizablePrimitiveType(Type inputType) { 17 return inputType.isa<FloatType>(); 18 } 19 20 const ExpressedToQuantizedConverter 21 ExpressedToQuantizedConverter::forInputType(Type inputType) { 22 switch (inputType.getKind()) { 23 default: 24 if (isQuantizablePrimitiveType(inputType)) { 25 // Supported primitive type (which just is the expressed type). 26 return ExpressedToQuantizedConverter{inputType, inputType}; 27 } 28 // Unsupported. 29 return ExpressedToQuantizedConverter{inputType, nullptr}; 30 case StandardTypes::RankedTensor: 31 case StandardTypes::UnrankedTensor: 32 case StandardTypes::Vector: { 33 Type elementType = inputType.cast<ShapedType>().getElementType(); 34 if (!isQuantizablePrimitiveType(elementType)) { 35 // Unsupported. 36 return ExpressedToQuantizedConverter{inputType, nullptr}; 37 } 38 return ExpressedToQuantizedConverter{ 39 inputType, inputType.cast<ShapedType>().getElementType()}; 40 } 41 } 42 } 43 44 Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const { 45 assert(expressedType && "convert() on unsupported conversion"); 46 47 switch (inputType.getKind()) { 48 default: 49 if (elementalType.getExpressedType() == expressedType) { 50 // If the expressed types match, just use the new elemental type. 51 return elementalType; 52 } 53 // Unsupported. 54 return nullptr; 55 case StandardTypes::RankedTensor: 56 return RankedTensorType::get(inputType.cast<RankedTensorType>().getShape(), 57 elementalType); 58 case StandardTypes::UnrankedTensor: 59 return UnrankedTensorType::get(elementalType); 60 case StandardTypes::Vector: 61 return VectorType::get(inputType.cast<VectorType>().getShape(), 62 elementalType); 63 } 64 } 65 66 ElementsAttr 67 UniformQuantizedPerAxisValueConverter::convert(Attribute realValue) { 68 if (auto attr = realValue.dyn_cast<DenseFPElementsAttr>()) { 69 return convert(attr); 70 } 71 // TODO(fengliuai): handles sparse elements attribute 72 return nullptr; 73 } 74 75 DenseElementsAttr 76 UniformQuantizedPerAxisValueConverter::convert(DenseFPElementsAttr attr) { 77 // Creates the converter for each chunk. Normally the size of the 78 // quantization dim is 3, so we can cache all the converters. 79 ShapedType type = attr.getType(); 80 size_t dimSize = type.getDimSize(quantizationDim); 81 if (dimSize != scales.size()) { 82 return {}; 83 } 84 SmallVector<UniformQuantizedValueConverter, 4> converters; 85 converters.reserve(dimSize); 86 for (int i = 0, e = dimSize; i != e; ++i) { 87 converters.push_back(getPerChunkConverter(i)); 88 } 89 90 // Scan the elements of the dense elements attributes and quantize them by 91 // using the right quantization parameters. 92 int64_t flattenIndex = 0; 93 auto shape = type.getShape(); 94 int64_t chunkSize = 95 std::accumulate(std::next(shape.begin(), quantizationDim + 1), 96 shape.end(), 1, std::multiplies<int64_t>()); 97 Type newElementType = IntegerType::get(storageBitWidth, attr.getContext()); 98 return attr.mapValues(newElementType, [&](const APFloat &old) { 99 int chunkIndex = (flattenIndex++) / chunkSize; 100 return converters[chunkIndex % dimSize].quantizeFloatToInt(old); 101 }); 102 } 103