1 //===- ConvertConst.cpp - Quantizes constant ops --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Quant/Passes.h"
10 #include "mlir/Dialect/Quant/QuantOps.h"
11 #include "mlir/Dialect/Quant/QuantizeUtils.h"
12 #include "mlir/Dialect/Quant/UniformSupport.h"
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/IR/Attributes.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/StandardTypes.h"
18 #include "mlir/Pass/Pass.h"
19 
20 using namespace mlir;
21 using namespace mlir::quant;
22 
23 namespace {
24 struct ConvertConstPass : public FunctionPass<ConvertConstPass> {
25 /// Include the generated pass utilities.
26 #define GEN_PASS_QuantConvertConst
27 #include "mlir/Dialect/Quant/Passes.h.inc"
28 
29   void runOnFunction() override;
30 };
31 
32 struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
33   using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
34 
35   LogicalResult matchAndRewrite(QuantizeCastOp qbarrier,
36                                 PatternRewriter &rewriter) const override;
37 };
38 
39 } // end anonymous namespace
40 
41 /// Matches a [constant] -> [qbarrier] where the qbarrier results type is
42 /// quantized and the operand type is quantizable.
43 
44 LogicalResult
45 QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
46                                        PatternRewriter &rewriter) const {
47   Attribute value;
48 
49   // Is the operand a constant?
50   if (!matchPattern(qbarrier.arg(), m_Constant(&value))) {
51     return failure();
52   }
53 
54   // Does the qbarrier convert to a quantized type. This will not be true
55   // if a quantized type has not yet been chosen or if the cast to an equivalent
56   // storage type is not supported.
57   Type qbarrierResultType = qbarrier.getResult().getType();
58   QuantizedType quantizedElementType =
59       QuantizedType::getQuantizedElementType(qbarrierResultType);
60   if (!quantizedElementType) {
61     return failure();
62   }
63   if (!QuantizedType::castToStorageType(qbarrierResultType)) {
64     return failure();
65   }
66 
67   // Is the operand type compatible with the expressed type of the quantized
68   // type? This will not be true if the qbarrier is superfluous (converts
69   // from and to a quantized type).
70   if (!quantizedElementType.isCompatibleExpressedType(
71           qbarrier.arg().getType())) {
72     return failure();
73   }
74 
75   // Is the constant value a type expressed in a way that we support?
76   if (!value.isa<FloatAttr>() && !value.isa<DenseElementsAttr>() &&
77       !value.isa<SparseElementsAttr>()) {
78     return failure();
79   }
80 
81   Type newConstValueType;
82   auto newConstValue =
83       quantizeAttr(value, quantizedElementType, newConstValueType);
84   if (!newConstValue) {
85     return failure();
86   }
87 
88   // When creating the new const op, use a fused location that combines the
89   // original const and the qbarrier that led to the quantization.
90   auto fusedLoc = FusedLoc::get(
91       {qbarrier.arg().getDefiningOp()->getLoc(), qbarrier.getLoc()},
92       rewriter.getContext());
93   auto newConstOp =
94       rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
95   rewriter.replaceOpWithNewOp<StorageCastOp>(qbarrier, qbarrier.getType(),
96                                              newConstOp);
97   return success();
98 }
99 
100 void ConvertConstPass::runOnFunction() {
101   OwningRewritePatternList patterns;
102   auto func = getFunction();
103   auto *context = &getContext();
104   patterns.insert<QuantizedConstRewrite>(context);
105   applyPatternsGreedily(func, patterns);
106 }
107 
108 std::unique_ptr<OpPassBase<FuncOp>> mlir::quant::createConvertConstPass() {
109   return std::make_unique<ConvertConstPass>();
110 }
111