1363dd3f3SRob Suderman //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
2363dd3f3SRob Suderman //
3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6363dd3f3SRob Suderman //
7363dd3f3SRob Suderman //===----------------------------------------------------------------------===//
8363dd3f3SRob Suderman 
91834ad4aSRiver Riddle #include "PassDetail.h"
10363dd3f3SRob Suderman #include "mlir/Dialect/Quant/FakeQuantSupport.h"
11363dd3f3SRob Suderman #include "mlir/Dialect/Quant/Passes.h"
12363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.h"
13363dd3f3SRob Suderman #include "mlir/Dialect/Quant/UniformSupport.h"
14363dd3f3SRob Suderman #include "mlir/IR/Attributes.h"
15363dd3f3SRob Suderman #include "mlir/IR/PatternMatch.h"
16363dd3f3SRob Suderman #include "mlir/IR/StandardTypes.h"
17363dd3f3SRob Suderman 
18363dd3f3SRob Suderman using namespace mlir;
19363dd3f3SRob Suderman using namespace mlir::quant;
20363dd3f3SRob Suderman 
21363dd3f3SRob Suderman namespace {
229a277af2SRiver Riddle struct ConvertSimulatedQuantPass
231834ad4aSRiver Riddle     : public QuantConvertSimulatedQuantBase<ConvertSimulatedQuantPass> {
24363dd3f3SRob Suderman   void runOnFunction() override;
25363dd3f3SRob Suderman };
26363dd3f3SRob Suderman 
27363dd3f3SRob Suderman /// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
28363dd3f3SRob Suderman template <typename ConcreteRewriteClass, typename FakeQuantOp>
29363dd3f3SRob Suderman class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> {
30363dd3f3SRob Suderman public:
31363dd3f3SRob Suderman   using OpRewritePattern<FakeQuantOp>::OpRewritePattern;
32363dd3f3SRob Suderman 
33363dd3f3SRob Suderman   FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
34363dd3f3SRob Suderman       : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}
35363dd3f3SRob Suderman 
363145427dSRiver Riddle   LogicalResult matchAndRewrite(FakeQuantOp op,
37363dd3f3SRob Suderman                                 PatternRewriter &rewriter) const override {
38363dd3f3SRob Suderman     // TODO: If this pattern comes up more frequently, consider adding core
39363dd3f3SRob Suderman     // support for failable rewrites.
40363dd3f3SRob Suderman     if (failableRewrite(op, rewriter)) {
41363dd3f3SRob Suderman       *hadFailure = true;
423145427dSRiver Riddle       return failure();
43363dd3f3SRob Suderman     }
44363dd3f3SRob Suderman 
453145427dSRiver Riddle     return success();
46363dd3f3SRob Suderman   }
47363dd3f3SRob Suderman 
48363dd3f3SRob Suderman private:
49363dd3f3SRob Suderman   bool *hadFailure;
50363dd3f3SRob Suderman 
51363dd3f3SRob Suderman   bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const {
52363dd3f3SRob Suderman     auto converter = ExpressedToQuantizedConverter::forInputType(op.getType());
53363dd3f3SRob Suderman     if (!converter) {
54363dd3f3SRob Suderman       return (op.emitError("unsupported quantized type conversion"), true);
55363dd3f3SRob Suderman     }
56363dd3f3SRob Suderman 
57363dd3f3SRob Suderman     QuantizedType elementType =
58363dd3f3SRob Suderman         static_cast<const ConcreteRewriteClass *>(this)
59363dd3f3SRob Suderman             ->convertFakeQuantAttrsToType(op, converter.expressedType);
60363dd3f3SRob Suderman 
61363dd3f3SRob Suderman     if (!elementType) {
62363dd3f3SRob Suderman       // Note that the fakeQuantAttrsToType will have emitted the error.
63363dd3f3SRob Suderman       return true;
64363dd3f3SRob Suderman     }
65363dd3f3SRob Suderman 
66363dd3f3SRob Suderman     Type quantizedType = converter.convert(elementType);
67363dd3f3SRob Suderman     assert(quantizedType &&
68363dd3f3SRob Suderman            "Converter accepted a type that it did not convert");
69363dd3f3SRob Suderman 
70363dd3f3SRob Suderman     // TODO: Map to a qbarrier with an attribute like [Forced] to signal that
71363dd3f3SRob Suderman     // this is a forced/hard-coded constraint.
72363dd3f3SRob Suderman     auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType,
73363dd3f3SRob Suderman                                                     op.inputs());
74363dd3f3SRob Suderman     rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
75363dd3f3SRob Suderman                                                   qbarrier.getResult());
76363dd3f3SRob Suderman 
77363dd3f3SRob Suderman     return false;
78363dd3f3SRob Suderman   }
79363dd3f3SRob Suderman };
80363dd3f3SRob Suderman 
81363dd3f3SRob Suderman class ConstFakeQuantRewrite
82363dd3f3SRob Suderman     : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> {
83363dd3f3SRob Suderman public:
84363dd3f3SRob Suderman   using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>;
85363dd3f3SRob Suderman 
86363dd3f3SRob Suderman   ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
87363dd3f3SRob Suderman       : BaseRewrite(ctx, hadFailure) {}
88363dd3f3SRob Suderman 
89363dd3f3SRob Suderman   QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp,
90363dd3f3SRob Suderman                                             Type expressedType) const {
91363dd3f3SRob Suderman     return fakeQuantAttrsToType(
92*431bb8b3SRiver Riddle         fqOp.getLoc(), fqOp.num_bits(), fqOp.min().convertToFloat(),
93*431bb8b3SRiver Riddle         fqOp.max().convertToFloat(), fqOp.narrow_range(), expressedType,
94*431bb8b3SRiver Riddle         fqOp.is_signed());
95363dd3f3SRob Suderman   }
96363dd3f3SRob Suderman };
97363dd3f3SRob Suderman 
98363dd3f3SRob Suderman class ConstFakeQuantPerAxisRewrite
99363dd3f3SRob Suderman     : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite,
100363dd3f3SRob Suderman                               ConstFakeQuantPerAxis> {
101363dd3f3SRob Suderman public:
102363dd3f3SRob Suderman   using BaseRewrite =
103363dd3f3SRob Suderman       FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>;
104363dd3f3SRob Suderman 
105363dd3f3SRob Suderman   ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure)
106363dd3f3SRob Suderman       : BaseRewrite(ctx, hadFailure) {}
107363dd3f3SRob Suderman 
108363dd3f3SRob Suderman   QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp,
109363dd3f3SRob Suderman                                             Type expressedType) const {
110363dd3f3SRob Suderman     SmallVector<double, 4> min, max;
111363dd3f3SRob Suderman     min.reserve(fqOp.min().size());
112363dd3f3SRob Suderman     max.reserve(fqOp.max().size());
113363dd3f3SRob Suderman     for (auto m : fqOp.min())
114363dd3f3SRob Suderman       min.push_back(m.cast<FloatAttr>().getValueAsDouble());
115363dd3f3SRob Suderman     for (auto m : fqOp.max())
116363dd3f3SRob Suderman       max.push_back(m.cast<FloatAttr>().getValueAsDouble());
117363dd3f3SRob Suderman 
118*431bb8b3SRiver Riddle     return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits(), fqOp.axis(),
119*431bb8b3SRiver Riddle                                 min, max, fqOp.narrow_range(), expressedType,
120363dd3f3SRob Suderman                                 fqOp.is_signed());
121363dd3f3SRob Suderman   }
122363dd3f3SRob Suderman };
123363dd3f3SRob Suderman 
124363dd3f3SRob Suderman } // namespace
125363dd3f3SRob Suderman 
126363dd3f3SRob Suderman void ConvertSimulatedQuantPass::runOnFunction() {
127363dd3f3SRob Suderman   bool hadFailure = false;
128363dd3f3SRob Suderman   OwningRewritePatternList patterns;
129363dd3f3SRob Suderman   auto func = getFunction();
130363dd3f3SRob Suderman   auto ctx = func.getContext();
131363dd3f3SRob Suderman   patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
132363dd3f3SRob Suderman       ctx, &hadFailure);
133a5b9316bSUday Bondhugula   applyPatternsAndFoldGreedily(func, patterns);
134363dd3f3SRob Suderman   if (hadFailure)
135363dd3f3SRob Suderman     signalPassFailure();
136363dd3f3SRob Suderman }
137363dd3f3SRob Suderman 
13880aca1eaSRiver Riddle std::unique_ptr<OperationPass<FuncOp>>
139363dd3f3SRob Suderman mlir::quant::createConvertSimulatedQuantPass() {
140363dd3f3SRob Suderman   return std::make_unique<ConvertSimulatedQuantPass>();
141363dd3f3SRob Suderman }
142