1 //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Quant/FakeQuantSupport.h"
11 #include "mlir/Dialect/Quant/Passes.h"
12 #include "mlir/Dialect/Quant/QuantOps.h"
13 #include "mlir/Dialect/Quant/UniformSupport.h"
14 #include "mlir/IR/Attributes.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/StandardTypes.h"
17 
18 using namespace mlir;
19 using namespace mlir::quant;
20 
21 namespace {
22 struct ConvertSimulatedQuantPass
23     : public QuantConvertSimulatedQuantBase<ConvertSimulatedQuantPass> {
24   void runOnFunction() override;
25 };
26 
27 /// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
28 template <typename ConcreteRewriteClass, typename FakeQuantOp>
29 class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> {
30 public:
31   using OpRewritePattern<FakeQuantOp>::OpRewritePattern;
32 
33   FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
34       : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}
35 
36   LogicalResult matchAndRewrite(FakeQuantOp op,
37                                 PatternRewriter &rewriter) const override {
38     // TODO: If this pattern comes up more frequently, consider adding core
39     // support for failable rewrites.
40     if (failableRewrite(op, rewriter)) {
41       *hadFailure = true;
42       return failure();
43     }
44 
45     return success();
46   }
47 
48 private:
49   bool *hadFailure;
50 
51   bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const {
52     auto converter = ExpressedToQuantizedConverter::forInputType(op.getType());
53     if (!converter) {
54       return (op.emitError("unsupported quantized type conversion"), true);
55     }
56 
57     QuantizedType elementType =
58         static_cast<const ConcreteRewriteClass *>(this)
59             ->convertFakeQuantAttrsToType(op, converter.expressedType);
60 
61     if (!elementType) {
62       // Note that the fakeQuantAttrsToType will have emitted the error.
63       return true;
64     }
65 
66     Type quantizedType = converter.convert(elementType);
67     assert(quantizedType &&
68            "Converter accepted a type that it did not convert");
69 
70     // TODO: Map to a qbarrier with an attribute like [Forced] to signal that
71     // this is a forced/hard-coded constraint.
72     auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType,
73                                                     op.inputs());
74     rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
75                                                   qbarrier.getResult());
76 
77     return false;
78   }
79 };
80 
81 class ConstFakeQuantRewrite
82     : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> {
83 public:
84   using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>;
85 
86   ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
87       : BaseRewrite(ctx, hadFailure) {}
88 
89   QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp,
90                                             Type expressedType) const {
91     return fakeQuantAttrsToType(
92         fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
93         fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
94         fqOp.narrow_range(), expressedType, fqOp.is_signed());
95   }
96 };
97 
98 class ConstFakeQuantPerAxisRewrite
99     : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite,
100                               ConstFakeQuantPerAxis> {
101 public:
102   using BaseRewrite =
103       FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>;
104 
105   ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure)
106       : BaseRewrite(ctx, hadFailure) {}
107 
108   QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp,
109                                             Type expressedType) const {
110     SmallVector<double, 4> min, max;
111     min.reserve(fqOp.min().size());
112     max.reserve(fqOp.max().size());
113     for (auto m : fqOp.min())
114       min.push_back(m.cast<FloatAttr>().getValueAsDouble());
115     for (auto m : fqOp.max())
116       max.push_back(m.cast<FloatAttr>().getValueAsDouble());
117 
118     return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
119                                 fqOp.axis().getSExtValue(), min, max,
120                                 fqOp.narrow_range(), expressedType,
121                                 fqOp.is_signed());
122   }
123 };
124 
125 } // namespace
126 
127 void ConvertSimulatedQuantPass::runOnFunction() {
128   bool hadFailure = false;
129   OwningRewritePatternList patterns;
130   auto func = getFunction();
131   auto ctx = func.getContext();
132   patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
133       ctx, &hadFailure);
134   applyPatternsAndFoldGreedily(func, patterns);
135   if (hadFailure)
136     signalPassFailure();
137 }
138 
139 std::unique_ptr<OperationPass<FuncOp>>
140 mlir::quant::createConvertSimulatedQuantPass() {
141   return std::make_unique<ConvertSimulatedQuantPass>();
142 }
143