1363dd3f3SRob Suderman //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
2363dd3f3SRob Suderman //
3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6363dd3f3SRob Suderman //
7363dd3f3SRob Suderman //===----------------------------------------------------------------------===//
8363dd3f3SRob Suderman 
9363dd3f3SRob Suderman #include "mlir/Dialect/Quant/FakeQuantSupport.h"
10363dd3f3SRob Suderman #include "mlir/Dialect/Quant/Passes.h"
11363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.h"
12363dd3f3SRob Suderman #include "mlir/Dialect/Quant/UniformSupport.h"
13363dd3f3SRob Suderman #include "mlir/IR/Attributes.h"
14363dd3f3SRob Suderman #include "mlir/IR/PatternMatch.h"
15363dd3f3SRob Suderman #include "mlir/IR/StandardTypes.h"
16363dd3f3SRob Suderman #include "mlir/Pass/Pass.h"
17363dd3f3SRob Suderman 
18363dd3f3SRob Suderman using namespace mlir;
19363dd3f3SRob Suderman using namespace mlir::quant;
20363dd3f3SRob Suderman 
21363dd3f3SRob Suderman namespace {
229a277af2SRiver Riddle struct ConvertSimulatedQuantPass
23*80aca1eaSRiver Riddle     : public PassWrapper<ConvertSimulatedQuantPass, FunctionPass> {
249a277af2SRiver Riddle /// Include the generated pass utilities.
259a277af2SRiver Riddle #define GEN_PASS_QuantConvertSimulatedQuant
269a277af2SRiver Riddle #include "mlir/Dialect/Quant/Passes.h.inc"
279a277af2SRiver Riddle 
28363dd3f3SRob Suderman   void runOnFunction() override;
29363dd3f3SRob Suderman };
30363dd3f3SRob Suderman 
31363dd3f3SRob Suderman /// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
32363dd3f3SRob Suderman template <typename ConcreteRewriteClass, typename FakeQuantOp>
33363dd3f3SRob Suderman class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> {
34363dd3f3SRob Suderman public:
35363dd3f3SRob Suderman   using OpRewritePattern<FakeQuantOp>::OpRewritePattern;
36363dd3f3SRob Suderman 
37363dd3f3SRob Suderman   FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
38363dd3f3SRob Suderman       : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}
39363dd3f3SRob Suderman 
403145427dSRiver Riddle   LogicalResult matchAndRewrite(FakeQuantOp op,
41363dd3f3SRob Suderman                                 PatternRewriter &rewriter) const override {
42363dd3f3SRob Suderman     // TODO: If this pattern comes up more frequently, consider adding core
43363dd3f3SRob Suderman     // support for failable rewrites.
44363dd3f3SRob Suderman     if (failableRewrite(op, rewriter)) {
45363dd3f3SRob Suderman       *hadFailure = true;
463145427dSRiver Riddle       return failure();
47363dd3f3SRob Suderman     }
48363dd3f3SRob Suderman 
493145427dSRiver Riddle     return success();
50363dd3f3SRob Suderman   }
51363dd3f3SRob Suderman 
52363dd3f3SRob Suderman private:
53363dd3f3SRob Suderman   bool *hadFailure;
54363dd3f3SRob Suderman 
55363dd3f3SRob Suderman   bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const {
56363dd3f3SRob Suderman     auto converter = ExpressedToQuantizedConverter::forInputType(op.getType());
57363dd3f3SRob Suderman     if (!converter) {
58363dd3f3SRob Suderman       return (op.emitError("unsupported quantized type conversion"), true);
59363dd3f3SRob Suderman     }
60363dd3f3SRob Suderman 
61363dd3f3SRob Suderman     QuantizedType elementType =
62363dd3f3SRob Suderman         static_cast<const ConcreteRewriteClass *>(this)
63363dd3f3SRob Suderman             ->convertFakeQuantAttrsToType(op, converter.expressedType);
64363dd3f3SRob Suderman 
65363dd3f3SRob Suderman     if (!elementType) {
66363dd3f3SRob Suderman       // Note that the fakeQuantAttrsToType will have emitted the error.
67363dd3f3SRob Suderman       return true;
68363dd3f3SRob Suderman     }
69363dd3f3SRob Suderman 
70363dd3f3SRob Suderman     Type quantizedType = converter.convert(elementType);
71363dd3f3SRob Suderman     assert(quantizedType &&
72363dd3f3SRob Suderman            "Converter accepted a type that it did not convert");
73363dd3f3SRob Suderman 
74363dd3f3SRob Suderman     // TODO: Map to a qbarrier with an attribute like [Forced] to signal that
75363dd3f3SRob Suderman     // this is a forced/hard-coded constraint.
76363dd3f3SRob Suderman     auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType,
77363dd3f3SRob Suderman                                                     op.inputs());
78363dd3f3SRob Suderman     rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
79363dd3f3SRob Suderman                                                   qbarrier.getResult());
80363dd3f3SRob Suderman 
81363dd3f3SRob Suderman     return false;
82363dd3f3SRob Suderman   }
83363dd3f3SRob Suderman };
84363dd3f3SRob Suderman 
85363dd3f3SRob Suderman class ConstFakeQuantRewrite
86363dd3f3SRob Suderman     : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> {
87363dd3f3SRob Suderman public:
88363dd3f3SRob Suderman   using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>;
89363dd3f3SRob Suderman 
90363dd3f3SRob Suderman   ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
91363dd3f3SRob Suderman       : BaseRewrite(ctx, hadFailure) {}
92363dd3f3SRob Suderman 
93363dd3f3SRob Suderman   QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp,
94363dd3f3SRob Suderman                                             Type expressedType) const {
95363dd3f3SRob Suderman     return fakeQuantAttrsToType(
96363dd3f3SRob Suderman         fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
97363dd3f3SRob Suderman         fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
98363dd3f3SRob Suderman         fqOp.narrow_range(), expressedType, fqOp.is_signed());
99363dd3f3SRob Suderman   }
100363dd3f3SRob Suderman };
101363dd3f3SRob Suderman 
102363dd3f3SRob Suderman class ConstFakeQuantPerAxisRewrite
103363dd3f3SRob Suderman     : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite,
104363dd3f3SRob Suderman                               ConstFakeQuantPerAxis> {
105363dd3f3SRob Suderman public:
106363dd3f3SRob Suderman   using BaseRewrite =
107363dd3f3SRob Suderman       FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>;
108363dd3f3SRob Suderman 
109363dd3f3SRob Suderman   ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure)
110363dd3f3SRob Suderman       : BaseRewrite(ctx, hadFailure) {}
111363dd3f3SRob Suderman 
112363dd3f3SRob Suderman   QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp,
113363dd3f3SRob Suderman                                             Type expressedType) const {
114363dd3f3SRob Suderman     SmallVector<double, 4> min, max;
115363dd3f3SRob Suderman     min.reserve(fqOp.min().size());
116363dd3f3SRob Suderman     max.reserve(fqOp.max().size());
117363dd3f3SRob Suderman     for (auto m : fqOp.min())
118363dd3f3SRob Suderman       min.push_back(m.cast<FloatAttr>().getValueAsDouble());
119363dd3f3SRob Suderman     for (auto m : fqOp.max())
120363dd3f3SRob Suderman       max.push_back(m.cast<FloatAttr>().getValueAsDouble());
121363dd3f3SRob Suderman 
122363dd3f3SRob Suderman     return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
123363dd3f3SRob Suderman                                 fqOp.axis().getSExtValue(), min, max,
124363dd3f3SRob Suderman                                 fqOp.narrow_range(), expressedType,
125363dd3f3SRob Suderman                                 fqOp.is_signed());
126363dd3f3SRob Suderman   }
127363dd3f3SRob Suderman };
128363dd3f3SRob Suderman 
129363dd3f3SRob Suderman } // namespace
130363dd3f3SRob Suderman 
131363dd3f3SRob Suderman void ConvertSimulatedQuantPass::runOnFunction() {
132363dd3f3SRob Suderman   bool hadFailure = false;
133363dd3f3SRob Suderman   OwningRewritePatternList patterns;
134363dd3f3SRob Suderman   auto func = getFunction();
135363dd3f3SRob Suderman   auto ctx = func.getContext();
136363dd3f3SRob Suderman   patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
137363dd3f3SRob Suderman       ctx, &hadFailure);
138363dd3f3SRob Suderman   applyPatternsGreedily(func, patterns);
139363dd3f3SRob Suderman   if (hadFailure)
140363dd3f3SRob Suderman     signalPassFailure();
141363dd3f3SRob Suderman }
142363dd3f3SRob Suderman 
143*80aca1eaSRiver Riddle std::unique_ptr<OperationPass<FuncOp>>
144363dd3f3SRob Suderman mlir::quant::createConvertSimulatedQuantPass() {
145363dd3f3SRob Suderman   return std::make_unique<ConvertSimulatedQuantPass>();
146363dd3f3SRob Suderman }
147