1363dd3f3SRob Suderman //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===// 2363dd3f3SRob Suderman // 3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information. 5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6363dd3f3SRob Suderman // 7363dd3f3SRob Suderman //===----------------------------------------------------------------------===// 8363dd3f3SRob Suderman 91834ad4aSRiver Riddle #include "PassDetail.h" 10363dd3f3SRob Suderman #include "mlir/Dialect/Quant/FakeQuantSupport.h" 11363dd3f3SRob Suderman #include "mlir/Dialect/Quant/Passes.h" 12363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.h" 13363dd3f3SRob Suderman #include "mlir/Dialect/Quant/UniformSupport.h" 14363dd3f3SRob Suderman #include "mlir/IR/Attributes.h" 15363dd3f3SRob Suderman #include "mlir/IR/PatternMatch.h" 16363dd3f3SRob Suderman #include "mlir/IR/StandardTypes.h" 17363dd3f3SRob Suderman 18363dd3f3SRob Suderman using namespace mlir; 19363dd3f3SRob Suderman using namespace mlir::quant; 20363dd3f3SRob Suderman 21363dd3f3SRob Suderman namespace { 229a277af2SRiver Riddle struct ConvertSimulatedQuantPass 231834ad4aSRiver Riddle : public QuantConvertSimulatedQuantBase<ConvertSimulatedQuantPass> { 24363dd3f3SRob Suderman void runOnFunction() override; 25363dd3f3SRob Suderman }; 26363dd3f3SRob Suderman 27363dd3f3SRob Suderman /// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair. 28363dd3f3SRob Suderman template <typename ConcreteRewriteClass, typename FakeQuantOp> 29363dd3f3SRob Suderman class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> { 30363dd3f3SRob Suderman public: 31363dd3f3SRob Suderman using OpRewritePattern<FakeQuantOp>::OpRewritePattern; 32363dd3f3SRob Suderman 33363dd3f3SRob Suderman FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) 34363dd3f3SRob Suderman : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {} 35363dd3f3SRob Suderman 363145427dSRiver Riddle LogicalResult matchAndRewrite(FakeQuantOp op, 37363dd3f3SRob Suderman PatternRewriter &rewriter) const override { 38363dd3f3SRob Suderman // TODO: If this pattern comes up more frequently, consider adding core 39363dd3f3SRob Suderman // support for failable rewrites. 40363dd3f3SRob Suderman if (failableRewrite(op, rewriter)) { 41363dd3f3SRob Suderman *hadFailure = true; 423145427dSRiver Riddle return failure(); 43363dd3f3SRob Suderman } 44363dd3f3SRob Suderman 453145427dSRiver Riddle return success(); 46363dd3f3SRob Suderman } 47363dd3f3SRob Suderman 48363dd3f3SRob Suderman private: 49363dd3f3SRob Suderman bool *hadFailure; 50363dd3f3SRob Suderman 51363dd3f3SRob Suderman bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const { 52363dd3f3SRob Suderman auto converter = ExpressedToQuantizedConverter::forInputType(op.getType()); 53363dd3f3SRob Suderman if (!converter) { 54363dd3f3SRob Suderman return (op.emitError("unsupported quantized type conversion"), true); 55363dd3f3SRob Suderman } 56363dd3f3SRob Suderman 57363dd3f3SRob Suderman QuantizedType elementType = 58363dd3f3SRob Suderman static_cast<const ConcreteRewriteClass *>(this) 59363dd3f3SRob Suderman ->convertFakeQuantAttrsToType(op, converter.expressedType); 60363dd3f3SRob Suderman 61363dd3f3SRob Suderman if (!elementType) { 62363dd3f3SRob Suderman // Note that the fakeQuantAttrsToType will have emitted the error. 63363dd3f3SRob Suderman return true; 64363dd3f3SRob Suderman } 65363dd3f3SRob Suderman 66363dd3f3SRob Suderman Type quantizedType = converter.convert(elementType); 67363dd3f3SRob Suderman assert(quantizedType && 68363dd3f3SRob Suderman "Converter accepted a type that it did not convert"); 69363dd3f3SRob Suderman 70363dd3f3SRob Suderman // TODO: Map to a qbarrier with an attribute like [Forced] to signal that 71363dd3f3SRob Suderman // this is a forced/hard-coded constraint. 72363dd3f3SRob Suderman auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType, 73363dd3f3SRob Suderman op.inputs()); 74363dd3f3SRob Suderman rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType, 75363dd3f3SRob Suderman qbarrier.getResult()); 76363dd3f3SRob Suderman 77363dd3f3SRob Suderman return false; 78363dd3f3SRob Suderman } 79363dd3f3SRob Suderman }; 80363dd3f3SRob Suderman 81363dd3f3SRob Suderman class ConstFakeQuantRewrite 82363dd3f3SRob Suderman : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> { 83363dd3f3SRob Suderman public: 84363dd3f3SRob Suderman using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>; 85363dd3f3SRob Suderman 86363dd3f3SRob Suderman ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) 87363dd3f3SRob Suderman : BaseRewrite(ctx, hadFailure) {} 88363dd3f3SRob Suderman 89363dd3f3SRob Suderman QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp, 90363dd3f3SRob Suderman Type expressedType) const { 91363dd3f3SRob Suderman return fakeQuantAttrsToType( 92*431bb8b3SRiver Riddle fqOp.getLoc(), fqOp.num_bits(), fqOp.min().convertToFloat(), 93*431bb8b3SRiver Riddle fqOp.max().convertToFloat(), fqOp.narrow_range(), expressedType, 94*431bb8b3SRiver Riddle fqOp.is_signed()); 95363dd3f3SRob Suderman } 96363dd3f3SRob Suderman }; 97363dd3f3SRob Suderman 98363dd3f3SRob Suderman class ConstFakeQuantPerAxisRewrite 99363dd3f3SRob Suderman : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, 100363dd3f3SRob Suderman ConstFakeQuantPerAxis> { 101363dd3f3SRob Suderman public: 102363dd3f3SRob Suderman using BaseRewrite = 103363dd3f3SRob Suderman FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>; 104363dd3f3SRob Suderman 105363dd3f3SRob Suderman ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure) 106363dd3f3SRob Suderman : BaseRewrite(ctx, hadFailure) {} 107363dd3f3SRob Suderman 108363dd3f3SRob Suderman QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp, 109363dd3f3SRob Suderman Type expressedType) const { 110363dd3f3SRob Suderman SmallVector<double, 4> min, max; 111363dd3f3SRob Suderman min.reserve(fqOp.min().size()); 112363dd3f3SRob Suderman max.reserve(fqOp.max().size()); 113363dd3f3SRob Suderman for (auto m : fqOp.min()) 114363dd3f3SRob Suderman min.push_back(m.cast<FloatAttr>().getValueAsDouble()); 115363dd3f3SRob Suderman for (auto m : fqOp.max()) 116363dd3f3SRob Suderman max.push_back(m.cast<FloatAttr>().getValueAsDouble()); 117363dd3f3SRob Suderman 118*431bb8b3SRiver Riddle return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits(), fqOp.axis(), 119*431bb8b3SRiver Riddle min, max, fqOp.narrow_range(), expressedType, 120363dd3f3SRob Suderman fqOp.is_signed()); 121363dd3f3SRob Suderman } 122363dd3f3SRob Suderman }; 123363dd3f3SRob Suderman 124363dd3f3SRob Suderman } // namespace 125363dd3f3SRob Suderman 126363dd3f3SRob Suderman void ConvertSimulatedQuantPass::runOnFunction() { 127363dd3f3SRob Suderman bool hadFailure = false; 128363dd3f3SRob Suderman OwningRewritePatternList patterns; 129363dd3f3SRob Suderman auto func = getFunction(); 130363dd3f3SRob Suderman auto ctx = func.getContext(); 131363dd3f3SRob Suderman patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>( 132363dd3f3SRob Suderman ctx, &hadFailure); 133a5b9316bSUday Bondhugula applyPatternsAndFoldGreedily(func, patterns); 134363dd3f3SRob Suderman if (hadFailure) 135363dd3f3SRob Suderman signalPassFailure(); 136363dd3f3SRob Suderman } 137363dd3f3SRob Suderman 13880aca1eaSRiver Riddle std::unique_ptr<OperationPass<FuncOp>> 139363dd3f3SRob Suderman mlir::quant::createConvertSimulatedQuantPass() { 140363dd3f3SRob Suderman return std::make_unique<ConvertSimulatedQuantPass>(); 141363dd3f3SRob Suderman } 142