1 //===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Quant/QuantOps.h"
10 #include "mlir/Dialect/Quant/QuantTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/Location.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/StringSwitch.h"
17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/MathExtras.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Support/raw_ostream.h"
21
22 using namespace mlir;
23 using namespace quant;
24
parseStorageType(DialectAsmParser & parser,bool & isSigned)25 static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
26 auto typeLoc = parser.getCurrentLocation();
27 IntegerType type;
28
29 // Parse storage type (alpha_ident, integer_literal).
30 StringRef identifier;
31 unsigned storageTypeWidth = 0;
32 OptionalParseResult result = parser.parseOptionalType(type);
33 if (result.hasValue()) {
34 if (!succeeded(*result))
35 return nullptr;
36 isSigned = !type.isUnsigned();
37 storageTypeWidth = type.getWidth();
38 } else if (succeeded(parser.parseKeyword(&identifier))) {
39 // Otherwise, this must be an unsigned integer (`u` integer-literal).
40 if (!identifier.consume_front("u")) {
41 parser.emitError(typeLoc, "illegal storage type prefix");
42 return nullptr;
43 }
44 if (identifier.getAsInteger(10, storageTypeWidth)) {
45 parser.emitError(typeLoc, "expected storage type width");
46 return nullptr;
47 }
48 isSigned = false;
49 type = parser.getBuilder().getIntegerType(storageTypeWidth);
50 } else {
51 return nullptr;
52 }
53
54 if (storageTypeWidth == 0 ||
55 storageTypeWidth > QuantizedType::MaxStorageBits) {
56 parser.emitError(typeLoc, "illegal storage type size: ")
57 << storageTypeWidth;
58 return nullptr;
59 }
60
61 return type;
62 }
63
parseStorageRange(DialectAsmParser & parser,IntegerType storageType,bool isSigned,int64_t & storageTypeMin,int64_t & storageTypeMax)64 static ParseResult parseStorageRange(DialectAsmParser &parser,
65 IntegerType storageType, bool isSigned,
66 int64_t &storageTypeMin,
67 int64_t &storageTypeMax) {
68 int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
69 isSigned, storageType.getWidth());
70 int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
71 isSigned, storageType.getWidth());
72 if (failed(parser.parseOptionalLess())) {
73 storageTypeMin = defaultIntegerMin;
74 storageTypeMax = defaultIntegerMax;
75 return success();
76 }
77
78 // Explicit storage min and storage max.
79 SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
80 if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
81 parser.getCurrentLocation(&maxLoc) ||
82 parser.parseInteger(storageTypeMax) || parser.parseGreater())
83 return failure();
84 if (storageTypeMin < defaultIntegerMin) {
85 return parser.emitError(minLoc, "illegal storage type minimum: ")
86 << storageTypeMin;
87 }
88 if (storageTypeMax > defaultIntegerMax) {
89 return parser.emitError(maxLoc, "illegal storage type maximum: ")
90 << storageTypeMax;
91 }
92 return success();
93 }
94
parseExpressedTypeAndRange(DialectAsmParser & parser,double & min,double & max)95 static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
96 double &min, double &max) {
97 auto typeLoc = parser.getCurrentLocation();
98 FloatType type;
99
100 if (failed(parser.parseType(type))) {
101 parser.emitError(typeLoc, "expecting float expressed type");
102 return nullptr;
103 }
104
105 // Calibrated min and max values.
106 if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() ||
107 parser.parseFloat(max) || parser.parseGreater()) {
108 parser.emitError(typeLoc, "calibrated values must be present");
109 return nullptr;
110 }
111 return type;
112 }
113
114 /// Parses an AnyQuantizedType.
115 ///
116 /// any ::= `any<` storage-spec (expressed-type-spec)?`>`
117 /// storage-spec ::= storage-type (`<` storage-range `>`)?
118 /// storage-range ::= integer-literal `:` integer-literal
119 /// storage-type ::= (`i` | `u`) integer-literal
120 /// expressed-type-spec ::= `:` `f` integer-literal
parseAnyType(DialectAsmParser & parser)121 static Type parseAnyType(DialectAsmParser &parser) {
122 IntegerType storageType;
123 FloatType expressedType;
124 unsigned typeFlags = 0;
125 int64_t storageTypeMin;
126 int64_t storageTypeMax;
127
128 // Type specification.
129 if (parser.parseLess())
130 return nullptr;
131
132 // Storage type.
133 bool isSigned = false;
134 storageType = parseStorageType(parser, isSigned);
135 if (!storageType) {
136 return nullptr;
137 }
138 if (isSigned) {
139 typeFlags |= QuantizationFlags::Signed;
140 }
141
142 // Storage type range.
143 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
144 storageTypeMax)) {
145 return nullptr;
146 }
147
148 // Optional expressed type.
149 if (succeeded(parser.parseOptionalColon())) {
150 if (parser.parseType(expressedType)) {
151 return nullptr;
152 }
153 }
154
155 if (parser.parseGreater()) {
156 return nullptr;
157 }
158
159 return parser.getChecked<AnyQuantizedType>(
160 typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
161 }
162
parseQuantParams(DialectAsmParser & parser,double & scale,int64_t & zeroPoint)163 static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
164 int64_t &zeroPoint) {
165 // scale[:zeroPoint]?
166 // scale.
167 if (parser.parseFloat(scale))
168 return failure();
169
170 // zero point.
171 zeroPoint = 0;
172 if (failed(parser.parseOptionalColon())) {
173 // Default zero point.
174 return success();
175 }
176
177 return parser.parseInteger(zeroPoint);
178 }
179
180 /// Parses a UniformQuantizedType.
181 ///
182 /// uniform_type ::= uniform_per_layer
183 /// | uniform_per_axis
184 /// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
185 /// `,` scale-zero `>`
186 /// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
187 /// axis-spec `,` scale-zero-list `>`
188 /// storage-spec ::= storage-type (`<` storage-range `>`)?
189 /// storage-range ::= integer-literal `:` integer-literal
190 /// storage-type ::= (`i` | `u`) integer-literal
191 /// expressed-type-spec ::= `:` `f` integer-literal
192 /// axis-spec ::= `:` integer-literal
193 /// scale-zero ::= float-literal `:` integer-literal
194 /// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
parseUniformType(DialectAsmParser & parser)195 static Type parseUniformType(DialectAsmParser &parser) {
196 IntegerType storageType;
197 FloatType expressedType;
198 unsigned typeFlags = 0;
199 int64_t storageTypeMin;
200 int64_t storageTypeMax;
201 bool isPerAxis = false;
202 int32_t quantizedDimension;
203 SmallVector<double, 1> scales;
204 SmallVector<int64_t, 1> zeroPoints;
205
206 // Type specification.
207 if (parser.parseLess()) {
208 return nullptr;
209 }
210
211 // Storage type.
212 bool isSigned = false;
213 storageType = parseStorageType(parser, isSigned);
214 if (!storageType) {
215 return nullptr;
216 }
217 if (isSigned) {
218 typeFlags |= QuantizationFlags::Signed;
219 }
220
221 // Storage type range.
222 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
223 storageTypeMax)) {
224 return nullptr;
225 }
226
227 // Expressed type.
228 if (parser.parseColon() || parser.parseType(expressedType)) {
229 return nullptr;
230 }
231
232 // Optionally parse quantized dimension for per-axis quantization.
233 if (succeeded(parser.parseOptionalColon())) {
234 if (parser.parseInteger(quantizedDimension))
235 return nullptr;
236 isPerAxis = true;
237 }
238
239 // Comma leading into range_spec.
240 if (parser.parseComma()) {
241 return nullptr;
242 }
243
244 // Parameter specification.
245 // For per-axis, ranges are in a {} delimitted list.
246 if (isPerAxis) {
247 if (parser.parseLBrace()) {
248 return nullptr;
249 }
250 }
251
252 // Parse scales/zeroPoints.
253 SMLoc scaleZPLoc = parser.getCurrentLocation();
254 do {
255 scales.resize(scales.size() + 1);
256 zeroPoints.resize(zeroPoints.size() + 1);
257 if (parseQuantParams(parser, scales.back(), zeroPoints.back())) {
258 return nullptr;
259 }
260 } while (isPerAxis && succeeded(parser.parseOptionalComma()));
261
262 if (isPerAxis) {
263 if (parser.parseRBrace()) {
264 return nullptr;
265 }
266 }
267
268 if (parser.parseGreater()) {
269 return nullptr;
270 }
271
272 if (!isPerAxis && scales.size() > 1) {
273 return (parser.emitError(scaleZPLoc,
274 "multiple scales/zeroPoints provided, but "
275 "quantizedDimension wasn't specified"),
276 nullptr);
277 }
278
279 if (isPerAxis) {
280 ArrayRef<double> scalesRef(scales.begin(), scales.end());
281 ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
282 return parser.getChecked<UniformQuantizedPerAxisType>(
283 typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
284 quantizedDimension, storageTypeMin, storageTypeMax);
285 }
286
287 return parser.getChecked<UniformQuantizedType>(
288 typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
289 storageTypeMin, storageTypeMax);
290 }
291
292 /// Parses an CalibratedQuantizedType.
293 ///
294 /// calibrated ::= `calibrated<` expressed-spec `>`
295 /// expressed-spec ::= expressed-type `<` calibrated-range `>`
296 /// expressed-type ::= `f` integer-literal
297 /// calibrated-range ::= float-literal `:` float-literal
parseCalibratedType(DialectAsmParser & parser)298 static Type parseCalibratedType(DialectAsmParser &parser) {
299 FloatType expressedType;
300 double min;
301 double max;
302
303 // Type specification.
304 if (parser.parseLess())
305 return nullptr;
306
307 // Expressed type.
308 expressedType = parseExpressedTypeAndRange(parser, min, max);
309 if (!expressedType) {
310 return nullptr;
311 }
312
313 if (parser.parseGreater()) {
314 return nullptr;
315 }
316
317 return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
318 }
319
320 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const321 Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
322 // All types start with an identifier that we switch on.
323 StringRef typeNameSpelling;
324 if (failed(parser.parseKeyword(&typeNameSpelling)))
325 return nullptr;
326
327 if (typeNameSpelling == "uniform")
328 return parseUniformType(parser);
329 if (typeNameSpelling == "any")
330 return parseAnyType(parser);
331 if (typeNameSpelling == "calibrated")
332 return parseCalibratedType(parser);
333
334 parser.emitError(parser.getNameLoc(),
335 "unknown quantized type " + typeNameSpelling);
336 return nullptr;
337 }
338
printStorageType(QuantizedType type,DialectAsmPrinter & out)339 static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
340 // storage type
341 unsigned storageWidth = type.getStorageTypeIntegralWidth();
342 bool isSigned = type.isSigned();
343 if (isSigned) {
344 out << "i" << storageWidth;
345 } else {
346 out << "u" << storageWidth;
347 }
348
349 // storageTypeMin and storageTypeMax if not default.
350 int64_t defaultIntegerMin =
351 QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth);
352 int64_t defaultIntegerMax =
353 QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth);
354 if (defaultIntegerMin != type.getStorageTypeMin() ||
355 defaultIntegerMax != type.getStorageTypeMax()) {
356 out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
357 << ">";
358 }
359 }
360
printQuantParams(double scale,int64_t zeroPoint,DialectAsmPrinter & out)361 static void printQuantParams(double scale, int64_t zeroPoint,
362 DialectAsmPrinter &out) {
363 out << scale;
364 if (zeroPoint != 0) {
365 out << ":" << zeroPoint;
366 }
367 }
368
369 /// Helper that prints a AnyQuantizedType.
printAnyQuantizedType(AnyQuantizedType type,DialectAsmPrinter & out)370 static void printAnyQuantizedType(AnyQuantizedType type,
371 DialectAsmPrinter &out) {
372 out << "any<";
373 printStorageType(type, out);
374 if (Type expressedType = type.getExpressedType()) {
375 out << ":" << expressedType;
376 }
377 out << ">";
378 }
379
380 /// Helper that prints a UniformQuantizedType.
printUniformQuantizedType(UniformQuantizedType type,DialectAsmPrinter & out)381 static void printUniformQuantizedType(UniformQuantizedType type,
382 DialectAsmPrinter &out) {
383 out << "uniform<";
384 printStorageType(type, out);
385 out << ":" << type.getExpressedType() << ", ";
386
387 // scheme specific parameters
388 printQuantParams(type.getScale(), type.getZeroPoint(), out);
389 out << ">";
390 }
391
392 /// Helper that prints a UniformQuantizedPerAxisType.
printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,DialectAsmPrinter & out)393 static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
394 DialectAsmPrinter &out) {
395 out << "uniform<";
396 printStorageType(type, out);
397 out << ":" << type.getExpressedType() << ":";
398 out << type.getQuantizedDimension();
399 out << ", ";
400
401 // scheme specific parameters
402 ArrayRef<double> scales = type.getScales();
403 ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
404 out << "{";
405 llvm::interleave(
406 llvm::seq<size_t>(0, scales.size()), out,
407 [&](size_t index) {
408 printQuantParams(scales[index], zeroPoints[index], out);
409 },
410 ",");
411 out << "}>";
412 }
413
414 /// Helper that prints a CalibratedQuantizedType.
printCalibratedQuantizedType(CalibratedQuantizedType type,DialectAsmPrinter & out)415 static void printCalibratedQuantizedType(CalibratedQuantizedType type,
416 DialectAsmPrinter &out) {
417 out << "calibrated<" << type.getExpressedType();
418 out << "<" << type.getMin() << ":" << type.getMax() << ">";
419 out << ">";
420 }
421
422 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const423 void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
424 if (auto anyType = type.dyn_cast<AnyQuantizedType>())
425 printAnyQuantizedType(anyType, os);
426 else if (auto uniformType = type.dyn_cast<UniformQuantizedType>())
427 printUniformQuantizedType(uniformType, os);
428 else if (auto perAxisType = type.dyn_cast<UniformQuantizedPerAxisType>())
429 printUniformQuantizedPerAxisType(perAxisType, os);
430 else if (auto calibratedType = type.dyn_cast<CalibratedQuantizedType>())
431 printCalibratedQuantizedType(calibratedType, os);
432 else
433 llvm_unreachable("Unhandled quantized type");
434 }
435