1363dd3f3SRob Suderman //===- ConvertConst.cpp - Quantizes constant ops --------------------------===//
2363dd3f3SRob Suderman //
3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6363dd3f3SRob Suderman //
7363dd3f3SRob Suderman //===----------------------------------------------------------------------===//
8363dd3f3SRob Suderman 
91834ad4aSRiver Riddle #include "PassDetail.h"
10*a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11363dd3f3SRob Suderman #include "mlir/Dialect/Quant/Passes.h"
12363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.h"
13363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantizeUtils.h"
14363dd3f3SRob Suderman #include "mlir/Dialect/Quant/UniformSupport.h"
15363dd3f3SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h"
1609f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
17363dd3f3SRob Suderman #include "mlir/IR/Matchers.h"
18b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19363dd3f3SRob Suderman 
20363dd3f3SRob Suderman using namespace mlir;
21363dd3f3SRob Suderman using namespace mlir::quant;
22363dd3f3SRob Suderman 
23363dd3f3SRob Suderman namespace {
241834ad4aSRiver Riddle struct ConvertConstPass : public QuantConvertConstBase<ConvertConstPass> {
25363dd3f3SRob Suderman   void runOnFunction() override;
26363dd3f3SRob Suderman };
27363dd3f3SRob Suderman 
28363dd3f3SRob Suderman struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
29363dd3f3SRob Suderman   using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
30363dd3f3SRob Suderman 
313145427dSRiver Riddle   LogicalResult matchAndRewrite(QuantizeCastOp qbarrier,
32363dd3f3SRob Suderman                                 PatternRewriter &rewriter) const override;
33363dd3f3SRob Suderman };
34363dd3f3SRob Suderman 
35363dd3f3SRob Suderman } // end anonymous namespace
36363dd3f3SRob Suderman 
37363dd3f3SRob Suderman /// Matches a [constant] -> [qbarrier] where the qbarrier results type is
38363dd3f3SRob Suderman /// quantized and the operand type is quantizable.
39363dd3f3SRob Suderman 
403145427dSRiver Riddle LogicalResult
41363dd3f3SRob Suderman QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
42363dd3f3SRob Suderman                                        PatternRewriter &rewriter) const {
43363dd3f3SRob Suderman   Attribute value;
44363dd3f3SRob Suderman 
45363dd3f3SRob Suderman   // Is the operand a constant?
46363dd3f3SRob Suderman   if (!matchPattern(qbarrier.arg(), m_Constant(&value))) {
473145427dSRiver Riddle     return failure();
48363dd3f3SRob Suderman   }
49363dd3f3SRob Suderman 
50363dd3f3SRob Suderman   // Does the qbarrier convert to a quantized type. This will not be true
51363dd3f3SRob Suderman   // if a quantized type has not yet been chosen or if the cast to an equivalent
52363dd3f3SRob Suderman   // storage type is not supported.
53363dd3f3SRob Suderman   Type qbarrierResultType = qbarrier.getResult().getType();
54363dd3f3SRob Suderman   QuantizedType quantizedElementType =
55363dd3f3SRob Suderman       QuantizedType::getQuantizedElementType(qbarrierResultType);
56363dd3f3SRob Suderman   if (!quantizedElementType) {
573145427dSRiver Riddle     return failure();
58363dd3f3SRob Suderman   }
59363dd3f3SRob Suderman   if (!QuantizedType::castToStorageType(qbarrierResultType)) {
603145427dSRiver Riddle     return failure();
61363dd3f3SRob Suderman   }
62363dd3f3SRob Suderman 
63363dd3f3SRob Suderman   // Is the operand type compatible with the expressed type of the quantized
64363dd3f3SRob Suderman   // type? This will not be true if the qbarrier is superfluous (converts
65363dd3f3SRob Suderman   // from and to a quantized type).
66363dd3f3SRob Suderman   if (!quantizedElementType.isCompatibleExpressedType(
67363dd3f3SRob Suderman           qbarrier.arg().getType())) {
683145427dSRiver Riddle     return failure();
69363dd3f3SRob Suderman   }
70363dd3f3SRob Suderman 
71363dd3f3SRob Suderman   // Is the constant value a type expressed in a way that we support?
72ee394e68SRahul Joshi   if (!value.isa<FloatAttr, DenseElementsAttr, SparseElementsAttr>()) {
733145427dSRiver Riddle     return failure();
74363dd3f3SRob Suderman   }
75363dd3f3SRob Suderman 
76363dd3f3SRob Suderman   Type newConstValueType;
77363dd3f3SRob Suderman   auto newConstValue =
78363dd3f3SRob Suderman       quantizeAttr(value, quantizedElementType, newConstValueType);
79363dd3f3SRob Suderman   if (!newConstValue) {
803145427dSRiver Riddle     return failure();
81363dd3f3SRob Suderman   }
82363dd3f3SRob Suderman 
83363dd3f3SRob Suderman   // When creating the new const op, use a fused location that combines the
84363dd3f3SRob Suderman   // original const and the qbarrier that led to the quantization.
85a4bb667dSRiver Riddle   auto fusedLoc = rewriter.getFusedLoc(
86a4bb667dSRiver Riddle       {qbarrier.arg().getDefiningOp()->getLoc(), qbarrier.getLoc()});
87*a54f4eaeSMogball   auto newConstOp = rewriter.create<arith::ConstantOp>(
88*a54f4eaeSMogball       fusedLoc, newConstValueType, newConstValue);
89363dd3f3SRob Suderman   rewriter.replaceOpWithNewOp<StorageCastOp>(qbarrier, qbarrier.getType(),
90363dd3f3SRob Suderman                                              newConstOp);
913145427dSRiver Riddle   return success();
92363dd3f3SRob Suderman }
93363dd3f3SRob Suderman 
94363dd3f3SRob Suderman void ConvertConstPass::runOnFunction() {
95dc4e913bSChris Lattner   RewritePatternSet patterns(&getContext());
96363dd3f3SRob Suderman   auto func = getFunction();
97363dd3f3SRob Suderman   auto *context = &getContext();
98dc4e913bSChris Lattner   patterns.add<QuantizedConstRewrite>(context);
99e21adfa3SRiver Riddle   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
100363dd3f3SRob Suderman }
101363dd3f3SRob Suderman 
10280aca1eaSRiver Riddle std::unique_ptr<OperationPass<FuncOp>> mlir::quant::createConvertConstPass() {
103363dd3f3SRob Suderman   return std::make_unique<ConvertConstPass>();
104363dd3f3SRob Suderman }
105