1363dd3f3SRob Suderman //===- ConvertConst.cpp - Quantizes constant ops --------------------------===//
2363dd3f3SRob Suderman //
3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6363dd3f3SRob Suderman //
7363dd3f3SRob Suderman //===----------------------------------------------------------------------===//
8363dd3f3SRob Suderman 
91834ad4aSRiver Riddle #include "PassDetail.h"
10a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11363dd3f3SRob Suderman #include "mlir/Dialect/Quant/Passes.h"
12363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.h"
13363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantizeUtils.h"
14363dd3f3SRob Suderman #include "mlir/Dialect/Quant/UniformSupport.h"
1509f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
16363dd3f3SRob Suderman #include "mlir/IR/Matchers.h"
17b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18363dd3f3SRob Suderman 
19363dd3f3SRob Suderman using namespace mlir;
20363dd3f3SRob Suderman using namespace mlir::quant;
21363dd3f3SRob Suderman 
22363dd3f3SRob Suderman namespace {
231834ad4aSRiver Riddle struct ConvertConstPass : public QuantConvertConstBase<ConvertConstPass> {
2441574554SRiver Riddle   void runOnOperation() override;
25363dd3f3SRob Suderman };
26363dd3f3SRob Suderman 
27363dd3f3SRob Suderman struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
28363dd3f3SRob Suderman   using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
29363dd3f3SRob Suderman 
303145427dSRiver Riddle   LogicalResult matchAndRewrite(QuantizeCastOp qbarrier,
31363dd3f3SRob Suderman                                 PatternRewriter &rewriter) const override;
32363dd3f3SRob Suderman };
33363dd3f3SRob Suderman 
34be0a7e9fSMehdi Amini } // namespace
35363dd3f3SRob Suderman 
36363dd3f3SRob Suderman /// Matches a [constant] -> [qbarrier] where the qbarrier results type is
37363dd3f3SRob Suderman /// quantized and the operand type is quantizable.
38363dd3f3SRob Suderman 
393145427dSRiver Riddle LogicalResult
matchAndRewrite(QuantizeCastOp qbarrier,PatternRewriter & rewriter) const40363dd3f3SRob Suderman QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
41363dd3f3SRob Suderman                                        PatternRewriter &rewriter) const {
42363dd3f3SRob Suderman   Attribute value;
43363dd3f3SRob Suderman 
44363dd3f3SRob Suderman   // Is the operand a constant?
45*04235d07SJacques Pienaar   if (!matchPattern(qbarrier.getArg(), m_Constant(&value))) {
463145427dSRiver Riddle     return failure();
47363dd3f3SRob Suderman   }
48363dd3f3SRob Suderman 
49363dd3f3SRob Suderman   // Does the qbarrier convert to a quantized type. This will not be true
50363dd3f3SRob Suderman   // if a quantized type has not yet been chosen or if the cast to an equivalent
51363dd3f3SRob Suderman   // storage type is not supported.
52363dd3f3SRob Suderman   Type qbarrierResultType = qbarrier.getResult().getType();
53363dd3f3SRob Suderman   QuantizedType quantizedElementType =
54363dd3f3SRob Suderman       QuantizedType::getQuantizedElementType(qbarrierResultType);
55363dd3f3SRob Suderman   if (!quantizedElementType) {
563145427dSRiver Riddle     return failure();
57363dd3f3SRob Suderman   }
58363dd3f3SRob Suderman   if (!QuantizedType::castToStorageType(qbarrierResultType)) {
593145427dSRiver Riddle     return failure();
60363dd3f3SRob Suderman   }
61363dd3f3SRob Suderman 
62363dd3f3SRob Suderman   // Is the operand type compatible with the expressed type of the quantized
63363dd3f3SRob Suderman   // type? This will not be true if the qbarrier is superfluous (converts
64363dd3f3SRob Suderman   // from and to a quantized type).
65363dd3f3SRob Suderman   if (!quantizedElementType.isCompatibleExpressedType(
66*04235d07SJacques Pienaar           qbarrier.getArg().getType())) {
673145427dSRiver Riddle     return failure();
68363dd3f3SRob Suderman   }
69363dd3f3SRob Suderman 
70363dd3f3SRob Suderman   // Is the constant value a type expressed in a way that we support?
71ee394e68SRahul Joshi   if (!value.isa<FloatAttr, DenseElementsAttr, SparseElementsAttr>()) {
723145427dSRiver Riddle     return failure();
73363dd3f3SRob Suderman   }
74363dd3f3SRob Suderman 
75363dd3f3SRob Suderman   Type newConstValueType;
76363dd3f3SRob Suderman   auto newConstValue =
77363dd3f3SRob Suderman       quantizeAttr(value, quantizedElementType, newConstValueType);
78363dd3f3SRob Suderman   if (!newConstValue) {
793145427dSRiver Riddle     return failure();
80363dd3f3SRob Suderman   }
81363dd3f3SRob Suderman 
82363dd3f3SRob Suderman   // When creating the new const op, use a fused location that combines the
83363dd3f3SRob Suderman   // original const and the qbarrier that led to the quantization.
84a4bb667dSRiver Riddle   auto fusedLoc = rewriter.getFusedLoc(
85*04235d07SJacques Pienaar       {qbarrier.getArg().getDefiningOp()->getLoc(), qbarrier.getLoc()});
86a54f4eaeSMogball   auto newConstOp = rewriter.create<arith::ConstantOp>(
87a54f4eaeSMogball       fusedLoc, newConstValueType, newConstValue);
88363dd3f3SRob Suderman   rewriter.replaceOpWithNewOp<StorageCastOp>(qbarrier, qbarrier.getType(),
89363dd3f3SRob Suderman                                              newConstOp);
903145427dSRiver Riddle   return success();
91363dd3f3SRob Suderman }
92363dd3f3SRob Suderman 
runOnOperation()9341574554SRiver Riddle void ConvertConstPass::runOnOperation() {
94dc4e913bSChris Lattner   RewritePatternSet patterns(&getContext());
9541574554SRiver Riddle   auto func = getOperation();
96363dd3f3SRob Suderman   auto *context = &getContext();
97dc4e913bSChris Lattner   patterns.add<QuantizedConstRewrite>(context);
98e21adfa3SRiver Riddle   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
99363dd3f3SRob Suderman }
100363dd3f3SRob Suderman 
10158ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>>
createConvertConstPass()10258ceae95SRiver Riddle mlir::quant::createConvertConstPass() {
103363dd3f3SRob Suderman   return std::make_unique<ConvertConstPass>();
104363dd3f3SRob Suderman }
105