1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Quant/QuantOps.h" 10 #include "TypeDetail.h" 11 12 #include "mlir/Dialect/Quant/QuantTypes.h" 13 #include "mlir/IR/MLIRContext.h" 14 #include "mlir/IR/Matchers.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/StandardTypes.h" 17 #include "llvm/ADT/StringRef.h" 18 #include "llvm/ADT/Twine.h" 19 #include "llvm/Support/MathExtras.h" 20 #include <numeric> 21 22 using namespace mlir; 23 using namespace mlir::quant; 24 using namespace mlir::quant::detail; 25 26 QuantizationDialect::QuantizationDialect(MLIRContext *context) 27 : Dialect(/*name=*/"quant", context) { 28 addTypes<AnyQuantizedType, UniformQuantizedType, 29 UniformQuantizedPerAxisType>(); 30 addOperations< 31 #define GET_OP_LIST 32 #include "mlir/Dialect/Quant/QuantOps.cpp.inc" 33 >(); 34 } 35 36 OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) { 37 // Matches x -> [scast -> scast] -> y, replacing the second scast with the 38 // value of x if the casts invert each other. 39 auto srcScastOp = arg().getDefiningOp<StorageCastOp>(); 40 if (!srcScastOp || srcScastOp.arg().getType() != getType()) 41 return OpFoldResult(); 42 return srcScastOp.arg(); 43 } 44 45 /// The quantization specification should match the expressed type. 46 static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) { 47 if (auto typeAttr = quantSpec.dyn_cast<TypeAttr>()) { 48 Type spec = typeAttr.getValue(); 49 if (spec.isa<TensorType>() || spec.isa<VectorType>()) 50 return false; 51 52 // The spec should be either a quantized type which is compatible to the 53 // expressed type, or a primitive type which is as same as the 54 // (element type of) the expressed type. 55 if (auto quantizedType = spec.dyn_cast<QuantizedType>()) 56 return quantizedType.isCompatibleExpressedType(expressed); 57 58 if (auto tensorType = expressed.dyn_cast<TensorType>()) 59 return spec == tensorType.getElementType(); 60 61 if (auto vectorType = expressed.dyn_cast<VectorType>()) 62 return spec == vectorType.getElementType(); 63 } 64 return false; 65 } 66 67 static LogicalResult verifyRegionOp(QuantizeRegionOp op) { 68 // There are specifications for both inputs and outputs. 69 if (op.getNumOperands() != op.input_specs().size() || 70 op.getNumResults() != op.output_specs().size()) 71 return op.emitOpError( 72 "has unmatched operands/results number and spec attributes number"); 73 74 // Verify that quantization specifications are valid. 75 for (auto input : llvm::zip(op.getOperandTypes(), op.input_specs())) { 76 Type inputType = std::get<0>(input); 77 Attribute inputSpec = std::get<1>(input); 78 if (!isValidQuantizationSpec(inputSpec, inputType)) { 79 return op.emitOpError() << "has incompatible specification " << inputSpec 80 << " and input type " << inputType; 81 } 82 } 83 84 for (auto result : llvm::zip(op.getResultTypes(), op.output_specs())) { 85 Type outputType = std::get<0>(result); 86 Attribute outputSpec = std::get<1>(result); 87 if (!isValidQuantizationSpec(outputSpec, outputType)) { 88 return op.emitOpError() << "has incompatible specification " << outputSpec 89 << " and output type " << outputType; 90 } 91 } 92 return success(); 93 } 94 95 #define GET_OP_CLASSES 96 #include "mlir/Dialect/Quant/QuantOps.cpp.inc" 97