1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Quant/QuantOps.h"
10 #include "TypeDetail.h"
11 
12 #include "mlir/Dialect/Quant/QuantTypes.h"
13 #include "mlir/IR/MLIRContext.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/StandardTypes.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/MathExtras.h"
20 #include <numeric>
21 
22 using namespace mlir;
23 using namespace mlir::quant;
24 using namespace mlir::quant::detail;
25 
26 QuantizationDialect::QuantizationDialect(MLIRContext *context)
27     : Dialect(/*name=*/"quant", context) {
28   addTypes<AnyQuantizedType, UniformQuantizedType,
29            UniformQuantizedPerAxisType>();
30   addOperations<
31 #define GET_OP_LIST
32 #include "mlir/Dialect/Quant/QuantOps.cpp.inc"
33       >();
34 }
35 
36 OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
37   // Matches x -> [scast -> scast] -> y, replacing the second scast with the
38   // value of x if the casts invert each other.
39   auto srcScastOp = arg().getDefiningOp<StorageCastOp>();
40   if (!srcScastOp || srcScastOp.arg().getType() != getType())
41     return OpFoldResult();
42   return srcScastOp.arg();
43 }
44 
45 /// The quantization specification should match the expressed type.
46 static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) {
47   if (auto typeAttr = quantSpec.dyn_cast<TypeAttr>()) {
48     Type spec = typeAttr.getValue();
49     if (spec.isa<TensorType>() || spec.isa<VectorType>())
50       return false;
51 
52     // The spec should be either a quantized type which is compatible to the
53     // expressed type, or a primitive type which is as same as the
54     // (element type of) the expressed type.
55     if (auto quantizedType = spec.dyn_cast<QuantizedType>())
56       return quantizedType.isCompatibleExpressedType(expressed);
57 
58     if (auto tensorType = expressed.dyn_cast<TensorType>())
59       return spec == tensorType.getElementType();
60 
61     if (auto vectorType = expressed.dyn_cast<VectorType>())
62       return spec == vectorType.getElementType();
63   }
64   return false;
65 }
66 
67 static LogicalResult verifyRegionOp(QuantizeRegionOp op) {
68   // There are specifications for both inputs and outputs.
69   if (op.getNumOperands() != op.input_specs().size() ||
70       op.getNumResults() != op.output_specs().size())
71     return op.emitOpError(
72         "has unmatched operands/results number and spec attributes number");
73 
74   // Verify that quantization specifications are valid.
75   for (auto input : llvm::zip(op.getOperandTypes(), op.input_specs())) {
76     Type inputType = std::get<0>(input);
77     Attribute inputSpec = std::get<1>(input);
78     if (!isValidQuantizationSpec(inputSpec, inputType)) {
79       return op.emitOpError() << "has incompatible specification " << inputSpec
80                               << " and input type " << inputType;
81     }
82   }
83 
84   for (auto result : llvm::zip(op.getResultTypes(), op.output_specs())) {
85     Type outputType = std::get<0>(result);
86     Attribute outputSpec = std::get<1>(result);
87     if (!isValidQuantizationSpec(outputSpec, outputType)) {
88       return op.emitOpError() << "has incompatible specification " << outputSpec
89                               << " and output type " << outputType;
90     }
91   }
92   return success();
93 }
94 
95 #define GET_OP_CLASSES
96 #include "mlir/Dialect/Quant/QuantOps.cpp.inc"
97