1 //===- UniformSupport.cpp - Support utilities for uniform quant -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Quant/UniformSupport.h"
10 #include "mlir/IR/StandardTypes.h"
11 #include <numeric>
12 
13 using namespace mlir;
14 using namespace mlir::quant;
15 
16 static bool isQuantizablePrimitiveType(Type inputType) {
17   return inputType.isa<FloatType>();
18 }
19 
20 const ExpressedToQuantizedConverter
21 ExpressedToQuantizedConverter::forInputType(Type inputType) {
22   switch (inputType.getKind()) {
23   default:
24     if (isQuantizablePrimitiveType(inputType)) {
25       // Supported primitive type (which just is the expressed type).
26       return ExpressedToQuantizedConverter{inputType, inputType};
27     }
28     // Unsupported.
29     return ExpressedToQuantizedConverter{inputType, nullptr};
30   case StandardTypes::RankedTensor:
31   case StandardTypes::UnrankedTensor:
32   case StandardTypes::Vector: {
33     Type elementType = inputType.cast<ShapedType>().getElementType();
34     if (!isQuantizablePrimitiveType(elementType)) {
35       // Unsupported.
36       return ExpressedToQuantizedConverter{inputType, nullptr};
37     }
38     return ExpressedToQuantizedConverter{
39         inputType, inputType.cast<ShapedType>().getElementType()};
40   }
41   }
42 }
43 
44 Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const {
45   assert(expressedType && "convert() on unsupported conversion");
46 
47   switch (inputType.getKind()) {
48   default:
49     if (elementalType.getExpressedType() == expressedType) {
50       // If the expressed types match, just use the new elemental type.
51       return elementalType;
52     }
53     // Unsupported.
54     return nullptr;
55   case StandardTypes::RankedTensor:
56     return RankedTensorType::get(inputType.cast<RankedTensorType>().getShape(),
57                                  elementalType);
58   case StandardTypes::UnrankedTensor:
59     return UnrankedTensorType::get(elementalType);
60   case StandardTypes::Vector:
61     return VectorType::get(inputType.cast<VectorType>().getShape(),
62                            elementalType);
63   }
64 }
65 
66 ElementsAttr
67 UniformQuantizedPerAxisValueConverter::convert(Attribute realValue) {
68   if (auto attr = realValue.dyn_cast<DenseFPElementsAttr>()) {
69     return convert(attr);
70   }
71   // TODO(fengliuai): handles sparse elements attribute
72   return nullptr;
73 }
74 
75 DenseElementsAttr
76 UniformQuantizedPerAxisValueConverter::convert(DenseFPElementsAttr attr) {
77   // Creates the converter for each chunk. Normally the size of the
78   // quantization dim is 3, so we can cache all the converters.
79   ShapedType type = attr.getType();
80   size_t dimSize = type.getDimSize(quantizationDim);
81   if (dimSize != scales.size()) {
82     return {};
83   }
84   SmallVector<UniformQuantizedValueConverter, 4> converters;
85   converters.reserve(dimSize);
86   for (int i = 0, e = dimSize; i != e; ++i) {
87     converters.push_back(getPerChunkConverter(i));
88   }
89 
90   // Scan the elements of the dense elements attributes and quantize them by
91   // using the right quantization parameters.
92   int64_t flattenIndex = 0;
93   auto shape = type.getShape();
94   int64_t chunkSize =
95       std::accumulate(std::next(shape.begin(), quantizationDim + 1),
96                       shape.end(), 1, std::multiplies<int64_t>());
97   Type newElementType = IntegerType::get(storageBitWidth, attr.getContext());
98   return attr.mapValues(newElementType, [&](const APFloat &old) {
99     int chunkIndex = (flattenIndex++) / chunkSize;
100     return converters[chunkIndex % dimSize].quantizeFloatToInt(old);
101   });
102 }
103