1*363dd3f3SRob Suderman //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
2*363dd3f3SRob Suderman //
3*363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5*363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*363dd3f3SRob Suderman //
7*363dd3f3SRob Suderman //===----------------------------------------------------------------------===//
8*363dd3f3SRob Suderman 
9*363dd3f3SRob Suderman #include "mlir/Dialect/Quant/FakeQuantSupport.h"
10*363dd3f3SRob Suderman #include "mlir/Dialect/Quant/Passes.h"
11*363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.h"
12*363dd3f3SRob Suderman #include "mlir/Dialect/Quant/UniformSupport.h"
13*363dd3f3SRob Suderman #include "mlir/IR/Attributes.h"
14*363dd3f3SRob Suderman #include "mlir/IR/PatternMatch.h"
15*363dd3f3SRob Suderman #include "mlir/IR/StandardTypes.h"
16*363dd3f3SRob Suderman #include "mlir/Pass/Pass.h"
17*363dd3f3SRob Suderman 
18*363dd3f3SRob Suderman using namespace mlir;
19*363dd3f3SRob Suderman using namespace mlir::quant;
20*363dd3f3SRob Suderman 
21*363dd3f3SRob Suderman namespace {
22*363dd3f3SRob Suderman 
23*363dd3f3SRob Suderman class ConvertSimulatedQuantPass
24*363dd3f3SRob Suderman     : public FunctionPass<ConvertSimulatedQuantPass> {
25*363dd3f3SRob Suderman public:
26*363dd3f3SRob Suderman   void runOnFunction() override;
27*363dd3f3SRob Suderman };
28*363dd3f3SRob Suderman 
29*363dd3f3SRob Suderman /// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
30*363dd3f3SRob Suderman template <typename ConcreteRewriteClass, typename FakeQuantOp>
31*363dd3f3SRob Suderman class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> {
32*363dd3f3SRob Suderman public:
33*363dd3f3SRob Suderman   using OpRewritePattern<FakeQuantOp>::OpRewritePattern;
34*363dd3f3SRob Suderman 
35*363dd3f3SRob Suderman   FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
36*363dd3f3SRob Suderman       : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}
37*363dd3f3SRob Suderman 
38*363dd3f3SRob Suderman   PatternMatchResult matchAndRewrite(FakeQuantOp op,
39*363dd3f3SRob Suderman                                      PatternRewriter &rewriter) const override {
40*363dd3f3SRob Suderman     // TODO: If this pattern comes up more frequently, consider adding core
41*363dd3f3SRob Suderman     // support for failable rewrites.
42*363dd3f3SRob Suderman     if (failableRewrite(op, rewriter)) {
43*363dd3f3SRob Suderman       *hadFailure = true;
44*363dd3f3SRob Suderman       return Pattern::matchFailure();
45*363dd3f3SRob Suderman     }
46*363dd3f3SRob Suderman 
47*363dd3f3SRob Suderman     return Pattern::matchSuccess();
48*363dd3f3SRob Suderman   }
49*363dd3f3SRob Suderman 
50*363dd3f3SRob Suderman private:
51*363dd3f3SRob Suderman   bool *hadFailure;
52*363dd3f3SRob Suderman 
53*363dd3f3SRob Suderman   bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const {
54*363dd3f3SRob Suderman     auto converter = ExpressedToQuantizedConverter::forInputType(op.getType());
55*363dd3f3SRob Suderman     if (!converter) {
56*363dd3f3SRob Suderman       return (op.emitError("unsupported quantized type conversion"), true);
57*363dd3f3SRob Suderman     }
58*363dd3f3SRob Suderman 
59*363dd3f3SRob Suderman     QuantizedType elementType =
60*363dd3f3SRob Suderman         static_cast<const ConcreteRewriteClass *>(this)
61*363dd3f3SRob Suderman             ->convertFakeQuantAttrsToType(op, converter.expressedType);
62*363dd3f3SRob Suderman 
63*363dd3f3SRob Suderman     if (!elementType) {
64*363dd3f3SRob Suderman       // Note that the fakeQuantAttrsToType will have emitted the error.
65*363dd3f3SRob Suderman       return true;
66*363dd3f3SRob Suderman     }
67*363dd3f3SRob Suderman 
68*363dd3f3SRob Suderman     Type quantizedType = converter.convert(elementType);
69*363dd3f3SRob Suderman     assert(quantizedType &&
70*363dd3f3SRob Suderman            "Converter accepted a type that it did not convert");
71*363dd3f3SRob Suderman 
72*363dd3f3SRob Suderman     // TODO: Map to a qbarrier with an attribute like [Forced] to signal that
73*363dd3f3SRob Suderman     // this is a forced/hard-coded constraint.
74*363dd3f3SRob Suderman     auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType,
75*363dd3f3SRob Suderman                                                     op.inputs());
76*363dd3f3SRob Suderman     rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
77*363dd3f3SRob Suderman                                                   qbarrier.getResult());
78*363dd3f3SRob Suderman 
79*363dd3f3SRob Suderman     return false;
80*363dd3f3SRob Suderman   }
81*363dd3f3SRob Suderman };
82*363dd3f3SRob Suderman 
83*363dd3f3SRob Suderman class ConstFakeQuantRewrite
84*363dd3f3SRob Suderman     : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> {
85*363dd3f3SRob Suderman public:
86*363dd3f3SRob Suderman   using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>;
87*363dd3f3SRob Suderman 
88*363dd3f3SRob Suderman   ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
89*363dd3f3SRob Suderman       : BaseRewrite(ctx, hadFailure) {}
90*363dd3f3SRob Suderman 
91*363dd3f3SRob Suderman   QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp,
92*363dd3f3SRob Suderman                                             Type expressedType) const {
93*363dd3f3SRob Suderman     return fakeQuantAttrsToType(
94*363dd3f3SRob Suderman         fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
95*363dd3f3SRob Suderman         fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
96*363dd3f3SRob Suderman         fqOp.narrow_range(), expressedType, fqOp.is_signed());
97*363dd3f3SRob Suderman   }
98*363dd3f3SRob Suderman };
99*363dd3f3SRob Suderman 
100*363dd3f3SRob Suderman class ConstFakeQuantPerAxisRewrite
101*363dd3f3SRob Suderman     : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite,
102*363dd3f3SRob Suderman                               ConstFakeQuantPerAxis> {
103*363dd3f3SRob Suderman public:
104*363dd3f3SRob Suderman   using BaseRewrite =
105*363dd3f3SRob Suderman       FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>;
106*363dd3f3SRob Suderman 
107*363dd3f3SRob Suderman   ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure)
108*363dd3f3SRob Suderman       : BaseRewrite(ctx, hadFailure) {}
109*363dd3f3SRob Suderman 
110*363dd3f3SRob Suderman   QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp,
111*363dd3f3SRob Suderman                                             Type expressedType) const {
112*363dd3f3SRob Suderman     SmallVector<double, 4> min, max;
113*363dd3f3SRob Suderman     min.reserve(fqOp.min().size());
114*363dd3f3SRob Suderman     max.reserve(fqOp.max().size());
115*363dd3f3SRob Suderman     for (auto m : fqOp.min())
116*363dd3f3SRob Suderman       min.push_back(m.cast<FloatAttr>().getValueAsDouble());
117*363dd3f3SRob Suderman     for (auto m : fqOp.max())
118*363dd3f3SRob Suderman       max.push_back(m.cast<FloatAttr>().getValueAsDouble());
119*363dd3f3SRob Suderman 
120*363dd3f3SRob Suderman     return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
121*363dd3f3SRob Suderman                                 fqOp.axis().getSExtValue(), min, max,
122*363dd3f3SRob Suderman                                 fqOp.narrow_range(), expressedType,
123*363dd3f3SRob Suderman                                 fqOp.is_signed());
124*363dd3f3SRob Suderman   }
125*363dd3f3SRob Suderman };
126*363dd3f3SRob Suderman 
127*363dd3f3SRob Suderman } // namespace
128*363dd3f3SRob Suderman 
129*363dd3f3SRob Suderman void ConvertSimulatedQuantPass::runOnFunction() {
130*363dd3f3SRob Suderman   bool hadFailure = false;
131*363dd3f3SRob Suderman   OwningRewritePatternList patterns;
132*363dd3f3SRob Suderman   auto func = getFunction();
133*363dd3f3SRob Suderman   auto ctx = func.getContext();
134*363dd3f3SRob Suderman   patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
135*363dd3f3SRob Suderman       ctx, &hadFailure);
136*363dd3f3SRob Suderman   applyPatternsGreedily(func, patterns);
137*363dd3f3SRob Suderman   if (hadFailure)
138*363dd3f3SRob Suderman     signalPassFailure();
139*363dd3f3SRob Suderman }
140*363dd3f3SRob Suderman 
141*363dd3f3SRob Suderman std::unique_ptr<OpPassBase<FuncOp>>
142*363dd3f3SRob Suderman mlir::quant::createConvertSimulatedQuantPass() {
143*363dd3f3SRob Suderman   return std::make_unique<ConvertSimulatedQuantPass>();
144*363dd3f3SRob Suderman }
145*363dd3f3SRob Suderman 
146*363dd3f3SRob Suderman static PassRegistration<ConvertSimulatedQuantPass>
147*363dd3f3SRob Suderman     pass("quant-convert-simulated-quantization",
148*363dd3f3SRob Suderman          "Converts training-time simulated quantization ops to corresponding "
149*363dd3f3SRob Suderman          "quantize/dequantize casts.");
150