1 //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Quant/FakeQuantSupport.h"
10 #include "mlir/Dialect/Quant/Passes.h"
11 #include "mlir/Dialect/Quant/QuantOps.h"
12 #include "mlir/Dialect/Quant/UniformSupport.h"
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/StandardTypes.h"
16 #include "mlir/Pass/Pass.h"
17 
18 using namespace mlir;
19 using namespace mlir::quant;
20 
21 namespace {
22 struct ConvertSimulatedQuantPass
23     : public FunctionPass<ConvertSimulatedQuantPass> {
24 /// Include the generated pass utilities.
25 #define GEN_PASS_QuantConvertSimulatedQuant
26 #include "mlir/Dialect/Quant/Passes.h.inc"
27 
28   void runOnFunction() override;
29 };
30 
31 /// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
32 template <typename ConcreteRewriteClass, typename FakeQuantOp>
33 class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> {
34 public:
35   using OpRewritePattern<FakeQuantOp>::OpRewritePattern;
36 
37   FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
38       : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}
39 
40   LogicalResult matchAndRewrite(FakeQuantOp op,
41                                 PatternRewriter &rewriter) const override {
42     // TODO: If this pattern comes up more frequently, consider adding core
43     // support for failable rewrites.
44     if (failableRewrite(op, rewriter)) {
45       *hadFailure = true;
46       return failure();
47     }
48 
49     return success();
50   }
51 
52 private:
53   bool *hadFailure;
54 
55   bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const {
56     auto converter = ExpressedToQuantizedConverter::forInputType(op.getType());
57     if (!converter) {
58       return (op.emitError("unsupported quantized type conversion"), true);
59     }
60 
61     QuantizedType elementType =
62         static_cast<const ConcreteRewriteClass *>(this)
63             ->convertFakeQuantAttrsToType(op, converter.expressedType);
64 
65     if (!elementType) {
66       // Note that the fakeQuantAttrsToType will have emitted the error.
67       return true;
68     }
69 
70     Type quantizedType = converter.convert(elementType);
71     assert(quantizedType &&
72            "Converter accepted a type that it did not convert");
73 
74     // TODO: Map to a qbarrier with an attribute like [Forced] to signal that
75     // this is a forced/hard-coded constraint.
76     auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType,
77                                                     op.inputs());
78     rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
79                                                   qbarrier.getResult());
80 
81     return false;
82   }
83 };
84 
85 class ConstFakeQuantRewrite
86     : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> {
87 public:
88   using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>;
89 
90   ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
91       : BaseRewrite(ctx, hadFailure) {}
92 
93   QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp,
94                                             Type expressedType) const {
95     return fakeQuantAttrsToType(
96         fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
97         fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
98         fqOp.narrow_range(), expressedType, fqOp.is_signed());
99   }
100 };
101 
102 class ConstFakeQuantPerAxisRewrite
103     : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite,
104                               ConstFakeQuantPerAxis> {
105 public:
106   using BaseRewrite =
107       FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>;
108 
109   ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure)
110       : BaseRewrite(ctx, hadFailure) {}
111 
112   QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp,
113                                             Type expressedType) const {
114     SmallVector<double, 4> min, max;
115     min.reserve(fqOp.min().size());
116     max.reserve(fqOp.max().size());
117     for (auto m : fqOp.min())
118       min.push_back(m.cast<FloatAttr>().getValueAsDouble());
119     for (auto m : fqOp.max())
120       max.push_back(m.cast<FloatAttr>().getValueAsDouble());
121 
122     return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
123                                 fqOp.axis().getSExtValue(), min, max,
124                                 fqOp.narrow_range(), expressedType,
125                                 fqOp.is_signed());
126   }
127 };
128 
129 } // namespace
130 
131 void ConvertSimulatedQuantPass::runOnFunction() {
132   bool hadFailure = false;
133   OwningRewritePatternList patterns;
134   auto func = getFunction();
135   auto ctx = func.getContext();
136   patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
137       ctx, &hadFailure);
138   applyPatternsGreedily(func, patterns);
139   if (hadFailure)
140     signalPassFailure();
141 }
142 
143 std::unique_ptr<OpPassBase<FuncOp>>
144 mlir::quant::createConvertSimulatedQuantPass() {
145   return std::make_unique<ConvertSimulatedQuantPass>();
146 }
147