1363dd3f3SRob Suderman //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
2363dd3f3SRob Suderman //
3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6363dd3f3SRob Suderman //
7363dd3f3SRob Suderman //===----------------------------------------------------------------------===//
8363dd3f3SRob Suderman
91834ad4aSRiver Riddle #include "PassDetail.h"
10363dd3f3SRob Suderman #include "mlir/Dialect/Quant/FakeQuantSupport.h"
11363dd3f3SRob Suderman #include "mlir/Dialect/Quant/Passes.h"
12363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.h"
13363dd3f3SRob Suderman #include "mlir/Dialect/Quant/UniformSupport.h"
1409f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
15b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16363dd3f3SRob Suderman
17363dd3f3SRob Suderman using namespace mlir;
18363dd3f3SRob Suderman using namespace mlir::quant;
19363dd3f3SRob Suderman
20363dd3f3SRob Suderman namespace {
219a277af2SRiver Riddle struct ConvertSimulatedQuantPass
221834ad4aSRiver Riddle : public QuantConvertSimulatedQuantBase<ConvertSimulatedQuantPass> {
2341574554SRiver Riddle void runOnOperation() override;
24363dd3f3SRob Suderman };
25363dd3f3SRob Suderman
26363dd3f3SRob Suderman /// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
27363dd3f3SRob Suderman template <typename ConcreteRewriteClass, typename FakeQuantOp>
28363dd3f3SRob Suderman class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> {
29363dd3f3SRob Suderman public:
30363dd3f3SRob Suderman using OpRewritePattern<FakeQuantOp>::OpRewritePattern;
31363dd3f3SRob Suderman
FakeQuantRewrite(MLIRContext * ctx,bool * hadFailure)32363dd3f3SRob Suderman FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
33363dd3f3SRob Suderman : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}
34363dd3f3SRob Suderman
matchAndRewrite(FakeQuantOp op,PatternRewriter & rewriter) const353145427dSRiver Riddle LogicalResult matchAndRewrite(FakeQuantOp op,
36363dd3f3SRob Suderman PatternRewriter &rewriter) const override {
37363dd3f3SRob Suderman // TODO: If this pattern comes up more frequently, consider adding core
38363dd3f3SRob Suderman // support for failable rewrites.
39363dd3f3SRob Suderman if (failableRewrite(op, rewriter)) {
40363dd3f3SRob Suderman *hadFailure = true;
413145427dSRiver Riddle return failure();
42363dd3f3SRob Suderman }
43363dd3f3SRob Suderman
443145427dSRiver Riddle return success();
45363dd3f3SRob Suderman }
46363dd3f3SRob Suderman
47363dd3f3SRob Suderman private:
48363dd3f3SRob Suderman bool *hadFailure;
49363dd3f3SRob Suderman
failableRewrite(FakeQuantOp op,PatternRewriter & rewriter) const50363dd3f3SRob Suderman bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const {
51363dd3f3SRob Suderman auto converter = ExpressedToQuantizedConverter::forInputType(op.getType());
52363dd3f3SRob Suderman if (!converter) {
53363dd3f3SRob Suderman return (op.emitError("unsupported quantized type conversion"), true);
54363dd3f3SRob Suderman }
55363dd3f3SRob Suderman
56363dd3f3SRob Suderman QuantizedType elementType =
57363dd3f3SRob Suderman static_cast<const ConcreteRewriteClass *>(this)
58363dd3f3SRob Suderman ->convertFakeQuantAttrsToType(op, converter.expressedType);
59363dd3f3SRob Suderman
60363dd3f3SRob Suderman if (!elementType) {
61363dd3f3SRob Suderman // Note that the fakeQuantAttrsToType will have emitted the error.
62363dd3f3SRob Suderman return true;
63363dd3f3SRob Suderman }
64363dd3f3SRob Suderman
65363dd3f3SRob Suderman Type quantizedType = converter.convert(elementType);
66363dd3f3SRob Suderman assert(quantizedType &&
67363dd3f3SRob Suderman "Converter accepted a type that it did not convert");
68363dd3f3SRob Suderman
69363dd3f3SRob Suderman // TODO: Map to a qbarrier with an attribute like [Forced] to signal that
70363dd3f3SRob Suderman // this is a forced/hard-coded constraint.
71363dd3f3SRob Suderman auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType,
72*04235d07SJacques Pienaar op.getInputs());
73363dd3f3SRob Suderman rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
74363dd3f3SRob Suderman qbarrier.getResult());
75363dd3f3SRob Suderman
76363dd3f3SRob Suderman return false;
77363dd3f3SRob Suderman }
78363dd3f3SRob Suderman };
79363dd3f3SRob Suderman
80363dd3f3SRob Suderman class ConstFakeQuantRewrite
81363dd3f3SRob Suderman : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> {
82363dd3f3SRob Suderman public:
83363dd3f3SRob Suderman using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>;
84363dd3f3SRob Suderman
ConstFakeQuantRewrite(MLIRContext * ctx,bool * hadFailure)85363dd3f3SRob Suderman ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
86363dd3f3SRob Suderman : BaseRewrite(ctx, hadFailure) {}
87363dd3f3SRob Suderman
convertFakeQuantAttrsToType(ConstFakeQuant fqOp,Type expressedType) const88363dd3f3SRob Suderman QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp,
89363dd3f3SRob Suderman Type expressedType) const {
90363dd3f3SRob Suderman return fakeQuantAttrsToType(
91*04235d07SJacques Pienaar fqOp.getLoc(), fqOp.getNumBits(), fqOp.getMin().convertToFloat(),
92*04235d07SJacques Pienaar fqOp.getMax().convertToFloat(), fqOp.getNarrowRange(), expressedType,
93*04235d07SJacques Pienaar fqOp.getIsSigned());
94363dd3f3SRob Suderman }
95363dd3f3SRob Suderman };
96363dd3f3SRob Suderman
97363dd3f3SRob Suderman class ConstFakeQuantPerAxisRewrite
98363dd3f3SRob Suderman : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite,
99363dd3f3SRob Suderman ConstFakeQuantPerAxis> {
100363dd3f3SRob Suderman public:
101363dd3f3SRob Suderman using BaseRewrite =
102363dd3f3SRob Suderman FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>;
103363dd3f3SRob Suderman
ConstFakeQuantPerAxisRewrite(MLIRContext * ctx,bool * hadFailure)104363dd3f3SRob Suderman ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure)
105363dd3f3SRob Suderman : BaseRewrite(ctx, hadFailure) {}
106363dd3f3SRob Suderman
convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp,Type expressedType) const107363dd3f3SRob Suderman QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp,
108363dd3f3SRob Suderman Type expressedType) const {
109363dd3f3SRob Suderman SmallVector<double, 4> min, max;
110*04235d07SJacques Pienaar min.reserve(fqOp.getMin().size());
111*04235d07SJacques Pienaar max.reserve(fqOp.getMax().size());
112*04235d07SJacques Pienaar for (auto m : fqOp.getMin())
113363dd3f3SRob Suderman min.push_back(m.cast<FloatAttr>().getValueAsDouble());
114*04235d07SJacques Pienaar for (auto m : fqOp.getMax())
115363dd3f3SRob Suderman max.push_back(m.cast<FloatAttr>().getValueAsDouble());
116363dd3f3SRob Suderman
117*04235d07SJacques Pienaar return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.getNumBits(),
118*04235d07SJacques Pienaar fqOp.getAxis(), min, max, fqOp.getNarrowRange(),
119*04235d07SJacques Pienaar expressedType, fqOp.getIsSigned());
120363dd3f3SRob Suderman }
121363dd3f3SRob Suderman };
122363dd3f3SRob Suderman
123363dd3f3SRob Suderman } // namespace
124363dd3f3SRob Suderman
runOnOperation()12541574554SRiver Riddle void ConvertSimulatedQuantPass::runOnOperation() {
126363dd3f3SRob Suderman bool hadFailure = false;
12741574554SRiver Riddle auto func = getOperation();
128dc4e913bSChris Lattner RewritePatternSet patterns(func.getContext());
12902b6fb21SMehdi Amini auto *ctx = func.getContext();
130dc4e913bSChris Lattner patterns.add<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
131363dd3f3SRob Suderman ctx, &hadFailure);
132e21adfa3SRiver Riddle (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
133363dd3f3SRob Suderman if (hadFailure)
134363dd3f3SRob Suderman signalPassFailure();
135363dd3f3SRob Suderman }
136363dd3f3SRob Suderman
13758ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>>
createConvertSimulatedQuantPass()138363dd3f3SRob Suderman mlir::quant::createConvertSimulatedQuantPass() {
139363dd3f3SRob Suderman return std::make_unique<ConvertSimulatedQuantPass>();
140363dd3f3SRob Suderman }
141