1363dd3f3SRob Suderman //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
2363dd3f3SRob Suderman //
3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6363dd3f3SRob Suderman //
7363dd3f3SRob Suderman //===----------------------------------------------------------------------===//
8363dd3f3SRob Suderman 
91834ad4aSRiver Riddle #include "PassDetail.h"
10363dd3f3SRob Suderman #include "mlir/Dialect/Quant/FakeQuantSupport.h"
11363dd3f3SRob Suderman #include "mlir/Dialect/Quant/Passes.h"
12363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.h"
13363dd3f3SRob Suderman #include "mlir/Dialect/Quant/UniformSupport.h"
1409f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
15b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16363dd3f3SRob Suderman 
17363dd3f3SRob Suderman using namespace mlir;
18363dd3f3SRob Suderman using namespace mlir::quant;
19363dd3f3SRob Suderman 
20363dd3f3SRob Suderman namespace {
219a277af2SRiver Riddle struct ConvertSimulatedQuantPass
221834ad4aSRiver Riddle     : public QuantConvertSimulatedQuantBase<ConvertSimulatedQuantPass> {
2341574554SRiver Riddle   void runOnOperation() override;
24363dd3f3SRob Suderman };
25363dd3f3SRob Suderman 
26363dd3f3SRob Suderman /// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
27363dd3f3SRob Suderman template <typename ConcreteRewriteClass, typename FakeQuantOp>
28363dd3f3SRob Suderman class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> {
29363dd3f3SRob Suderman public:
30363dd3f3SRob Suderman   using OpRewritePattern<FakeQuantOp>::OpRewritePattern;
31363dd3f3SRob Suderman 
FakeQuantRewrite(MLIRContext * ctx,bool * hadFailure)32363dd3f3SRob Suderman   FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
33363dd3f3SRob Suderman       : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}
34363dd3f3SRob Suderman 
matchAndRewrite(FakeQuantOp op,PatternRewriter & rewriter) const353145427dSRiver Riddle   LogicalResult matchAndRewrite(FakeQuantOp op,
36363dd3f3SRob Suderman                                 PatternRewriter &rewriter) const override {
37363dd3f3SRob Suderman     // TODO: If this pattern comes up more frequently, consider adding core
38363dd3f3SRob Suderman     // support for failable rewrites.
39363dd3f3SRob Suderman     if (failableRewrite(op, rewriter)) {
40363dd3f3SRob Suderman       *hadFailure = true;
413145427dSRiver Riddle       return failure();
42363dd3f3SRob Suderman     }
43363dd3f3SRob Suderman 
443145427dSRiver Riddle     return success();
45363dd3f3SRob Suderman   }
46363dd3f3SRob Suderman 
47363dd3f3SRob Suderman private:
48363dd3f3SRob Suderman   bool *hadFailure;
49363dd3f3SRob Suderman 
failableRewrite(FakeQuantOp op,PatternRewriter & rewriter) const50363dd3f3SRob Suderman   bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const {
51363dd3f3SRob Suderman     auto converter = ExpressedToQuantizedConverter::forInputType(op.getType());
52363dd3f3SRob Suderman     if (!converter) {
53363dd3f3SRob Suderman       return (op.emitError("unsupported quantized type conversion"), true);
54363dd3f3SRob Suderman     }
55363dd3f3SRob Suderman 
56363dd3f3SRob Suderman     QuantizedType elementType =
57363dd3f3SRob Suderman         static_cast<const ConcreteRewriteClass *>(this)
58363dd3f3SRob Suderman             ->convertFakeQuantAttrsToType(op, converter.expressedType);
59363dd3f3SRob Suderman 
60363dd3f3SRob Suderman     if (!elementType) {
61363dd3f3SRob Suderman       // Note that the fakeQuantAttrsToType will have emitted the error.
62363dd3f3SRob Suderman       return true;
63363dd3f3SRob Suderman     }
64363dd3f3SRob Suderman 
65363dd3f3SRob Suderman     Type quantizedType = converter.convert(elementType);
66363dd3f3SRob Suderman     assert(quantizedType &&
67363dd3f3SRob Suderman            "Converter accepted a type that it did not convert");
68363dd3f3SRob Suderman 
69363dd3f3SRob Suderman     // TODO: Map to a qbarrier with an attribute like [Forced] to signal that
70363dd3f3SRob Suderman     // this is a forced/hard-coded constraint.
71363dd3f3SRob Suderman     auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType,
72*04235d07SJacques Pienaar                                                     op.getInputs());
73363dd3f3SRob Suderman     rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
74363dd3f3SRob Suderman                                                   qbarrier.getResult());
75363dd3f3SRob Suderman 
76363dd3f3SRob Suderman     return false;
77363dd3f3SRob Suderman   }
78363dd3f3SRob Suderman };
79363dd3f3SRob Suderman 
80363dd3f3SRob Suderman class ConstFakeQuantRewrite
81363dd3f3SRob Suderman     : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> {
82363dd3f3SRob Suderman public:
83363dd3f3SRob Suderman   using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>;
84363dd3f3SRob Suderman 
ConstFakeQuantRewrite(MLIRContext * ctx,bool * hadFailure)85363dd3f3SRob Suderman   ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
86363dd3f3SRob Suderman       : BaseRewrite(ctx, hadFailure) {}
87363dd3f3SRob Suderman 
convertFakeQuantAttrsToType(ConstFakeQuant fqOp,Type expressedType) const88363dd3f3SRob Suderman   QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp,
89363dd3f3SRob Suderman                                             Type expressedType) const {
90363dd3f3SRob Suderman     return fakeQuantAttrsToType(
91*04235d07SJacques Pienaar         fqOp.getLoc(), fqOp.getNumBits(), fqOp.getMin().convertToFloat(),
92*04235d07SJacques Pienaar         fqOp.getMax().convertToFloat(), fqOp.getNarrowRange(), expressedType,
93*04235d07SJacques Pienaar         fqOp.getIsSigned());
94363dd3f3SRob Suderman   }
95363dd3f3SRob Suderman };
96363dd3f3SRob Suderman 
97363dd3f3SRob Suderman class ConstFakeQuantPerAxisRewrite
98363dd3f3SRob Suderman     : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite,
99363dd3f3SRob Suderman                               ConstFakeQuantPerAxis> {
100363dd3f3SRob Suderman public:
101363dd3f3SRob Suderman   using BaseRewrite =
102363dd3f3SRob Suderman       FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>;
103363dd3f3SRob Suderman 
ConstFakeQuantPerAxisRewrite(MLIRContext * ctx,bool * hadFailure)104363dd3f3SRob Suderman   ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure)
105363dd3f3SRob Suderman       : BaseRewrite(ctx, hadFailure) {}
106363dd3f3SRob Suderman 
convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp,Type expressedType) const107363dd3f3SRob Suderman   QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp,
108363dd3f3SRob Suderman                                             Type expressedType) const {
109363dd3f3SRob Suderman     SmallVector<double, 4> min, max;
110*04235d07SJacques Pienaar     min.reserve(fqOp.getMin().size());
111*04235d07SJacques Pienaar     max.reserve(fqOp.getMax().size());
112*04235d07SJacques Pienaar     for (auto m : fqOp.getMin())
113363dd3f3SRob Suderman       min.push_back(m.cast<FloatAttr>().getValueAsDouble());
114*04235d07SJacques Pienaar     for (auto m : fqOp.getMax())
115363dd3f3SRob Suderman       max.push_back(m.cast<FloatAttr>().getValueAsDouble());
116363dd3f3SRob Suderman 
117*04235d07SJacques Pienaar     return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.getNumBits(),
118*04235d07SJacques Pienaar                                 fqOp.getAxis(), min, max, fqOp.getNarrowRange(),
119*04235d07SJacques Pienaar                                 expressedType, fqOp.getIsSigned());
120363dd3f3SRob Suderman   }
121363dd3f3SRob Suderman };
122363dd3f3SRob Suderman 
123363dd3f3SRob Suderman } // namespace
124363dd3f3SRob Suderman 
runOnOperation()12541574554SRiver Riddle void ConvertSimulatedQuantPass::runOnOperation() {
126363dd3f3SRob Suderman   bool hadFailure = false;
12741574554SRiver Riddle   auto func = getOperation();
128dc4e913bSChris Lattner   RewritePatternSet patterns(func.getContext());
12902b6fb21SMehdi Amini   auto *ctx = func.getContext();
130dc4e913bSChris Lattner   patterns.add<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
131363dd3f3SRob Suderman       ctx, &hadFailure);
132e21adfa3SRiver Riddle   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
133363dd3f3SRob Suderman   if (hadFailure)
134363dd3f3SRob Suderman     signalPassFailure();
135363dd3f3SRob Suderman }
136363dd3f3SRob Suderman 
13758ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>>
createConvertSimulatedQuantPass()138363dd3f3SRob Suderman mlir::quant::createConvertSimulatedQuantPass() {
139363dd3f3SRob Suderman   return std::make_unique<ConvertSimulatedQuantPass>();
140363dd3f3SRob Suderman }
141