1 //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Quant/FakeQuantSupport.h"
10 #include "mlir/Dialect/Quant/Passes.h"
11 #include "mlir/Dialect/Quant/QuantOps.h"
12 #include "mlir/Dialect/Quant/UniformSupport.h"
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/StandardTypes.h"
16 #include "mlir/Pass/Pass.h"
17 
18 using namespace mlir;
19 using namespace mlir::quant;
20 
21 namespace {
22 
23 class ConvertSimulatedQuantPass
24     : public FunctionPass<ConvertSimulatedQuantPass> {
25 public:
26   void runOnFunction() override;
27 };
28 
29 /// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
30 template <typename ConcreteRewriteClass, typename FakeQuantOp>
31 class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> {
32 public:
33   using OpRewritePattern<FakeQuantOp>::OpRewritePattern;
34 
35   FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
36       : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}
37 
38   LogicalResult matchAndRewrite(FakeQuantOp op,
39                                 PatternRewriter &rewriter) const override {
40     // TODO: If this pattern comes up more frequently, consider adding core
41     // support for failable rewrites.
42     if (failableRewrite(op, rewriter)) {
43       *hadFailure = true;
44       return failure();
45     }
46 
47     return success();
48   }
49 
50 private:
51   bool *hadFailure;
52 
53   bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const {
54     auto converter = ExpressedToQuantizedConverter::forInputType(op.getType());
55     if (!converter) {
56       return (op.emitError("unsupported quantized type conversion"), true);
57     }
58 
59     QuantizedType elementType =
60         static_cast<const ConcreteRewriteClass *>(this)
61             ->convertFakeQuantAttrsToType(op, converter.expressedType);
62 
63     if (!elementType) {
64       // Note that the fakeQuantAttrsToType will have emitted the error.
65       return true;
66     }
67 
68     Type quantizedType = converter.convert(elementType);
69     assert(quantizedType &&
70            "Converter accepted a type that it did not convert");
71 
72     // TODO: Map to a qbarrier with an attribute like [Forced] to signal that
73     // this is a forced/hard-coded constraint.
74     auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType,
75                                                     op.inputs());
76     rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
77                                                   qbarrier.getResult());
78 
79     return false;
80   }
81 };
82 
83 class ConstFakeQuantRewrite
84     : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> {
85 public:
86   using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>;
87 
88   ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
89       : BaseRewrite(ctx, hadFailure) {}
90 
91   QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp,
92                                             Type expressedType) const {
93     return fakeQuantAttrsToType(
94         fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
95         fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
96         fqOp.narrow_range(), expressedType, fqOp.is_signed());
97   }
98 };
99 
100 class ConstFakeQuantPerAxisRewrite
101     : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite,
102                               ConstFakeQuantPerAxis> {
103 public:
104   using BaseRewrite =
105       FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>;
106 
107   ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure)
108       : BaseRewrite(ctx, hadFailure) {}
109 
110   QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp,
111                                             Type expressedType) const {
112     SmallVector<double, 4> min, max;
113     min.reserve(fqOp.min().size());
114     max.reserve(fqOp.max().size());
115     for (auto m : fqOp.min())
116       min.push_back(m.cast<FloatAttr>().getValueAsDouble());
117     for (auto m : fqOp.max())
118       max.push_back(m.cast<FloatAttr>().getValueAsDouble());
119 
120     return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
121                                 fqOp.axis().getSExtValue(), min, max,
122                                 fqOp.narrow_range(), expressedType,
123                                 fqOp.is_signed());
124   }
125 };
126 
127 } // namespace
128 
129 void ConvertSimulatedQuantPass::runOnFunction() {
130   bool hadFailure = false;
131   OwningRewritePatternList patterns;
132   auto func = getFunction();
133   auto ctx = func.getContext();
134   patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
135       ctx, &hadFailure);
136   applyPatternsGreedily(func, patterns);
137   if (hadFailure)
138     signalPassFailure();
139 }
140 
141 std::unique_ptr<OpPassBase<FuncOp>>
142 mlir::quant::createConvertSimulatedQuantPass() {
143   return std::make_unique<ConvertSimulatedQuantPass>();
144 }
145 
146 static PassRegistration<ConvertSimulatedQuantPass>
147     pass("quant-convert-simulated-quantization",
148          "Converts training-time simulated quantization ops to corresponding "
149          "quantize/dequantize casts.");
150