1 //===- QuantUtils.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains TOSA numerical support functions and quantization
10 // attribute builders.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
15
16 using namespace mlir;
17 using namespace mlir::tosa;
18
19 /// From a scale value, generates multiplier and shift values where
20 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
21 /// multiplier = mantissa*2^shift for 16-bit scaling.
computeMultiplierAndShiftTosaScale16(double scale,int32_t & multiplier,int32_t & shift)22 static void computeMultiplierAndShiftTosaScale16(double scale,
23 int32_t &multiplier,
24 int32_t &shift) {
25
26 const double mantissa = std::frexp(scale, &shift);
27 auto shiftedM = std::round(mantissa * (int64_t(1) << 15));
28
29 // Can't be greater than 1.0.
30 assert(shiftedM <= (int64_t(1) << 15) &&
31 "Shifted mantissa exceeds 16 signed bits");
32
33 if (shiftedM == (int64_t(1) << 15)) {
34 shiftedM /= 2;
35 shift++;
36 }
37
38 // TOSA expects right shift to be positive and embed (1 << 15) into right
39 // shift bits.
40 shift = (-shift) + 15;
41
42 assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
43 "Shifted mantissa exceeds 32-bit signed output type");
44
45 multiplier = static_cast<int32_t>(shiftedM);
46
47 // Shifting tops out at 63 bits. Right shift to make 63 bits the max.
48 if (shift > 63) {
49 // Shifting the multiplier by more than 31-bits is unnecessary.
50 multiplier = multiplier >> std::min<int32_t>(31, shift - 63);
51 shift = 63;
52 }
53 }
54
55 /// From a scale value, generates multiplier and shift values where
56 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
57 /// multiplier = mantissa*2^shift for 32-bit scaling.
computeMultiplierAndShiftTosaScale32(double scale,int32_t & multiplier,int32_t & shift)58 static void computeMultiplierAndShiftTosaScale32(double scale,
59 int32_t &multiplier,
60 int32_t &shift) {
61
62 const double mantissa = std::frexp(scale, &shift);
63 auto shiftedM = std::round(mantissa * (int64_t(1) << 31));
64
65 // Can't be greater than 1.0.
66 assert(shiftedM <= (int64_t(1) << 31) &&
67 "Shifted mantissa exceeds 32 signed bits");
68 if (shiftedM == (int64_t(1) << 31)) {
69 shiftedM /= 2;
70 shift++;
71 }
72
73 // TOSA expects right shift to be positive, and embed (1 << 31) into right
74 // shift bits.
75 shift = (-shift) + 31;
76
77 assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
78 "Shifted mantissa exceeds 32-bit signed output type");
79
80 multiplier = static_cast<int32_t>(shiftedM);
81
82 // Shifting tops out at 63 bits. Right shift to make 63 bits the max.
83 if (shift > 63) {
84 // Shifting the multiplier by more than 32-bits is unnecessary.
85 multiplier = multiplier >> std::min<int32_t>(31, shift - 63);
86 shift = 63;
87 }
88 }
89
90 /// Generates a quantized multiplier/shift from double.
computeMultiplierAndShift(double scale,int32_t & multiplier,int32_t & shift,int32_t scaleWidth)91 void mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier,
92 int32_t &shift, int32_t scaleWidth) {
93
94 switch (scaleWidth) {
95 case 16:
96 computeMultiplierAndShiftTosaScale16(scale, multiplier, shift);
97 return;
98 case 32:
99 computeMultiplierAndShiftTosaScale32(scale, multiplier, shift);
100 return;
101 default:
102 assert(0 && "Unsupported Tosa quantized_scale regime specified!");
103 }
104 }
105
106 #define GET_UQTYPE(input_type) \
107 ((input_type).getElementType().dyn_cast<quant::UniformQuantizedType>())
108 #define GET_QTYPE(input_type) \
109 ((input_type).getElementType().dyn_cast<quant::QuantizedType>())
110
111 /// Method to build ConvOpQuantizationAttr, called from
112 /// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder:
113 /// input_zp: input zeropoint
114 /// weight_zp: weight zeropoint.
115 ConvOpQuantizationAttr
buildConvOpQuantizationAttr(OpBuilder & builder,Value input,Value weight)116 mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input,
117 Value weight) {
118
119 auto inputType = input.getType().dyn_cast<ShapedType>();
120 auto weightType = weight.getType().dyn_cast<ShapedType>();
121
122 if (!inputType || !weightType)
123 return nullptr;
124
125 auto inputQType = GET_UQTYPE(inputType);
126 auto weightPerTensorQType = GET_UQTYPE(weightType);
127 auto weightPerAxisQType = weightType.getElementType()
128 .dyn_cast<quant::UniformQuantizedPerAxisType>();
129
130 // Weights must be either per-tensor quantized or per-axis quantized.
131 assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) &&
132 "Weights must be either per-tensor or per-axis quantized");
133
134 // Either all quantized or all not quantized.
135 assert(!((bool)inputQType ^
136 ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) &&
137 "Inputs and weights must be all quantized or all not quantized");
138
139 if (inputQType) {
140 int64_t inputZp = inputQType.getZeroPoint();
141 int64_t weightZp = 0;
142
143 if (weightPerTensorQType) {
144 weightZp = weightPerTensorQType.getZeroPoint();
145 } else if (weightPerAxisQType) {
146 weightZp = weightPerAxisQType.getZeroPoints().front();
147 }
148
149 return builder.getAttr<tosa::ConvOpQuantizationAttr>(inputZp, weightZp);
150 }
151
152 return nullptr;
153 }
154
155 /// Builds MatMulOpQuantizationAttr, called from
156 /// MatMulOpQuantInfoBuilder:
157 /// aZp: input a zeropoint
158 /// bZp: input b zeropoint.
159 MatMulOpQuantizationAttr
buildMatMulOpQuantizationAttr(OpBuilder & builder,Value a,Value b)160 mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a,
161 Value b) {
162
163 auto aType = a.getType().dyn_cast<ShapedType>();
164 auto bType = b.getType().dyn_cast<ShapedType>();
165
166 if (!aType || !bType)
167 return nullptr;
168
169 auto aQType = GET_UQTYPE(aType);
170 auto bQType = GET_UQTYPE(bType);
171
172 // A and B are either all quantized or all not quantized.
173 assert(!((bool)aQType ^ (bool)bQType) &&
174 "Matmul operands must be all quantized or all not quantized");
175
176 if (aQType) {
177 return builder.getAttr<tosa::MatMulOpQuantizationAttr>(
178 aQType.getZeroPoint(), bQType.getZeroPoint());
179 }
180
181 return nullptr;
182 }
183
184 /// Builds UnaryOpQuantizationAttr
185 /// UnaryOpQuantInfoBuilder:
186 /// inputZp: input zeropoint
187 /// outputZp: output zeropoint.
188 UnaryOpQuantizationAttr
buildUnaryOpQuantizationAttr(OpBuilder & builder,Value input,Type outputRawType)189 mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input,
190 Type outputRawType) {
191
192 auto inputType = input.getType().dyn_cast<ShapedType>();
193 auto outputType = outputRawType.dyn_cast<ShapedType>();
194
195 if (!inputType || !outputType)
196 return nullptr;
197
198 auto inputQType = GET_UQTYPE(inputType);
199 auto outputQType = GET_UQTYPE(outputType);
200
201 // Either all quantized or all not quantized.
202 assert(!((bool)inputQType ^ (bool)outputQType) &&
203 "Unary inputs/outputs must be all quantized or all not quantized");
204
205 if (inputQType) {
206 return builder.getAttr<UnaryOpQuantizationAttr>(inputQType.getZeroPoint(),
207 outputQType.getZeroPoint());
208 }
209
210 return nullptr;
211 }
212
213 /// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder:
214 /// inputZp: input zeropoint.
buildPadOpQuantizationAttr(OpBuilder & builder,Value input)215 PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
216 Value input) {
217
218 auto inputType = input.getType().dyn_cast<ShapedType>();
219
220 if (!inputType)
221 return nullptr;
222
223 auto inputQType = GET_UQTYPE(inputType);
224
225 if (inputQType) {
226 return builder.getAttr<tosa::PadOpQuantizationAttr>(
227 inputQType.getZeroPoint());
228 }
229
230 return nullptr;
231 }
232
233 /// Builds output type for a quantized ConvOp with the right bitwidth.
234 /// This is called by the builder when dealing with quantized content.
buildConvOpResultTypeInfo(OpBuilder & builder,Type outputType,Value input,Value weight)235 Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType,
236 Value input, Value weight) {
237
238 auto inputType = input.getType().dyn_cast<ShapedType>();
239 auto weightType = weight.getType().dyn_cast<ShapedType>();
240
241 assert(inputType && weightType &&
242 "Could not extract input or weight tensors from Conv op");
243
244 auto inputQType = GET_QTYPE(inputType);
245 auto weightQType = GET_QTYPE(weightType);
246
247 assert(inputQType && weightQType &&
248 "Could not extract input or weight tensor types from Conv op");
249
250 unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
251 unsigned weightBits = weightQType.getStorageTypeIntegralWidth();
252
253 auto outputShapedType = outputType.dyn_cast<ShapedType>();
254 assert(outputShapedType &&
255 "Could not extract output shape type from Conv op");
256
257 IntegerType accElementType;
258 if (inputBits == 16 && weightBits == 8)
259 accElementType = builder.getIntegerType(48);
260 else
261 accElementType = builder.getI32Type();
262 auto accType = outputShapedType.clone(accElementType);
263 return accType;
264 }
265
266 /// Builds Tosa quantization attributes from min/max values.
buildQTypeFromMinMax(OpBuilder builder,Type inputDType,Attribute minAttr,Attribute maxAttr,IntegerAttr quantBits,int filterQuantDim,bool isSigned,BoolAttr narrowRange)267 Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType,
268 Attribute minAttr, Attribute maxAttr,
269 IntegerAttr quantBits, int filterQuantDim,
270 bool isSigned, BoolAttr narrowRange) {
271
272 quant::QuantizedType retType;
273
274 auto convfunc =
275 quant::ExpressedToQuantizedConverter::forInputType(inputDType);
276
277 auto minElems = minAttr.dyn_cast<DenseFPElementsAttr>();
278 auto maxElems = maxAttr.dyn_cast<DenseFPElementsAttr>();
279
280 SmallVector<double, 2> min, max;
281
282 // At least one is per-axis quantized elementsattr.
283 if (minElems || maxElems) {
284 // Must have the same number of elements.
285 if (minElems.getNumElements() != maxElems.getNumElements())
286 return {};
287 min.reserve(minElems.getNumElements());
288 max.reserve(maxElems.getNumElements());
289 for (auto i : minElems)
290 min.push_back(FloatAttr::getValueAsDouble(i));
291 for (auto i : maxElems)
292 max.push_back(FloatAttr::getValueAsDouble(i));
293 } else { // Just a single FP value.
294 auto minVal = minAttr.dyn_cast<FloatAttr>();
295 if (minVal)
296 min.push_back(minVal.getValueAsDouble());
297 else
298 return {};
299 auto maxVal = maxAttr.dyn_cast<FloatAttr>();
300 if (maxVal)
301 max.push_back(maxVal.getValueAsDouble());
302 else
303 return {};
304 }
305
306 if (min.size() == max.size()) {
307 if (min.size() == 1) { // Per-tensor quantization with one min/max pair.
308 retType = quant::fakeQuantAttrsToType(
309 builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0],
310 narrowRange.getValue(), convfunc.expressedType, isSigned);
311 } else if (min.size() > 1) { // Per-axis quant on filterQuantDim.
312 auto shape = inputDType.dyn_cast<ShapedType>();
313 if (!shape)
314 return {};
315 if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) {
316 retType = quant::fakeQuantAttrsToType(
317 builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0],
318 max[0], narrowRange.getValue(), convfunc.expressedType, isSigned);
319 }
320 } else {
321 return {};
322 }
323 } else {
324 return {};
325 }
326
327 if (!retType)
328 return {};
329
330 return convfunc.convert(retType);
331 }
332
333 /// Builds Tosa quantization attributes from min/max values.
334 TypeAttr
buildQTypeAttrFromMinMax(OpBuilder builder,Type inputDtype,Attribute minAttr,Attribute maxAttr,IntegerAttr quantBits,int filterQuantDim,bool isSigned,BoolAttr narrowRange)335 mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype,
336 Attribute minAttr, Attribute maxAttr,
337 IntegerAttr quantBits, int filterQuantDim,
338 bool isSigned, BoolAttr narrowRange) {
339
340 return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr,
341 maxAttr, quantBits, filterQuantDim,
342 isSigned, narrowRange));
343 }
344