1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Quant/QuantOps.h" 10 #include "TypeDetail.h" 11 12 #include "mlir/Dialect/Quant/QuantTypes.h" 13 #include "mlir/IR/BuiltinTypes.h" 14 #include "mlir/IR/MLIRContext.h" 15 #include "mlir/IR/Matchers.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "llvm/ADT/StringRef.h" 18 #include "llvm/ADT/Twine.h" 19 #include "llvm/Support/MathExtras.h" 20 #include <numeric> 21 22 using namespace mlir; 23 using namespace mlir::quant; 24 using namespace mlir::quant::detail; 25 26 #include "mlir/Dialect/Quant/QuantOpsDialect.cpp.inc" 27 28 void QuantizationDialect::initialize() { 29 addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType, 30 UniformQuantizedPerAxisType>(); 31 addOperations< 32 #define GET_OP_LIST 33 #include "mlir/Dialect/Quant/QuantOps.cpp.inc" 34 >(); 35 } 36 37 OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) { 38 // Matches x -> [scast -> scast] -> y, replacing the second scast with the 39 // value of x if the casts invert each other. 40 auto srcScastOp = getArg().getDefiningOp<StorageCastOp>(); 41 if (!srcScastOp || srcScastOp.getArg().getType() != getType()) 42 return OpFoldResult(); 43 return srcScastOp.getArg(); 44 } 45 46 /// The quantization specification should match the expressed type. 47 static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) { 48 if (auto typeAttr = quantSpec.dyn_cast<TypeAttr>()) { 49 Type spec = typeAttr.getValue(); 50 if (spec.isa<TensorType, VectorType>()) 51 return false; 52 53 // The spec should be either a quantized type which is compatible to the 54 // expressed type, or a primitive type which is as same as the 55 // (element type of) the expressed type. 56 if (auto quantizedType = spec.dyn_cast<QuantizedType>()) 57 return quantizedType.isCompatibleExpressedType(expressed); 58 59 if (auto tensorType = expressed.dyn_cast<TensorType>()) 60 return spec == tensorType.getElementType(); 61 62 if (auto vectorType = expressed.dyn_cast<VectorType>()) 63 return spec == vectorType.getElementType(); 64 } 65 return false; 66 } 67 68 LogicalResult QuantizeRegionOp::verify() { 69 // There are specifications for both inputs and outputs. 70 if (getNumOperands() != getInputSpecs().size() || 71 getNumResults() != getOutputSpecs().size()) 72 return emitOpError( 73 "has unmatched operands/results number and spec attributes number"); 74 75 // Verify that quantization specifications are valid. 76 for (auto input : llvm::zip(getOperandTypes(), getInputSpecs())) { 77 Type inputType = std::get<0>(input); 78 Attribute inputSpec = std::get<1>(input); 79 if (!isValidQuantizationSpec(inputSpec, inputType)) { 80 return emitOpError() << "has incompatible specification " << inputSpec 81 << " and input type " << inputType; 82 } 83 } 84 85 for (auto result : llvm::zip(getResultTypes(), getOutputSpecs())) { 86 Type outputType = std::get<0>(result); 87 Attribute outputSpec = std::get<1>(result); 88 if (!isValidQuantizationSpec(outputSpec, outputType)) { 89 return emitOpError() << "has incompatible specification " << outputSpec 90 << " and output type " << outputType; 91 } 92 } 93 return success(); 94 } 95 96 LogicalResult StatisticsOp::verify() { 97 auto tensorArg = getArg().getType().dyn_cast<TensorType>(); 98 if (!tensorArg) 99 return emitOpError("arg needs to be tensor type."); 100 101 // Verify layerStats attribute. 102 { 103 auto layerStatsType = getLayerStats().getType(); 104 if (!layerStatsType.getElementType().isa<FloatType>()) { 105 return emitOpError("layerStats must have a floating point element type"); 106 } 107 if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) { 108 return emitOpError("layerStats must have shape [2]"); 109 } 110 } 111 // Verify axisStats (optional) attribute. 112 if (getAxisStats()) { 113 if (!getAxis()) 114 return emitOpError("axis must be specified for axisStats"); 115 116 auto shape = tensorArg.getShape(); 117 auto argSliceSize = 118 std::accumulate(std::next(shape.begin(), *getAxis()), shape.end(), 1, 119 std::multiplies<int64_t>()); 120 121 auto axisStatsType = getAxisStats()->getType(); 122 if (!axisStatsType.getElementType().isa<FloatType>()) { 123 return emitOpError("axisStats must have a floating point element type"); 124 } 125 if (axisStatsType.getRank() != 2 || axisStatsType.getDimSize(1) != 2 || 126 axisStatsType.getDimSize(0) != argSliceSize) { 127 return emitOpError("axisStats must have shape [N,2] " 128 "where N = the slice size defined by the axis dim"); 129 } 130 } 131 return success(); 132 } 133 134 #define GET_OP_CLASSES 135 #include "mlir/Dialect/Quant/QuantOps.cpp.inc" 136