1363dd3f3SRob Suderman //===- ConvertConst.cpp - Quantizes constant ops --------------------------===//
2363dd3f3SRob Suderman //
3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6363dd3f3SRob Suderman //
7363dd3f3SRob Suderman //===----------------------------------------------------------------------===//
8363dd3f3SRob Suderman 
9363dd3f3SRob Suderman #include "mlir/Dialect/Quant/Passes.h"
10363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.h"
11363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantizeUtils.h"
12363dd3f3SRob Suderman #include "mlir/Dialect/Quant/UniformSupport.h"
13363dd3f3SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h"
14363dd3f3SRob Suderman #include "mlir/IR/Attributes.h"
15363dd3f3SRob Suderman #include "mlir/IR/Matchers.h"
16363dd3f3SRob Suderman #include "mlir/IR/PatternMatch.h"
17363dd3f3SRob Suderman #include "mlir/IR/StandardTypes.h"
18363dd3f3SRob Suderman #include "mlir/Pass/Pass.h"
19363dd3f3SRob Suderman 
20363dd3f3SRob Suderman using namespace mlir;
21363dd3f3SRob Suderman using namespace mlir::quant;
22363dd3f3SRob Suderman 
23363dd3f3SRob Suderman namespace {
24*80aca1eaSRiver Riddle struct ConvertConstPass : public PassWrapper<ConvertConstPass, FunctionPass> {
259a277af2SRiver Riddle /// Include the generated pass utilities.
269a277af2SRiver Riddle #define GEN_PASS_QuantConvertConst
279a277af2SRiver Riddle #include "mlir/Dialect/Quant/Passes.h.inc"
28363dd3f3SRob Suderman 
29363dd3f3SRob Suderman   void runOnFunction() override;
30363dd3f3SRob Suderman };
31363dd3f3SRob Suderman 
32363dd3f3SRob Suderman struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
33363dd3f3SRob Suderman   using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
34363dd3f3SRob Suderman 
353145427dSRiver Riddle   LogicalResult matchAndRewrite(QuantizeCastOp qbarrier,
36363dd3f3SRob Suderman                                 PatternRewriter &rewriter) const override;
37363dd3f3SRob Suderman };
38363dd3f3SRob Suderman 
39363dd3f3SRob Suderman } // end anonymous namespace
40363dd3f3SRob Suderman 
41363dd3f3SRob Suderman /// Matches a [constant] -> [qbarrier] where the qbarrier results type is
42363dd3f3SRob Suderman /// quantized and the operand type is quantizable.
43363dd3f3SRob Suderman 
443145427dSRiver Riddle LogicalResult
45363dd3f3SRob Suderman QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
46363dd3f3SRob Suderman                                        PatternRewriter &rewriter) const {
47363dd3f3SRob Suderman   Attribute value;
48363dd3f3SRob Suderman 
49363dd3f3SRob Suderman   // Is the operand a constant?
50363dd3f3SRob Suderman   if (!matchPattern(qbarrier.arg(), m_Constant(&value))) {
513145427dSRiver Riddle     return failure();
52363dd3f3SRob Suderman   }
53363dd3f3SRob Suderman 
54363dd3f3SRob Suderman   // Does the qbarrier convert to a quantized type. This will not be true
55363dd3f3SRob Suderman   // if a quantized type has not yet been chosen or if the cast to an equivalent
56363dd3f3SRob Suderman   // storage type is not supported.
57363dd3f3SRob Suderman   Type qbarrierResultType = qbarrier.getResult().getType();
58363dd3f3SRob Suderman   QuantizedType quantizedElementType =
59363dd3f3SRob Suderman       QuantizedType::getQuantizedElementType(qbarrierResultType);
60363dd3f3SRob Suderman   if (!quantizedElementType) {
613145427dSRiver Riddle     return failure();
62363dd3f3SRob Suderman   }
63363dd3f3SRob Suderman   if (!QuantizedType::castToStorageType(qbarrierResultType)) {
643145427dSRiver Riddle     return failure();
65363dd3f3SRob Suderman   }
66363dd3f3SRob Suderman 
67363dd3f3SRob Suderman   // Is the operand type compatible with the expressed type of the quantized
68363dd3f3SRob Suderman   // type? This will not be true if the qbarrier is superfluous (converts
69363dd3f3SRob Suderman   // from and to a quantized type).
70363dd3f3SRob Suderman   if (!quantizedElementType.isCompatibleExpressedType(
71363dd3f3SRob Suderman           qbarrier.arg().getType())) {
723145427dSRiver Riddle     return failure();
73363dd3f3SRob Suderman   }
74363dd3f3SRob Suderman 
75363dd3f3SRob Suderman   // Is the constant value a type expressed in a way that we support?
76363dd3f3SRob Suderman   if (!value.isa<FloatAttr>() && !value.isa<DenseElementsAttr>() &&
77363dd3f3SRob Suderman       !value.isa<SparseElementsAttr>()) {
783145427dSRiver Riddle     return failure();
79363dd3f3SRob Suderman   }
80363dd3f3SRob Suderman 
81363dd3f3SRob Suderman   Type newConstValueType;
82363dd3f3SRob Suderman   auto newConstValue =
83363dd3f3SRob Suderman       quantizeAttr(value, quantizedElementType, newConstValueType);
84363dd3f3SRob Suderman   if (!newConstValue) {
853145427dSRiver Riddle     return failure();
86363dd3f3SRob Suderman   }
87363dd3f3SRob Suderman 
88363dd3f3SRob Suderman   // When creating the new const op, use a fused location that combines the
89363dd3f3SRob Suderman   // original const and the qbarrier that led to the quantization.
90363dd3f3SRob Suderman   auto fusedLoc = FusedLoc::get(
91363dd3f3SRob Suderman       {qbarrier.arg().getDefiningOp()->getLoc(), qbarrier.getLoc()},
92363dd3f3SRob Suderman       rewriter.getContext());
93363dd3f3SRob Suderman   auto newConstOp =
94363dd3f3SRob Suderman       rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
95363dd3f3SRob Suderman   rewriter.replaceOpWithNewOp<StorageCastOp>(qbarrier, qbarrier.getType(),
96363dd3f3SRob Suderman                                              newConstOp);
973145427dSRiver Riddle   return success();
98363dd3f3SRob Suderman }
99363dd3f3SRob Suderman 
100363dd3f3SRob Suderman void ConvertConstPass::runOnFunction() {
101363dd3f3SRob Suderman   OwningRewritePatternList patterns;
102363dd3f3SRob Suderman   auto func = getFunction();
103363dd3f3SRob Suderman   auto *context = &getContext();
104363dd3f3SRob Suderman   patterns.insert<QuantizedConstRewrite>(context);
105363dd3f3SRob Suderman   applyPatternsGreedily(func, patterns);
106363dd3f3SRob Suderman }
107363dd3f3SRob Suderman 
108*80aca1eaSRiver Riddle std::unique_ptr<OperationPass<FuncOp>> mlir::quant::createConvertConstPass() {
109363dd3f3SRob Suderman   return std::make_unique<ConvertConstPass>();
110363dd3f3SRob Suderman }
111