1 //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Quant/FakeQuantSupport.h" 10 #include "mlir/Dialect/Quant/Passes.h" 11 #include "mlir/Dialect/Quant/QuantOps.h" 12 #include "mlir/Dialect/Quant/UniformSupport.h" 13 #include "mlir/IR/Attributes.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/StandardTypes.h" 16 #include "mlir/Pass/Pass.h" 17 18 using namespace mlir; 19 using namespace mlir::quant; 20 21 namespace { 22 23 class ConvertSimulatedQuantPass 24 : public FunctionPass<ConvertSimulatedQuantPass> { 25 public: 26 void runOnFunction() override; 27 }; 28 29 /// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair. 30 template <typename ConcreteRewriteClass, typename FakeQuantOp> 31 class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> { 32 public: 33 using OpRewritePattern<FakeQuantOp>::OpRewritePattern; 34 35 FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) 36 : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {} 37 38 LogicalResult matchAndRewrite(FakeQuantOp op, 39 PatternRewriter &rewriter) const override { 40 // TODO: If this pattern comes up more frequently, consider adding core 41 // support for failable rewrites. 42 if (failableRewrite(op, rewriter)) { 43 *hadFailure = true; 44 return failure(); 45 } 46 47 return success(); 48 } 49 50 private: 51 bool *hadFailure; 52 53 bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const { 54 auto converter = ExpressedToQuantizedConverter::forInputType(op.getType()); 55 if (!converter) { 56 return (op.emitError("unsupported quantized type conversion"), true); 57 } 58 59 QuantizedType elementType = 60 static_cast<const ConcreteRewriteClass *>(this) 61 ->convertFakeQuantAttrsToType(op, converter.expressedType); 62 63 if (!elementType) { 64 // Note that the fakeQuantAttrsToType will have emitted the error. 65 return true; 66 } 67 68 Type quantizedType = converter.convert(elementType); 69 assert(quantizedType && 70 "Converter accepted a type that it did not convert"); 71 72 // TODO: Map to a qbarrier with an attribute like [Forced] to signal that 73 // this is a forced/hard-coded constraint. 74 auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType, 75 op.inputs()); 76 rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType, 77 qbarrier.getResult()); 78 79 return false; 80 } 81 }; 82 83 class ConstFakeQuantRewrite 84 : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> { 85 public: 86 using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>; 87 88 ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) 89 : BaseRewrite(ctx, hadFailure) {} 90 91 QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp, 92 Type expressedType) const { 93 return fakeQuantAttrsToType( 94 fqOp.getLoc(), fqOp.num_bits().getSExtValue(), 95 fqOp.min().convertToFloat(), fqOp.max().convertToFloat(), 96 fqOp.narrow_range(), expressedType, fqOp.is_signed()); 97 } 98 }; 99 100 class ConstFakeQuantPerAxisRewrite 101 : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, 102 ConstFakeQuantPerAxis> { 103 public: 104 using BaseRewrite = 105 FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>; 106 107 ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure) 108 : BaseRewrite(ctx, hadFailure) {} 109 110 QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp, 111 Type expressedType) const { 112 SmallVector<double, 4> min, max; 113 min.reserve(fqOp.min().size()); 114 max.reserve(fqOp.max().size()); 115 for (auto m : fqOp.min()) 116 min.push_back(m.cast<FloatAttr>().getValueAsDouble()); 117 for (auto m : fqOp.max()) 118 max.push_back(m.cast<FloatAttr>().getValueAsDouble()); 119 120 return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(), 121 fqOp.axis().getSExtValue(), min, max, 122 fqOp.narrow_range(), expressedType, 123 fqOp.is_signed()); 124 } 125 }; 126 127 } // namespace 128 129 void ConvertSimulatedQuantPass::runOnFunction() { 130 bool hadFailure = false; 131 OwningRewritePatternList patterns; 132 auto func = getFunction(); 133 auto ctx = func.getContext(); 134 patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>( 135 ctx, &hadFailure); 136 applyPatternsGreedily(func, patterns); 137 if (hadFailure) 138 signalPassFailure(); 139 } 140 141 std::unique_ptr<OpPassBase<FuncOp>> 142 mlir::quant::createConvertSimulatedQuantPass() { 143 return std::make_unique<ConvertSimulatedQuantPass>(); 144 } 145 146 static PassRegistration<ConvertSimulatedQuantPass> 147 pass("quant-convert-simulated-quantization", 148 "Converts training-time simulated quantization ops to corresponding " 149 "quantize/dequantize casts."); 150