1 //===- QuantUtils.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains TOSA numerical support functions and quantization
10 // attribute builders.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
15 
16 using namespace mlir;
17 using namespace mlir::tosa;
18 
19 /// From a scale value, generates multiplier and shift values where
20 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
21 /// multiplier = mantissa*2^shift for 16-bit scaling.
computeMultiplierAndShiftTosaScale16(double scale,int32_t & multiplier,int32_t & shift)22 static void computeMultiplierAndShiftTosaScale16(double scale,
23                                                  int32_t &multiplier,
24                                                  int32_t &shift) {
25 
26   const double mantissa = std::frexp(scale, &shift);
27   auto shiftedM = std::round(mantissa * (int64_t(1) << 15));
28 
29   // Can't be greater than 1.0.
30   assert(shiftedM <= (int64_t(1) << 15) &&
31          "Shifted mantissa exceeds 16 signed bits");
32 
33   if (shiftedM == (int64_t(1) << 15)) {
34     shiftedM /= 2;
35     shift++;
36   }
37 
38   // TOSA expects right shift to be positive and embed (1 << 15) into right
39   // shift bits.
40   shift = (-shift) + 15;
41 
42   assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
43          "Shifted mantissa exceeds 32-bit signed output type");
44 
45   multiplier = static_cast<int32_t>(shiftedM);
46 
47   // Shifting tops out at 63 bits. Right shift to make 63 bits the max.
48   if (shift > 63) {
49     // Shifting the multiplier by more than 31-bits is unnecessary.
50     multiplier = multiplier >> std::min<int32_t>(31, shift - 63);
51     shift = 63;
52   }
53 }
54 
55 /// From a scale value, generates multiplier and shift values where
56 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
57 /// multiplier = mantissa*2^shift for 32-bit scaling.
computeMultiplierAndShiftTosaScale32(double scale,int32_t & multiplier,int32_t & shift)58 static void computeMultiplierAndShiftTosaScale32(double scale,
59                                                  int32_t &multiplier,
60                                                  int32_t &shift) {
61 
62   const double mantissa = std::frexp(scale, &shift);
63   auto shiftedM = std::round(mantissa * (int64_t(1) << 31));
64 
65   // Can't be greater than 1.0.
66   assert(shiftedM <= (int64_t(1) << 31) &&
67          "Shifted mantissa exceeds 32 signed bits");
68   if (shiftedM == (int64_t(1) << 31)) {
69     shiftedM /= 2;
70     shift++;
71   }
72 
73   // TOSA expects right shift to be positive, and embed (1 << 31) into right
74   // shift bits.
75   shift = (-shift) + 31;
76 
77   assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
78          "Shifted mantissa exceeds 32-bit signed output type");
79 
80   multiplier = static_cast<int32_t>(shiftedM);
81 
82   // Shifting tops out at 63 bits. Right shift to make 63 bits the max.
83   if (shift > 63) {
84     // Shifting the multiplier by more than 32-bits is unnecessary.
85     multiplier = multiplier >> std::min<int32_t>(31, shift - 63);
86     shift = 63;
87   }
88 }
89 
90 /// Generates a quantized multiplier/shift from double.
computeMultiplierAndShift(double scale,int32_t & multiplier,int32_t & shift,int32_t scaleWidth)91 void mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier,
92                                            int32_t &shift, int32_t scaleWidth) {
93 
94   switch (scaleWidth) {
95   case 16:
96     computeMultiplierAndShiftTosaScale16(scale, multiplier, shift);
97     return;
98   case 32:
99     computeMultiplierAndShiftTosaScale32(scale, multiplier, shift);
100     return;
101   default:
102     assert(0 && "Unsupported Tosa quantized_scale regime specified!");
103   }
104 }
105 
106 #define GET_UQTYPE(input_type)                                                 \
107   ((input_type).getElementType().dyn_cast<quant::UniformQuantizedType>())
108 #define GET_QTYPE(input_type)                                                  \
109   ((input_type).getElementType().dyn_cast<quant::QuantizedType>())
110 
111 /// Method to build ConvOpQuantizationAttr, called from
112 /// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder:
113 /// input_zp: input zeropoint
114 /// weight_zp: weight zeropoint.
115 ConvOpQuantizationAttr
buildConvOpQuantizationAttr(OpBuilder & builder,Value input,Value weight)116 mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input,
117                                         Value weight) {
118 
119   auto inputType = input.getType().dyn_cast<ShapedType>();
120   auto weightType = weight.getType().dyn_cast<ShapedType>();
121 
122   if (!inputType || !weightType)
123     return nullptr;
124 
125   auto inputQType = GET_UQTYPE(inputType);
126   auto weightPerTensorQType = GET_UQTYPE(weightType);
127   auto weightPerAxisQType = weightType.getElementType()
128                                 .dyn_cast<quant::UniformQuantizedPerAxisType>();
129 
130   // Weights must be either per-tensor quantized or per-axis quantized.
131   assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) &&
132          "Weights must be either per-tensor or per-axis quantized");
133 
134   // Either all quantized or all not quantized.
135   assert(!((bool)inputQType ^
136            ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) &&
137          "Inputs and weights must be all quantized or all not quantized");
138 
139   if (inputQType) {
140     int64_t inputZp = inputQType.getZeroPoint();
141     int64_t weightZp = 0;
142 
143     if (weightPerTensorQType) {
144       weightZp = weightPerTensorQType.getZeroPoint();
145     } else if (weightPerAxisQType) {
146       weightZp = weightPerAxisQType.getZeroPoints().front();
147     }
148 
149     return builder.getAttr<tosa::ConvOpQuantizationAttr>(inputZp, weightZp);
150   }
151 
152   return nullptr;
153 }
154 
155 /// Builds MatMulOpQuantizationAttr, called from
156 /// MatMulOpQuantInfoBuilder:
157 /// aZp: input a zeropoint
158 /// bZp: input b zeropoint.
159 MatMulOpQuantizationAttr
buildMatMulOpQuantizationAttr(OpBuilder & builder,Value a,Value b)160 mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a,
161                                           Value b) {
162 
163   auto aType = a.getType().dyn_cast<ShapedType>();
164   auto bType = b.getType().dyn_cast<ShapedType>();
165 
166   if (!aType || !bType)
167     return nullptr;
168 
169   auto aQType = GET_UQTYPE(aType);
170   auto bQType = GET_UQTYPE(bType);
171 
172   // A and B are either all quantized or all not quantized.
173   assert(!((bool)aQType ^ (bool)bQType) &&
174          "Matmul operands must be all quantized or all not quantized");
175 
176   if (aQType) {
177     return builder.getAttr<tosa::MatMulOpQuantizationAttr>(
178         aQType.getZeroPoint(), bQType.getZeroPoint());
179   }
180 
181   return nullptr;
182 }
183 
184 /// Builds UnaryOpQuantizationAttr
185 /// UnaryOpQuantInfoBuilder:
186 /// inputZp: input zeropoint
187 /// outputZp: output zeropoint.
188 UnaryOpQuantizationAttr
buildUnaryOpQuantizationAttr(OpBuilder & builder,Value input,Type outputRawType)189 mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input,
190                                          Type outputRawType) {
191 
192   auto inputType = input.getType().dyn_cast<ShapedType>();
193   auto outputType = outputRawType.dyn_cast<ShapedType>();
194 
195   if (!inputType || !outputType)
196     return nullptr;
197 
198   auto inputQType = GET_UQTYPE(inputType);
199   auto outputQType = GET_UQTYPE(outputType);
200 
201   // Either all quantized or all not quantized.
202   assert(!((bool)inputQType ^ (bool)outputQType) &&
203          "Unary inputs/outputs must be all quantized or all not quantized");
204 
205   if (inputQType) {
206     return builder.getAttr<UnaryOpQuantizationAttr>(inputQType.getZeroPoint(),
207                                                     outputQType.getZeroPoint());
208   }
209 
210   return nullptr;
211 }
212 
213 /// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder:
214 /// inputZp: input zeropoint.
buildPadOpQuantizationAttr(OpBuilder & builder,Value input)215 PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
216                                                              Value input) {
217 
218   auto inputType = input.getType().dyn_cast<ShapedType>();
219 
220   if (!inputType)
221     return nullptr;
222 
223   auto inputQType = GET_UQTYPE(inputType);
224 
225   if (inputQType) {
226     return builder.getAttr<tosa::PadOpQuantizationAttr>(
227         inputQType.getZeroPoint());
228   }
229 
230   return nullptr;
231 }
232 
233 /// Builds output type for a quantized ConvOp with the right bitwidth.
234 /// This is called by the builder when dealing with quantized content.
buildConvOpResultTypeInfo(OpBuilder & builder,Type outputType,Value input,Value weight)235 Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType,
236                                            Value input, Value weight) {
237 
238   auto inputType = input.getType().dyn_cast<ShapedType>();
239   auto weightType = weight.getType().dyn_cast<ShapedType>();
240 
241   assert(inputType && weightType &&
242          "Could not extract input or weight tensors from Conv op");
243 
244   auto inputQType = GET_QTYPE(inputType);
245   auto weightQType = GET_QTYPE(weightType);
246 
247   assert(inputQType && weightQType &&
248          "Could not extract input or weight tensor types from Conv op");
249 
250   unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
251   unsigned weightBits = weightQType.getStorageTypeIntegralWidth();
252 
253   auto outputShapedType = outputType.dyn_cast<ShapedType>();
254   assert(outputShapedType &&
255          "Could not extract output shape type from Conv op");
256 
257   IntegerType accElementType;
258   if (inputBits == 16 && weightBits == 8)
259     accElementType = builder.getIntegerType(48);
260   else
261     accElementType = builder.getI32Type();
262   auto accType = outputShapedType.clone(accElementType);
263   return accType;
264 }
265 
266 /// Builds Tosa quantization attributes from min/max values.
buildQTypeFromMinMax(OpBuilder builder,Type inputDType,Attribute minAttr,Attribute maxAttr,IntegerAttr quantBits,int filterQuantDim,bool isSigned,BoolAttr narrowRange)267 Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType,
268                                       Attribute minAttr, Attribute maxAttr,
269                                       IntegerAttr quantBits, int filterQuantDim,
270                                       bool isSigned, BoolAttr narrowRange) {
271 
272   quant::QuantizedType retType;
273 
274   auto convfunc =
275       quant::ExpressedToQuantizedConverter::forInputType(inputDType);
276 
277   auto minElems = minAttr.dyn_cast<DenseFPElementsAttr>();
278   auto maxElems = maxAttr.dyn_cast<DenseFPElementsAttr>();
279 
280   SmallVector<double, 2> min, max;
281 
282   // At least one is per-axis quantized elementsattr.
283   if (minElems || maxElems) {
284     // Must have the same number of elements.
285     if (minElems.getNumElements() != maxElems.getNumElements())
286       return {};
287     min.reserve(minElems.getNumElements());
288     max.reserve(maxElems.getNumElements());
289     for (auto i : minElems)
290       min.push_back(FloatAttr::getValueAsDouble(i));
291     for (auto i : maxElems)
292       max.push_back(FloatAttr::getValueAsDouble(i));
293   } else { // Just a single FP value.
294     auto minVal = minAttr.dyn_cast<FloatAttr>();
295     if (minVal)
296       min.push_back(minVal.getValueAsDouble());
297     else
298       return {};
299     auto maxVal = maxAttr.dyn_cast<FloatAttr>();
300     if (maxVal)
301       max.push_back(maxVal.getValueAsDouble());
302     else
303       return {};
304   }
305 
306   if (min.size() == max.size()) {
307     if (min.size() == 1) { // Per-tensor quantization with one min/max pair.
308       retType = quant::fakeQuantAttrsToType(
309           builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0],
310           narrowRange.getValue(), convfunc.expressedType, isSigned);
311     } else if (min.size() > 1) { // Per-axis quant on filterQuantDim.
312       auto shape = inputDType.dyn_cast<ShapedType>();
313       if (!shape)
314         return {};
315       if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) {
316         retType = quant::fakeQuantAttrsToType(
317             builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0],
318             max[0], narrowRange.getValue(), convfunc.expressedType, isSigned);
319       }
320     } else {
321       return {};
322     }
323   } else {
324     return {};
325   }
326 
327   if (!retType)
328     return {};
329 
330   return convfunc.convert(retType);
331 }
332 
333 /// Builds Tosa quantization attributes from min/max values.
334 TypeAttr
buildQTypeAttrFromMinMax(OpBuilder builder,Type inputDtype,Attribute minAttr,Attribute maxAttr,IntegerAttr quantBits,int filterQuantDim,bool isSigned,BoolAttr narrowRange)335 mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype,
336                                      Attribute minAttr, Attribute maxAttr,
337                                      IntegerAttr quantBits, int filterQuantDim,
338                                      bool isSigned, BoolAttr narrowRange) {
339 
340   return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr,
341                                             maxAttr, quantBits, filterQuantDim,
342                                             isSigned, narrowRange));
343 }
344