1363dd3f3SRob Suderman //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===// 2363dd3f3SRob Suderman // 3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information. 5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6363dd3f3SRob Suderman // 7363dd3f3SRob Suderman //===----------------------------------------------------------------------===// 8363dd3f3SRob Suderman 91834ad4aSRiver Riddle #include "PassDetail.h" 10363dd3f3SRob Suderman #include "mlir/Dialect/Quant/FakeQuantSupport.h" 11363dd3f3SRob Suderman #include "mlir/Dialect/Quant/Passes.h" 12363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.h" 13363dd3f3SRob Suderman #include "mlir/Dialect/Quant/UniformSupport.h" 1409f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h" 15b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 16363dd3f3SRob Suderman 17363dd3f3SRob Suderman using namespace mlir; 18363dd3f3SRob Suderman using namespace mlir::quant; 19363dd3f3SRob Suderman 20363dd3f3SRob Suderman namespace { 219a277af2SRiver Riddle struct ConvertSimulatedQuantPass 221834ad4aSRiver Riddle : public QuantConvertSimulatedQuantBase<ConvertSimulatedQuantPass> { 23363dd3f3SRob Suderman void runOnFunction() override; 24363dd3f3SRob Suderman }; 25363dd3f3SRob Suderman 26363dd3f3SRob Suderman /// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair. 27363dd3f3SRob Suderman template <typename ConcreteRewriteClass, typename FakeQuantOp> 28363dd3f3SRob Suderman class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> { 29363dd3f3SRob Suderman public: 30363dd3f3SRob Suderman using OpRewritePattern<FakeQuantOp>::OpRewritePattern; 31363dd3f3SRob Suderman 32363dd3f3SRob Suderman FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) 33363dd3f3SRob Suderman : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {} 34363dd3f3SRob Suderman 353145427dSRiver Riddle LogicalResult matchAndRewrite(FakeQuantOp op, 36363dd3f3SRob Suderman PatternRewriter &rewriter) const override { 37363dd3f3SRob Suderman // TODO: If this pattern comes up more frequently, consider adding core 38363dd3f3SRob Suderman // support for failable rewrites. 39363dd3f3SRob Suderman if (failableRewrite(op, rewriter)) { 40363dd3f3SRob Suderman *hadFailure = true; 413145427dSRiver Riddle return failure(); 42363dd3f3SRob Suderman } 43363dd3f3SRob Suderman 443145427dSRiver Riddle return success(); 45363dd3f3SRob Suderman } 46363dd3f3SRob Suderman 47363dd3f3SRob Suderman private: 48363dd3f3SRob Suderman bool *hadFailure; 49363dd3f3SRob Suderman 50363dd3f3SRob Suderman bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const { 51363dd3f3SRob Suderman auto converter = ExpressedToQuantizedConverter::forInputType(op.getType()); 52363dd3f3SRob Suderman if (!converter) { 53363dd3f3SRob Suderman return (op.emitError("unsupported quantized type conversion"), true); 54363dd3f3SRob Suderman } 55363dd3f3SRob Suderman 56363dd3f3SRob Suderman QuantizedType elementType = 57363dd3f3SRob Suderman static_cast<const ConcreteRewriteClass *>(this) 58363dd3f3SRob Suderman ->convertFakeQuantAttrsToType(op, converter.expressedType); 59363dd3f3SRob Suderman 60363dd3f3SRob Suderman if (!elementType) { 61363dd3f3SRob Suderman // Note that the fakeQuantAttrsToType will have emitted the error. 62363dd3f3SRob Suderman return true; 63363dd3f3SRob Suderman } 64363dd3f3SRob Suderman 65363dd3f3SRob Suderman Type quantizedType = converter.convert(elementType); 66363dd3f3SRob Suderman assert(quantizedType && 67363dd3f3SRob Suderman "Converter accepted a type that it did not convert"); 68363dd3f3SRob Suderman 69363dd3f3SRob Suderman // TODO: Map to a qbarrier with an attribute like [Forced] to signal that 70363dd3f3SRob Suderman // this is a forced/hard-coded constraint. 71363dd3f3SRob Suderman auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType, 72363dd3f3SRob Suderman op.inputs()); 73363dd3f3SRob Suderman rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType, 74363dd3f3SRob Suderman qbarrier.getResult()); 75363dd3f3SRob Suderman 76363dd3f3SRob Suderman return false; 77363dd3f3SRob Suderman } 78363dd3f3SRob Suderman }; 79363dd3f3SRob Suderman 80363dd3f3SRob Suderman class ConstFakeQuantRewrite 81363dd3f3SRob Suderman : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> { 82363dd3f3SRob Suderman public: 83363dd3f3SRob Suderman using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>; 84363dd3f3SRob Suderman 85363dd3f3SRob Suderman ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) 86363dd3f3SRob Suderman : BaseRewrite(ctx, hadFailure) {} 87363dd3f3SRob Suderman 88363dd3f3SRob Suderman QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp, 89363dd3f3SRob Suderman Type expressedType) const { 90363dd3f3SRob Suderman return fakeQuantAttrsToType( 91431bb8b3SRiver Riddle fqOp.getLoc(), fqOp.num_bits(), fqOp.min().convertToFloat(), 92431bb8b3SRiver Riddle fqOp.max().convertToFloat(), fqOp.narrow_range(), expressedType, 93431bb8b3SRiver Riddle fqOp.is_signed()); 94363dd3f3SRob Suderman } 95363dd3f3SRob Suderman }; 96363dd3f3SRob Suderman 97363dd3f3SRob Suderman class ConstFakeQuantPerAxisRewrite 98363dd3f3SRob Suderman : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, 99363dd3f3SRob Suderman ConstFakeQuantPerAxis> { 100363dd3f3SRob Suderman public: 101363dd3f3SRob Suderman using BaseRewrite = 102363dd3f3SRob Suderman FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>; 103363dd3f3SRob Suderman 104363dd3f3SRob Suderman ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure) 105363dd3f3SRob Suderman : BaseRewrite(ctx, hadFailure) {} 106363dd3f3SRob Suderman 107363dd3f3SRob Suderman QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp, 108363dd3f3SRob Suderman Type expressedType) const { 109363dd3f3SRob Suderman SmallVector<double, 4> min, max; 110363dd3f3SRob Suderman min.reserve(fqOp.min().size()); 111363dd3f3SRob Suderman max.reserve(fqOp.max().size()); 112363dd3f3SRob Suderman for (auto m : fqOp.min()) 113363dd3f3SRob Suderman min.push_back(m.cast<FloatAttr>().getValueAsDouble()); 114363dd3f3SRob Suderman for (auto m : fqOp.max()) 115363dd3f3SRob Suderman max.push_back(m.cast<FloatAttr>().getValueAsDouble()); 116363dd3f3SRob Suderman 117431bb8b3SRiver Riddle return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits(), fqOp.axis(), 118431bb8b3SRiver Riddle min, max, fqOp.narrow_range(), expressedType, 119363dd3f3SRob Suderman fqOp.is_signed()); 120363dd3f3SRob Suderman } 121363dd3f3SRob Suderman }; 122363dd3f3SRob Suderman 123363dd3f3SRob Suderman } // namespace 124363dd3f3SRob Suderman 125363dd3f3SRob Suderman void ConvertSimulatedQuantPass::runOnFunction() { 126363dd3f3SRob Suderman bool hadFailure = false; 127363dd3f3SRob Suderman auto func = getFunction(); 128*dc4e913bSChris Lattner RewritePatternSet patterns(func.getContext()); 129363dd3f3SRob Suderman auto ctx = func.getContext(); 130*dc4e913bSChris Lattner patterns.add<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>( 131363dd3f3SRob Suderman ctx, &hadFailure); 132e21adfa3SRiver Riddle (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 133363dd3f3SRob Suderman if (hadFailure) 134363dd3f3SRob Suderman signalPassFailure(); 135363dd3f3SRob Suderman } 136363dd3f3SRob Suderman 13780aca1eaSRiver Riddle std::unique_ptr<OperationPass<FuncOp>> 138363dd3f3SRob Suderman mlir::quant::createConvertSimulatedQuantPass() { 139363dd3f3SRob Suderman return std::make_unique<ConvertSimulatedQuantPass>(); 140363dd3f3SRob Suderman } 141