1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Quant/QuantOps.h"
10 #include "TypeDetail.h"
11
12 #include "mlir/Dialect/Quant/QuantTypes.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/MathExtras.h"
20 #include <numeric>
21
22 using namespace mlir;
23 using namespace mlir::quant;
24 using namespace mlir::quant::detail;
25
26 #include "mlir/Dialect/Quant/QuantOpsDialect.cpp.inc"
27
initialize()28 void QuantizationDialect::initialize() {
29 addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
30 UniformQuantizedPerAxisType>();
31 addOperations<
32 #define GET_OP_LIST
33 #include "mlir/Dialect/Quant/QuantOps.cpp.inc"
34 >();
35 }
36
fold(ArrayRef<Attribute> operands)37 OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
38 // Matches x -> [scast -> scast] -> y, replacing the second scast with the
39 // value of x if the casts invert each other.
40 auto srcScastOp = getArg().getDefiningOp<StorageCastOp>();
41 if (!srcScastOp || srcScastOp.getArg().getType() != getType())
42 return OpFoldResult();
43 return srcScastOp.getArg();
44 }
45
46 /// The quantization specification should match the expressed type.
isValidQuantizationSpec(Attribute quantSpec,Type expressed)47 static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) {
48 if (auto typeAttr = quantSpec.dyn_cast<TypeAttr>()) {
49 Type spec = typeAttr.getValue();
50 if (spec.isa<TensorType, VectorType>())
51 return false;
52
53 // The spec should be either a quantized type which is compatible to the
54 // expressed type, or a primitive type which is as same as the
55 // (element type of) the expressed type.
56 if (auto quantizedType = spec.dyn_cast<QuantizedType>())
57 return quantizedType.isCompatibleExpressedType(expressed);
58
59 if (auto tensorType = expressed.dyn_cast<TensorType>())
60 return spec == tensorType.getElementType();
61
62 if (auto vectorType = expressed.dyn_cast<VectorType>())
63 return spec == vectorType.getElementType();
64 }
65 return false;
66 }
67
verify()68 LogicalResult QuantizeRegionOp::verify() {
69 // There are specifications for both inputs and outputs.
70 if (getNumOperands() != getInputSpecs().size() ||
71 getNumResults() != getOutputSpecs().size())
72 return emitOpError(
73 "has unmatched operands/results number and spec attributes number");
74
75 // Verify that quantization specifications are valid.
76 for (auto input : llvm::zip(getOperandTypes(), getInputSpecs())) {
77 Type inputType = std::get<0>(input);
78 Attribute inputSpec = std::get<1>(input);
79 if (!isValidQuantizationSpec(inputSpec, inputType)) {
80 return emitOpError() << "has incompatible specification " << inputSpec
81 << " and input type " << inputType;
82 }
83 }
84
85 for (auto result : llvm::zip(getResultTypes(), getOutputSpecs())) {
86 Type outputType = std::get<0>(result);
87 Attribute outputSpec = std::get<1>(result);
88 if (!isValidQuantizationSpec(outputSpec, outputType)) {
89 return emitOpError() << "has incompatible specification " << outputSpec
90 << " and output type " << outputType;
91 }
92 }
93 return success();
94 }
95
verify()96 LogicalResult StatisticsOp::verify() {
97 auto tensorArg = getArg().getType().dyn_cast<TensorType>();
98 if (!tensorArg)
99 return emitOpError("arg needs to be tensor type.");
100
101 // Verify layerStats attribute.
102 {
103 auto layerStatsType = getLayerStats().getType();
104 if (!layerStatsType.getElementType().isa<FloatType>()) {
105 return emitOpError("layerStats must have a floating point element type");
106 }
107 if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) {
108 return emitOpError("layerStats must have shape [2]");
109 }
110 }
111 // Verify axisStats (optional) attribute.
112 if (getAxisStats()) {
113 if (!getAxis())
114 return emitOpError("axis must be specified for axisStats");
115
116 auto shape = tensorArg.getShape();
117 auto argSliceSize =
118 std::accumulate(std::next(shape.begin(), *getAxis()), shape.end(), 1,
119 std::multiplies<int64_t>());
120
121 auto axisStatsType = getAxisStats()->getType();
122 if (!axisStatsType.getElementType().isa<FloatType>()) {
123 return emitOpError("axisStats must have a floating point element type");
124 }
125 if (axisStatsType.getRank() != 2 || axisStatsType.getDimSize(1) != 2 ||
126 axisStatsType.getDimSize(0) != argSliceSize) {
127 return emitOpError("axisStats must have shape [N,2] "
128 "where N = the slice size defined by the axis dim");
129 }
130 }
131 return success();
132 }
133
134 #define GET_OP_CLASSES
135 #include "mlir/Dialect/Quant/QuantOps.cpp.inc"
136