1363dd3f3SRob Suderman //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===// 2363dd3f3SRob Suderman // 3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information. 5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6363dd3f3SRob Suderman // 7363dd3f3SRob Suderman //===----------------------------------------------------------------------===// 8363dd3f3SRob Suderman 9363dd3f3SRob Suderman #include "mlir/Dialect/Quant/FakeQuantSupport.h" 10363dd3f3SRob Suderman #include "mlir/Dialect/Quant/Passes.h" 11363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.h" 12363dd3f3SRob Suderman #include "mlir/Dialect/Quant/UniformSupport.h" 13363dd3f3SRob Suderman #include "mlir/IR/Attributes.h" 14363dd3f3SRob Suderman #include "mlir/IR/PatternMatch.h" 15363dd3f3SRob Suderman #include "mlir/IR/StandardTypes.h" 16363dd3f3SRob Suderman #include "mlir/Pass/Pass.h" 17363dd3f3SRob Suderman 18363dd3f3SRob Suderman using namespace mlir; 19363dd3f3SRob Suderman using namespace mlir::quant; 20363dd3f3SRob Suderman 21363dd3f3SRob Suderman namespace { 229a277af2SRiver Riddle struct ConvertSimulatedQuantPass 23*80aca1eaSRiver Riddle : public PassWrapper<ConvertSimulatedQuantPass, FunctionPass> { 249a277af2SRiver Riddle /// Include the generated pass utilities. 259a277af2SRiver Riddle #define GEN_PASS_QuantConvertSimulatedQuant 269a277af2SRiver Riddle #include "mlir/Dialect/Quant/Passes.h.inc" 279a277af2SRiver Riddle 28363dd3f3SRob Suderman void runOnFunction() override; 29363dd3f3SRob Suderman }; 30363dd3f3SRob Suderman 31363dd3f3SRob Suderman /// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair. 32363dd3f3SRob Suderman template <typename ConcreteRewriteClass, typename FakeQuantOp> 33363dd3f3SRob Suderman class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> { 34363dd3f3SRob Suderman public: 35363dd3f3SRob Suderman using OpRewritePattern<FakeQuantOp>::OpRewritePattern; 36363dd3f3SRob Suderman 37363dd3f3SRob Suderman FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) 38363dd3f3SRob Suderman : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {} 39363dd3f3SRob Suderman 403145427dSRiver Riddle LogicalResult matchAndRewrite(FakeQuantOp op, 41363dd3f3SRob Suderman PatternRewriter &rewriter) const override { 42363dd3f3SRob Suderman // TODO: If this pattern comes up more frequently, consider adding core 43363dd3f3SRob Suderman // support for failable rewrites. 44363dd3f3SRob Suderman if (failableRewrite(op, rewriter)) { 45363dd3f3SRob Suderman *hadFailure = true; 463145427dSRiver Riddle return failure(); 47363dd3f3SRob Suderman } 48363dd3f3SRob Suderman 493145427dSRiver Riddle return success(); 50363dd3f3SRob Suderman } 51363dd3f3SRob Suderman 52363dd3f3SRob Suderman private: 53363dd3f3SRob Suderman bool *hadFailure; 54363dd3f3SRob Suderman 55363dd3f3SRob Suderman bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const { 56363dd3f3SRob Suderman auto converter = ExpressedToQuantizedConverter::forInputType(op.getType()); 57363dd3f3SRob Suderman if (!converter) { 58363dd3f3SRob Suderman return (op.emitError("unsupported quantized type conversion"), true); 59363dd3f3SRob Suderman } 60363dd3f3SRob Suderman 61363dd3f3SRob Suderman QuantizedType elementType = 62363dd3f3SRob Suderman static_cast<const ConcreteRewriteClass *>(this) 63363dd3f3SRob Suderman ->convertFakeQuantAttrsToType(op, converter.expressedType); 64363dd3f3SRob Suderman 65363dd3f3SRob Suderman if (!elementType) { 66363dd3f3SRob Suderman // Note that the fakeQuantAttrsToType will have emitted the error. 67363dd3f3SRob Suderman return true; 68363dd3f3SRob Suderman } 69363dd3f3SRob Suderman 70363dd3f3SRob Suderman Type quantizedType = converter.convert(elementType); 71363dd3f3SRob Suderman assert(quantizedType && 72363dd3f3SRob Suderman "Converter accepted a type that it did not convert"); 73363dd3f3SRob Suderman 74363dd3f3SRob Suderman // TODO: Map to a qbarrier with an attribute like [Forced] to signal that 75363dd3f3SRob Suderman // this is a forced/hard-coded constraint. 76363dd3f3SRob Suderman auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType, 77363dd3f3SRob Suderman op.inputs()); 78363dd3f3SRob Suderman rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType, 79363dd3f3SRob Suderman qbarrier.getResult()); 80363dd3f3SRob Suderman 81363dd3f3SRob Suderman return false; 82363dd3f3SRob Suderman } 83363dd3f3SRob Suderman }; 84363dd3f3SRob Suderman 85363dd3f3SRob Suderman class ConstFakeQuantRewrite 86363dd3f3SRob Suderman : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> { 87363dd3f3SRob Suderman public: 88363dd3f3SRob Suderman using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>; 89363dd3f3SRob Suderman 90363dd3f3SRob Suderman ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) 91363dd3f3SRob Suderman : BaseRewrite(ctx, hadFailure) {} 92363dd3f3SRob Suderman 93363dd3f3SRob Suderman QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp, 94363dd3f3SRob Suderman Type expressedType) const { 95363dd3f3SRob Suderman return fakeQuantAttrsToType( 96363dd3f3SRob Suderman fqOp.getLoc(), fqOp.num_bits().getSExtValue(), 97363dd3f3SRob Suderman fqOp.min().convertToFloat(), fqOp.max().convertToFloat(), 98363dd3f3SRob Suderman fqOp.narrow_range(), expressedType, fqOp.is_signed()); 99363dd3f3SRob Suderman } 100363dd3f3SRob Suderman }; 101363dd3f3SRob Suderman 102363dd3f3SRob Suderman class ConstFakeQuantPerAxisRewrite 103363dd3f3SRob Suderman : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, 104363dd3f3SRob Suderman ConstFakeQuantPerAxis> { 105363dd3f3SRob Suderman public: 106363dd3f3SRob Suderman using BaseRewrite = 107363dd3f3SRob Suderman FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>; 108363dd3f3SRob Suderman 109363dd3f3SRob Suderman ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure) 110363dd3f3SRob Suderman : BaseRewrite(ctx, hadFailure) {} 111363dd3f3SRob Suderman 112363dd3f3SRob Suderman QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp, 113363dd3f3SRob Suderman Type expressedType) const { 114363dd3f3SRob Suderman SmallVector<double, 4> min, max; 115363dd3f3SRob Suderman min.reserve(fqOp.min().size()); 116363dd3f3SRob Suderman max.reserve(fqOp.max().size()); 117363dd3f3SRob Suderman for (auto m : fqOp.min()) 118363dd3f3SRob Suderman min.push_back(m.cast<FloatAttr>().getValueAsDouble()); 119363dd3f3SRob Suderman for (auto m : fqOp.max()) 120363dd3f3SRob Suderman max.push_back(m.cast<FloatAttr>().getValueAsDouble()); 121363dd3f3SRob Suderman 122363dd3f3SRob Suderman return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(), 123363dd3f3SRob Suderman fqOp.axis().getSExtValue(), min, max, 124363dd3f3SRob Suderman fqOp.narrow_range(), expressedType, 125363dd3f3SRob Suderman fqOp.is_signed()); 126363dd3f3SRob Suderman } 127363dd3f3SRob Suderman }; 128363dd3f3SRob Suderman 129363dd3f3SRob Suderman } // namespace 130363dd3f3SRob Suderman 131363dd3f3SRob Suderman void ConvertSimulatedQuantPass::runOnFunction() { 132363dd3f3SRob Suderman bool hadFailure = false; 133363dd3f3SRob Suderman OwningRewritePatternList patterns; 134363dd3f3SRob Suderman auto func = getFunction(); 135363dd3f3SRob Suderman auto ctx = func.getContext(); 136363dd3f3SRob Suderman patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>( 137363dd3f3SRob Suderman ctx, &hadFailure); 138363dd3f3SRob Suderman applyPatternsGreedily(func, patterns); 139363dd3f3SRob Suderman if (hadFailure) 140363dd3f3SRob Suderman signalPassFailure(); 141363dd3f3SRob Suderman } 142363dd3f3SRob Suderman 143*80aca1eaSRiver Riddle std::unique_ptr<OperationPass<FuncOp>> 144363dd3f3SRob Suderman mlir::quant::createConvertSimulatedQuantPass() { 145363dd3f3SRob Suderman return std::make_unique<ConvertSimulatedQuantPass>(); 146363dd3f3SRob Suderman } 147