1363dd3f3SRob Suderman //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2363dd3f3SRob Suderman //
3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6363dd3f3SRob Suderman //
7363dd3f3SRob Suderman //===----------------------------------------------------------------------===//
8363dd3f3SRob Suderman
9363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.h"
10363dd3f3SRob Suderman #include "TypeDetail.h"
11363dd3f3SRob Suderman
12363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantTypes.h"
1309f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
14363dd3f3SRob Suderman #include "mlir/IR/MLIRContext.h"
15363dd3f3SRob Suderman #include "mlir/IR/Matchers.h"
16363dd3f3SRob Suderman #include "mlir/IR/PatternMatch.h"
17363dd3f3SRob Suderman #include "llvm/ADT/StringRef.h"
18363dd3f3SRob Suderman #include "llvm/ADT/Twine.h"
19363dd3f3SRob Suderman #include "llvm/Support/MathExtras.h"
20363dd3f3SRob Suderman #include <numeric>
21363dd3f3SRob Suderman
22363dd3f3SRob Suderman using namespace mlir;
23363dd3f3SRob Suderman using namespace mlir::quant;
24363dd3f3SRob Suderman using namespace mlir::quant::detail;
25363dd3f3SRob Suderman
26485cc55eSStella Laurenzo #include "mlir/Dialect/Quant/QuantOpsDialect.cpp.inc"
27485cc55eSStella Laurenzo
initialize()28575b22b5SMehdi Amini void QuantizationDialect::initialize() {
2994e4ec64STei Jeong addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
30363dd3f3SRob Suderman UniformQuantizedPerAxisType>();
31363dd3f3SRob Suderman addOperations<
32363dd3f3SRob Suderman #define GET_OP_LIST
33363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.cpp.inc"
34363dd3f3SRob Suderman >();
35363dd3f3SRob Suderman }
36363dd3f3SRob Suderman
fold(ArrayRef<Attribute> operands)37363dd3f3SRob Suderman OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
38363dd3f3SRob Suderman // Matches x -> [scast -> scast] -> y, replacing the second scast with the
39363dd3f3SRob Suderman // value of x if the casts invert each other.
40*04235d07SJacques Pienaar auto srcScastOp = getArg().getDefiningOp<StorageCastOp>();
41*04235d07SJacques Pienaar if (!srcScastOp || srcScastOp.getArg().getType() != getType())
42363dd3f3SRob Suderman return OpFoldResult();
43*04235d07SJacques Pienaar return srcScastOp.getArg();
44363dd3f3SRob Suderman }
45363dd3f3SRob Suderman
46363dd3f3SRob Suderman /// The quantization specification should match the expressed type.
isValidQuantizationSpec(Attribute quantSpec,Type expressed)47363dd3f3SRob Suderman static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) {
48363dd3f3SRob Suderman if (auto typeAttr = quantSpec.dyn_cast<TypeAttr>()) {
49363dd3f3SRob Suderman Type spec = typeAttr.getValue();
50ee394e68SRahul Joshi if (spec.isa<TensorType, VectorType>())
51363dd3f3SRob Suderman return false;
52363dd3f3SRob Suderman
53363dd3f3SRob Suderman // The spec should be either a quantized type which is compatible to the
54363dd3f3SRob Suderman // expressed type, or a primitive type which is as same as the
55363dd3f3SRob Suderman // (element type of) the expressed type.
56363dd3f3SRob Suderman if (auto quantizedType = spec.dyn_cast<QuantizedType>())
57363dd3f3SRob Suderman return quantizedType.isCompatibleExpressedType(expressed);
58363dd3f3SRob Suderman
59363dd3f3SRob Suderman if (auto tensorType = expressed.dyn_cast<TensorType>())
60363dd3f3SRob Suderman return spec == tensorType.getElementType();
61363dd3f3SRob Suderman
62363dd3f3SRob Suderman if (auto vectorType = expressed.dyn_cast<VectorType>())
63363dd3f3SRob Suderman return spec == vectorType.getElementType();
64363dd3f3SRob Suderman }
65363dd3f3SRob Suderman return false;
66363dd3f3SRob Suderman }
67363dd3f3SRob Suderman
verify()681be88f5aSRiver Riddle LogicalResult QuantizeRegionOp::verify() {
69363dd3f3SRob Suderman // There are specifications for both inputs and outputs.
70*04235d07SJacques Pienaar if (getNumOperands() != getInputSpecs().size() ||
71*04235d07SJacques Pienaar getNumResults() != getOutputSpecs().size())
721be88f5aSRiver Riddle return emitOpError(
73363dd3f3SRob Suderman "has unmatched operands/results number and spec attributes number");
74363dd3f3SRob Suderman
75363dd3f3SRob Suderman // Verify that quantization specifications are valid.
76*04235d07SJacques Pienaar for (auto input : llvm::zip(getOperandTypes(), getInputSpecs())) {
77363dd3f3SRob Suderman Type inputType = std::get<0>(input);
78363dd3f3SRob Suderman Attribute inputSpec = std::get<1>(input);
79363dd3f3SRob Suderman if (!isValidQuantizationSpec(inputSpec, inputType)) {
801be88f5aSRiver Riddle return emitOpError() << "has incompatible specification " << inputSpec
81363dd3f3SRob Suderman << " and input type " << inputType;
82363dd3f3SRob Suderman }
83363dd3f3SRob Suderman }
84363dd3f3SRob Suderman
85*04235d07SJacques Pienaar for (auto result : llvm::zip(getResultTypes(), getOutputSpecs())) {
86363dd3f3SRob Suderman Type outputType = std::get<0>(result);
87363dd3f3SRob Suderman Attribute outputSpec = std::get<1>(result);
88363dd3f3SRob Suderman if (!isValidQuantizationSpec(outputSpec, outputType)) {
891be88f5aSRiver Riddle return emitOpError() << "has incompatible specification " << outputSpec
90363dd3f3SRob Suderman << " and output type " << outputType;
91363dd3f3SRob Suderman }
92363dd3f3SRob Suderman }
93363dd3f3SRob Suderman return success();
94363dd3f3SRob Suderman }
95363dd3f3SRob Suderman
verify()961be88f5aSRiver Riddle LogicalResult StatisticsOp::verify() {
97*04235d07SJacques Pienaar auto tensorArg = getArg().getType().dyn_cast<TensorType>();
981be88f5aSRiver Riddle if (!tensorArg)
991be88f5aSRiver Riddle return emitOpError("arg needs to be tensor type.");
1001be88f5aSRiver Riddle
1011be88f5aSRiver Riddle // Verify layerStats attribute.
1021be88f5aSRiver Riddle {
103*04235d07SJacques Pienaar auto layerStatsType = getLayerStats().getType();
1041be88f5aSRiver Riddle if (!layerStatsType.getElementType().isa<FloatType>()) {
1051be88f5aSRiver Riddle return emitOpError("layerStats must have a floating point element type");
1061be88f5aSRiver Riddle }
1071be88f5aSRiver Riddle if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) {
1081be88f5aSRiver Riddle return emitOpError("layerStats must have shape [2]");
1091be88f5aSRiver Riddle }
1101be88f5aSRiver Riddle }
1111be88f5aSRiver Riddle // Verify axisStats (optional) attribute.
112*04235d07SJacques Pienaar if (getAxisStats()) {
113*04235d07SJacques Pienaar if (!getAxis())
1141be88f5aSRiver Riddle return emitOpError("axis must be specified for axisStats");
1151be88f5aSRiver Riddle
1161be88f5aSRiver Riddle auto shape = tensorArg.getShape();
1171be88f5aSRiver Riddle auto argSliceSize =
118*04235d07SJacques Pienaar std::accumulate(std::next(shape.begin(), *getAxis()), shape.end(), 1,
1191be88f5aSRiver Riddle std::multiplies<int64_t>());
1201be88f5aSRiver Riddle
121*04235d07SJacques Pienaar auto axisStatsType = getAxisStats()->getType();
1221be88f5aSRiver Riddle if (!axisStatsType.getElementType().isa<FloatType>()) {
1231be88f5aSRiver Riddle return emitOpError("axisStats must have a floating point element type");
1241be88f5aSRiver Riddle }
1251be88f5aSRiver Riddle if (axisStatsType.getRank() != 2 || axisStatsType.getDimSize(1) != 2 ||
1261be88f5aSRiver Riddle axisStatsType.getDimSize(0) != argSliceSize) {
1271be88f5aSRiver Riddle return emitOpError("axisStats must have shape [N,2] "
1281be88f5aSRiver Riddle "where N = the slice size defined by the axis dim");
1291be88f5aSRiver Riddle }
1301be88f5aSRiver Riddle }
1311be88f5aSRiver Riddle return success();
1321be88f5aSRiver Riddle }
1331be88f5aSRiver Riddle
134363dd3f3SRob Suderman #define GET_OP_CLASSES
135363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.cpp.inc"
136