1*363dd3f3SRob Suderman //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===// 2*363dd3f3SRob Suderman // 3*363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information. 5*363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*363dd3f3SRob Suderman // 7*363dd3f3SRob Suderman //===----------------------------------------------------------------------===// 8*363dd3f3SRob Suderman 9*363dd3f3SRob Suderman #include "mlir/Dialect/Quant/FakeQuantSupport.h" 10*363dd3f3SRob Suderman #include "mlir/Dialect/Quant/Passes.h" 11*363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.h" 12*363dd3f3SRob Suderman #include "mlir/Dialect/Quant/UniformSupport.h" 13*363dd3f3SRob Suderman #include "mlir/IR/Attributes.h" 14*363dd3f3SRob Suderman #include "mlir/IR/PatternMatch.h" 15*363dd3f3SRob Suderman #include "mlir/IR/StandardTypes.h" 16*363dd3f3SRob Suderman #include "mlir/Pass/Pass.h" 17*363dd3f3SRob Suderman 18*363dd3f3SRob Suderman using namespace mlir; 19*363dd3f3SRob Suderman using namespace mlir::quant; 20*363dd3f3SRob Suderman 21*363dd3f3SRob Suderman namespace { 22*363dd3f3SRob Suderman 23*363dd3f3SRob Suderman class ConvertSimulatedQuantPass 24*363dd3f3SRob Suderman : public FunctionPass<ConvertSimulatedQuantPass> { 25*363dd3f3SRob Suderman public: 26*363dd3f3SRob Suderman void runOnFunction() override; 27*363dd3f3SRob Suderman }; 28*363dd3f3SRob Suderman 29*363dd3f3SRob Suderman /// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair. 30*363dd3f3SRob Suderman template <typename ConcreteRewriteClass, typename FakeQuantOp> 31*363dd3f3SRob Suderman class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> { 32*363dd3f3SRob Suderman public: 33*363dd3f3SRob Suderman using OpRewritePattern<FakeQuantOp>::OpRewritePattern; 34*363dd3f3SRob Suderman 35*363dd3f3SRob Suderman FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) 36*363dd3f3SRob Suderman : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {} 37*363dd3f3SRob Suderman 38*363dd3f3SRob Suderman PatternMatchResult matchAndRewrite(FakeQuantOp op, 39*363dd3f3SRob Suderman PatternRewriter &rewriter) const override { 40*363dd3f3SRob Suderman // TODO: If this pattern comes up more frequently, consider adding core 41*363dd3f3SRob Suderman // support for failable rewrites. 42*363dd3f3SRob Suderman if (failableRewrite(op, rewriter)) { 43*363dd3f3SRob Suderman *hadFailure = true; 44*363dd3f3SRob Suderman return Pattern::matchFailure(); 45*363dd3f3SRob Suderman } 46*363dd3f3SRob Suderman 47*363dd3f3SRob Suderman return Pattern::matchSuccess(); 48*363dd3f3SRob Suderman } 49*363dd3f3SRob Suderman 50*363dd3f3SRob Suderman private: 51*363dd3f3SRob Suderman bool *hadFailure; 52*363dd3f3SRob Suderman 53*363dd3f3SRob Suderman bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const { 54*363dd3f3SRob Suderman auto converter = ExpressedToQuantizedConverter::forInputType(op.getType()); 55*363dd3f3SRob Suderman if (!converter) { 56*363dd3f3SRob Suderman return (op.emitError("unsupported quantized type conversion"), true); 57*363dd3f3SRob Suderman } 58*363dd3f3SRob Suderman 59*363dd3f3SRob Suderman QuantizedType elementType = 60*363dd3f3SRob Suderman static_cast<const ConcreteRewriteClass *>(this) 61*363dd3f3SRob Suderman ->convertFakeQuantAttrsToType(op, converter.expressedType); 62*363dd3f3SRob Suderman 63*363dd3f3SRob Suderman if (!elementType) { 64*363dd3f3SRob Suderman // Note that the fakeQuantAttrsToType will have emitted the error. 65*363dd3f3SRob Suderman return true; 66*363dd3f3SRob Suderman } 67*363dd3f3SRob Suderman 68*363dd3f3SRob Suderman Type quantizedType = converter.convert(elementType); 69*363dd3f3SRob Suderman assert(quantizedType && 70*363dd3f3SRob Suderman "Converter accepted a type that it did not convert"); 71*363dd3f3SRob Suderman 72*363dd3f3SRob Suderman // TODO: Map to a qbarrier with an attribute like [Forced] to signal that 73*363dd3f3SRob Suderman // this is a forced/hard-coded constraint. 74*363dd3f3SRob Suderman auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType, 75*363dd3f3SRob Suderman op.inputs()); 76*363dd3f3SRob Suderman rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType, 77*363dd3f3SRob Suderman qbarrier.getResult()); 78*363dd3f3SRob Suderman 79*363dd3f3SRob Suderman return false; 80*363dd3f3SRob Suderman } 81*363dd3f3SRob Suderman }; 82*363dd3f3SRob Suderman 83*363dd3f3SRob Suderman class ConstFakeQuantRewrite 84*363dd3f3SRob Suderman : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> { 85*363dd3f3SRob Suderman public: 86*363dd3f3SRob Suderman using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>; 87*363dd3f3SRob Suderman 88*363dd3f3SRob Suderman ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) 89*363dd3f3SRob Suderman : BaseRewrite(ctx, hadFailure) {} 90*363dd3f3SRob Suderman 91*363dd3f3SRob Suderman QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp, 92*363dd3f3SRob Suderman Type expressedType) const { 93*363dd3f3SRob Suderman return fakeQuantAttrsToType( 94*363dd3f3SRob Suderman fqOp.getLoc(), fqOp.num_bits().getSExtValue(), 95*363dd3f3SRob Suderman fqOp.min().convertToFloat(), fqOp.max().convertToFloat(), 96*363dd3f3SRob Suderman fqOp.narrow_range(), expressedType, fqOp.is_signed()); 97*363dd3f3SRob Suderman } 98*363dd3f3SRob Suderman }; 99*363dd3f3SRob Suderman 100*363dd3f3SRob Suderman class ConstFakeQuantPerAxisRewrite 101*363dd3f3SRob Suderman : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, 102*363dd3f3SRob Suderman ConstFakeQuantPerAxis> { 103*363dd3f3SRob Suderman public: 104*363dd3f3SRob Suderman using BaseRewrite = 105*363dd3f3SRob Suderman FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>; 106*363dd3f3SRob Suderman 107*363dd3f3SRob Suderman ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure) 108*363dd3f3SRob Suderman : BaseRewrite(ctx, hadFailure) {} 109*363dd3f3SRob Suderman 110*363dd3f3SRob Suderman QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp, 111*363dd3f3SRob Suderman Type expressedType) const { 112*363dd3f3SRob Suderman SmallVector<double, 4> min, max; 113*363dd3f3SRob Suderman min.reserve(fqOp.min().size()); 114*363dd3f3SRob Suderman max.reserve(fqOp.max().size()); 115*363dd3f3SRob Suderman for (auto m : fqOp.min()) 116*363dd3f3SRob Suderman min.push_back(m.cast<FloatAttr>().getValueAsDouble()); 117*363dd3f3SRob Suderman for (auto m : fqOp.max()) 118*363dd3f3SRob Suderman max.push_back(m.cast<FloatAttr>().getValueAsDouble()); 119*363dd3f3SRob Suderman 120*363dd3f3SRob Suderman return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(), 121*363dd3f3SRob Suderman fqOp.axis().getSExtValue(), min, max, 122*363dd3f3SRob Suderman fqOp.narrow_range(), expressedType, 123*363dd3f3SRob Suderman fqOp.is_signed()); 124*363dd3f3SRob Suderman } 125*363dd3f3SRob Suderman }; 126*363dd3f3SRob Suderman 127*363dd3f3SRob Suderman } // namespace 128*363dd3f3SRob Suderman 129*363dd3f3SRob Suderman void ConvertSimulatedQuantPass::runOnFunction() { 130*363dd3f3SRob Suderman bool hadFailure = false; 131*363dd3f3SRob Suderman OwningRewritePatternList patterns; 132*363dd3f3SRob Suderman auto func = getFunction(); 133*363dd3f3SRob Suderman auto ctx = func.getContext(); 134*363dd3f3SRob Suderman patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>( 135*363dd3f3SRob Suderman ctx, &hadFailure); 136*363dd3f3SRob Suderman applyPatternsGreedily(func, patterns); 137*363dd3f3SRob Suderman if (hadFailure) 138*363dd3f3SRob Suderman signalPassFailure(); 139*363dd3f3SRob Suderman } 140*363dd3f3SRob Suderman 141*363dd3f3SRob Suderman std::unique_ptr<OpPassBase<FuncOp>> 142*363dd3f3SRob Suderman mlir::quant::createConvertSimulatedQuantPass() { 143*363dd3f3SRob Suderman return std::make_unique<ConvertSimulatedQuantPass>(); 144*363dd3f3SRob Suderman } 145*363dd3f3SRob Suderman 146*363dd3f3SRob Suderman static PassRegistration<ConvertSimulatedQuantPass> 147*363dd3f3SRob Suderman pass("quant-convert-simulated-quantization", 148*363dd3f3SRob Suderman "Converts training-time simulated quantization ops to corresponding " 149*363dd3f3SRob Suderman "quantize/dequantize casts."); 150