1 //===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Quant/QuantOps.h"
10 #include "mlir/Dialect/Quant/QuantTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/Location.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/StringSwitch.h"
17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/MathExtras.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Support/raw_ostream.h"
21 
22 using namespace mlir;
23 using namespace quant;
24 
parseStorageType(DialectAsmParser & parser,bool & isSigned)25 static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
26   auto typeLoc = parser.getCurrentLocation();
27   IntegerType type;
28 
29   // Parse storage type (alpha_ident, integer_literal).
30   StringRef identifier;
31   unsigned storageTypeWidth = 0;
32   OptionalParseResult result = parser.parseOptionalType(type);
33   if (result.hasValue()) {
34     if (!succeeded(*result))
35       return nullptr;
36     isSigned = !type.isUnsigned();
37     storageTypeWidth = type.getWidth();
38   } else if (succeeded(parser.parseKeyword(&identifier))) {
39     // Otherwise, this must be an unsigned integer (`u` integer-literal).
40     if (!identifier.consume_front("u")) {
41       parser.emitError(typeLoc, "illegal storage type prefix");
42       return nullptr;
43     }
44     if (identifier.getAsInteger(10, storageTypeWidth)) {
45       parser.emitError(typeLoc, "expected storage type width");
46       return nullptr;
47     }
48     isSigned = false;
49     type = parser.getBuilder().getIntegerType(storageTypeWidth);
50   } else {
51     return nullptr;
52   }
53 
54   if (storageTypeWidth == 0 ||
55       storageTypeWidth > QuantizedType::MaxStorageBits) {
56     parser.emitError(typeLoc, "illegal storage type size: ")
57         << storageTypeWidth;
58     return nullptr;
59   }
60 
61   return type;
62 }
63 
parseStorageRange(DialectAsmParser & parser,IntegerType storageType,bool isSigned,int64_t & storageTypeMin,int64_t & storageTypeMax)64 static ParseResult parseStorageRange(DialectAsmParser &parser,
65                                      IntegerType storageType, bool isSigned,
66                                      int64_t &storageTypeMin,
67                                      int64_t &storageTypeMax) {
68   int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
69       isSigned, storageType.getWidth());
70   int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
71       isSigned, storageType.getWidth());
72   if (failed(parser.parseOptionalLess())) {
73     storageTypeMin = defaultIntegerMin;
74     storageTypeMax = defaultIntegerMax;
75     return success();
76   }
77 
78   // Explicit storage min and storage max.
79   SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
80   if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
81       parser.getCurrentLocation(&maxLoc) ||
82       parser.parseInteger(storageTypeMax) || parser.parseGreater())
83     return failure();
84   if (storageTypeMin < defaultIntegerMin) {
85     return parser.emitError(minLoc, "illegal storage type minimum: ")
86            << storageTypeMin;
87   }
88   if (storageTypeMax > defaultIntegerMax) {
89     return parser.emitError(maxLoc, "illegal storage type maximum: ")
90            << storageTypeMax;
91   }
92   return success();
93 }
94 
parseExpressedTypeAndRange(DialectAsmParser & parser,double & min,double & max)95 static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
96                                             double &min, double &max) {
97   auto typeLoc = parser.getCurrentLocation();
98   FloatType type;
99 
100   if (failed(parser.parseType(type))) {
101     parser.emitError(typeLoc, "expecting float expressed type");
102     return nullptr;
103   }
104 
105   // Calibrated min and max values.
106   if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() ||
107       parser.parseFloat(max) || parser.parseGreater()) {
108     parser.emitError(typeLoc, "calibrated values must be present");
109     return nullptr;
110   }
111   return type;
112 }
113 
114 /// Parses an AnyQuantizedType.
115 ///
116 ///   any ::= `any<` storage-spec (expressed-type-spec)?`>`
117 ///   storage-spec ::= storage-type (`<` storage-range `>`)?
118 ///   storage-range ::= integer-literal `:` integer-literal
119 ///   storage-type ::= (`i` | `u`) integer-literal
120 ///   expressed-type-spec ::= `:` `f` integer-literal
parseAnyType(DialectAsmParser & parser)121 static Type parseAnyType(DialectAsmParser &parser) {
122   IntegerType storageType;
123   FloatType expressedType;
124   unsigned typeFlags = 0;
125   int64_t storageTypeMin;
126   int64_t storageTypeMax;
127 
128   // Type specification.
129   if (parser.parseLess())
130     return nullptr;
131 
132   // Storage type.
133   bool isSigned = false;
134   storageType = parseStorageType(parser, isSigned);
135   if (!storageType) {
136     return nullptr;
137   }
138   if (isSigned) {
139     typeFlags |= QuantizationFlags::Signed;
140   }
141 
142   // Storage type range.
143   if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
144                         storageTypeMax)) {
145     return nullptr;
146   }
147 
148   // Optional expressed type.
149   if (succeeded(parser.parseOptionalColon())) {
150     if (parser.parseType(expressedType)) {
151       return nullptr;
152     }
153   }
154 
155   if (parser.parseGreater()) {
156     return nullptr;
157   }
158 
159   return parser.getChecked<AnyQuantizedType>(
160       typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
161 }
162 
parseQuantParams(DialectAsmParser & parser,double & scale,int64_t & zeroPoint)163 static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
164                                     int64_t &zeroPoint) {
165   // scale[:zeroPoint]?
166   // scale.
167   if (parser.parseFloat(scale))
168     return failure();
169 
170   // zero point.
171   zeroPoint = 0;
172   if (failed(parser.parseOptionalColon())) {
173     // Default zero point.
174     return success();
175   }
176 
177   return parser.parseInteger(zeroPoint);
178 }
179 
180 /// Parses a UniformQuantizedType.
181 ///
182 ///   uniform_type ::= uniform_per_layer
183 ///                  | uniform_per_axis
184 ///   uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
185 ///                          `,` scale-zero `>`
186 ///   uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
187 ///                        axis-spec `,` scale-zero-list `>`
188 ///   storage-spec ::= storage-type (`<` storage-range `>`)?
189 ///   storage-range ::= integer-literal `:` integer-literal
190 ///   storage-type ::= (`i` | `u`) integer-literal
191 ///   expressed-type-spec ::= `:` `f` integer-literal
192 ///   axis-spec ::= `:` integer-literal
193 ///   scale-zero ::= float-literal `:` integer-literal
194 ///   scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
parseUniformType(DialectAsmParser & parser)195 static Type parseUniformType(DialectAsmParser &parser) {
196   IntegerType storageType;
197   FloatType expressedType;
198   unsigned typeFlags = 0;
199   int64_t storageTypeMin;
200   int64_t storageTypeMax;
201   bool isPerAxis = false;
202   int32_t quantizedDimension;
203   SmallVector<double, 1> scales;
204   SmallVector<int64_t, 1> zeroPoints;
205 
206   // Type specification.
207   if (parser.parseLess()) {
208     return nullptr;
209   }
210 
211   // Storage type.
212   bool isSigned = false;
213   storageType = parseStorageType(parser, isSigned);
214   if (!storageType) {
215     return nullptr;
216   }
217   if (isSigned) {
218     typeFlags |= QuantizationFlags::Signed;
219   }
220 
221   // Storage type range.
222   if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
223                         storageTypeMax)) {
224     return nullptr;
225   }
226 
227   // Expressed type.
228   if (parser.parseColon() || parser.parseType(expressedType)) {
229     return nullptr;
230   }
231 
232   // Optionally parse quantized dimension for per-axis quantization.
233   if (succeeded(parser.parseOptionalColon())) {
234     if (parser.parseInteger(quantizedDimension))
235       return nullptr;
236     isPerAxis = true;
237   }
238 
239   // Comma leading into range_spec.
240   if (parser.parseComma()) {
241     return nullptr;
242   }
243 
244   // Parameter specification.
245   // For per-axis, ranges are in a {} delimitted list.
246   if (isPerAxis) {
247     if (parser.parseLBrace()) {
248       return nullptr;
249     }
250   }
251 
252   // Parse scales/zeroPoints.
253   SMLoc scaleZPLoc = parser.getCurrentLocation();
254   do {
255     scales.resize(scales.size() + 1);
256     zeroPoints.resize(zeroPoints.size() + 1);
257     if (parseQuantParams(parser, scales.back(), zeroPoints.back())) {
258       return nullptr;
259     }
260   } while (isPerAxis && succeeded(parser.parseOptionalComma()));
261 
262   if (isPerAxis) {
263     if (parser.parseRBrace()) {
264       return nullptr;
265     }
266   }
267 
268   if (parser.parseGreater()) {
269     return nullptr;
270   }
271 
272   if (!isPerAxis && scales.size() > 1) {
273     return (parser.emitError(scaleZPLoc,
274                              "multiple scales/zeroPoints provided, but "
275                              "quantizedDimension wasn't specified"),
276             nullptr);
277   }
278 
279   if (isPerAxis) {
280     ArrayRef<double> scalesRef(scales.begin(), scales.end());
281     ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
282     return parser.getChecked<UniformQuantizedPerAxisType>(
283         typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
284         quantizedDimension, storageTypeMin, storageTypeMax);
285   }
286 
287   return parser.getChecked<UniformQuantizedType>(
288       typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
289       storageTypeMin, storageTypeMax);
290 }
291 
292 /// Parses an CalibratedQuantizedType.
293 ///
294 ///   calibrated ::= `calibrated<` expressed-spec `>`
295 ///   expressed-spec ::= expressed-type `<` calibrated-range `>`
296 ///   expressed-type ::= `f` integer-literal
297 ///   calibrated-range ::= float-literal `:` float-literal
parseCalibratedType(DialectAsmParser & parser)298 static Type parseCalibratedType(DialectAsmParser &parser) {
299   FloatType expressedType;
300   double min;
301   double max;
302 
303   // Type specification.
304   if (parser.parseLess())
305     return nullptr;
306 
307   // Expressed type.
308   expressedType = parseExpressedTypeAndRange(parser, min, max);
309   if (!expressedType) {
310     return nullptr;
311   }
312 
313   if (parser.parseGreater()) {
314     return nullptr;
315   }
316 
317   return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
318 }
319 
320 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const321 Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
322   // All types start with an identifier that we switch on.
323   StringRef typeNameSpelling;
324   if (failed(parser.parseKeyword(&typeNameSpelling)))
325     return nullptr;
326 
327   if (typeNameSpelling == "uniform")
328     return parseUniformType(parser);
329   if (typeNameSpelling == "any")
330     return parseAnyType(parser);
331   if (typeNameSpelling == "calibrated")
332     return parseCalibratedType(parser);
333 
334   parser.emitError(parser.getNameLoc(),
335                    "unknown quantized type " + typeNameSpelling);
336   return nullptr;
337 }
338 
printStorageType(QuantizedType type,DialectAsmPrinter & out)339 static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
340   // storage type
341   unsigned storageWidth = type.getStorageTypeIntegralWidth();
342   bool isSigned = type.isSigned();
343   if (isSigned) {
344     out << "i" << storageWidth;
345   } else {
346     out << "u" << storageWidth;
347   }
348 
349   // storageTypeMin and storageTypeMax if not default.
350   int64_t defaultIntegerMin =
351       QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth);
352   int64_t defaultIntegerMax =
353       QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth);
354   if (defaultIntegerMin != type.getStorageTypeMin() ||
355       defaultIntegerMax != type.getStorageTypeMax()) {
356     out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
357         << ">";
358   }
359 }
360 
printQuantParams(double scale,int64_t zeroPoint,DialectAsmPrinter & out)361 static void printQuantParams(double scale, int64_t zeroPoint,
362                              DialectAsmPrinter &out) {
363   out << scale;
364   if (zeroPoint != 0) {
365     out << ":" << zeroPoint;
366   }
367 }
368 
369 /// Helper that prints a AnyQuantizedType.
printAnyQuantizedType(AnyQuantizedType type,DialectAsmPrinter & out)370 static void printAnyQuantizedType(AnyQuantizedType type,
371                                   DialectAsmPrinter &out) {
372   out << "any<";
373   printStorageType(type, out);
374   if (Type expressedType = type.getExpressedType()) {
375     out << ":" << expressedType;
376   }
377   out << ">";
378 }
379 
380 /// Helper that prints a UniformQuantizedType.
printUniformQuantizedType(UniformQuantizedType type,DialectAsmPrinter & out)381 static void printUniformQuantizedType(UniformQuantizedType type,
382                                       DialectAsmPrinter &out) {
383   out << "uniform<";
384   printStorageType(type, out);
385   out << ":" << type.getExpressedType() << ", ";
386 
387   // scheme specific parameters
388   printQuantParams(type.getScale(), type.getZeroPoint(), out);
389   out << ">";
390 }
391 
392 /// Helper that prints a UniformQuantizedPerAxisType.
printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,DialectAsmPrinter & out)393 static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
394                                              DialectAsmPrinter &out) {
395   out << "uniform<";
396   printStorageType(type, out);
397   out << ":" << type.getExpressedType() << ":";
398   out << type.getQuantizedDimension();
399   out << ", ";
400 
401   // scheme specific parameters
402   ArrayRef<double> scales = type.getScales();
403   ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
404   out << "{";
405   llvm::interleave(
406       llvm::seq<size_t>(0, scales.size()), out,
407       [&](size_t index) {
408         printQuantParams(scales[index], zeroPoints[index], out);
409       },
410       ",");
411   out << "}>";
412 }
413 
414 /// Helper that prints a CalibratedQuantizedType.
printCalibratedQuantizedType(CalibratedQuantizedType type,DialectAsmPrinter & out)415 static void printCalibratedQuantizedType(CalibratedQuantizedType type,
416                                          DialectAsmPrinter &out) {
417   out << "calibrated<" << type.getExpressedType();
418   out << "<" << type.getMin() << ":" << type.getMax() << ">";
419   out << ">";
420 }
421 
422 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const423 void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
424   if (auto anyType = type.dyn_cast<AnyQuantizedType>())
425     printAnyQuantizedType(anyType, os);
426   else if (auto uniformType = type.dyn_cast<UniformQuantizedType>())
427     printUniformQuantizedType(uniformType, os);
428   else if (auto perAxisType = type.dyn_cast<UniformQuantizedPerAxisType>())
429     printUniformQuantizedPerAxisType(perAxisType, os);
430   else if (auto calibratedType = type.dyn_cast<CalibratedQuantizedType>())
431     printCalibratedQuantizedType(calibratedType, os);
432   else
433     llvm_unreachable("Unhandled quantized type");
434 }
435