1363dd3f3SRob Suderman //===- ConvertConst.cpp - Quantizes constant ops --------------------------===//
2363dd3f3SRob Suderman //
3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6363dd3f3SRob Suderman //
7363dd3f3SRob Suderman //===----------------------------------------------------------------------===//
8363dd3f3SRob Suderman
91834ad4aSRiver Riddle #include "PassDetail.h"
10a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11363dd3f3SRob Suderman #include "mlir/Dialect/Quant/Passes.h"
12363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.h"
13363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantizeUtils.h"
14363dd3f3SRob Suderman #include "mlir/Dialect/Quant/UniformSupport.h"
1509f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
16363dd3f3SRob Suderman #include "mlir/IR/Matchers.h"
17b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18363dd3f3SRob Suderman
19363dd3f3SRob Suderman using namespace mlir;
20363dd3f3SRob Suderman using namespace mlir::quant;
21363dd3f3SRob Suderman
22363dd3f3SRob Suderman namespace {
231834ad4aSRiver Riddle struct ConvertConstPass : public QuantConvertConstBase<ConvertConstPass> {
2441574554SRiver Riddle void runOnOperation() override;
25363dd3f3SRob Suderman };
26363dd3f3SRob Suderman
27363dd3f3SRob Suderman struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
28363dd3f3SRob Suderman using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
29363dd3f3SRob Suderman
303145427dSRiver Riddle LogicalResult matchAndRewrite(QuantizeCastOp qbarrier,
31363dd3f3SRob Suderman PatternRewriter &rewriter) const override;
32363dd3f3SRob Suderman };
33363dd3f3SRob Suderman
34be0a7e9fSMehdi Amini } // namespace
35363dd3f3SRob Suderman
36363dd3f3SRob Suderman /// Matches a [constant] -> [qbarrier] where the qbarrier results type is
37363dd3f3SRob Suderman /// quantized and the operand type is quantizable.
38363dd3f3SRob Suderman
393145427dSRiver Riddle LogicalResult
matchAndRewrite(QuantizeCastOp qbarrier,PatternRewriter & rewriter) const40363dd3f3SRob Suderman QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
41363dd3f3SRob Suderman PatternRewriter &rewriter) const {
42363dd3f3SRob Suderman Attribute value;
43363dd3f3SRob Suderman
44363dd3f3SRob Suderman // Is the operand a constant?
45*04235d07SJacques Pienaar if (!matchPattern(qbarrier.getArg(), m_Constant(&value))) {
463145427dSRiver Riddle return failure();
47363dd3f3SRob Suderman }
48363dd3f3SRob Suderman
49363dd3f3SRob Suderman // Does the qbarrier convert to a quantized type. This will not be true
50363dd3f3SRob Suderman // if a quantized type has not yet been chosen or if the cast to an equivalent
51363dd3f3SRob Suderman // storage type is not supported.
52363dd3f3SRob Suderman Type qbarrierResultType = qbarrier.getResult().getType();
53363dd3f3SRob Suderman QuantizedType quantizedElementType =
54363dd3f3SRob Suderman QuantizedType::getQuantizedElementType(qbarrierResultType);
55363dd3f3SRob Suderman if (!quantizedElementType) {
563145427dSRiver Riddle return failure();
57363dd3f3SRob Suderman }
58363dd3f3SRob Suderman if (!QuantizedType::castToStorageType(qbarrierResultType)) {
593145427dSRiver Riddle return failure();
60363dd3f3SRob Suderman }
61363dd3f3SRob Suderman
62363dd3f3SRob Suderman // Is the operand type compatible with the expressed type of the quantized
63363dd3f3SRob Suderman // type? This will not be true if the qbarrier is superfluous (converts
64363dd3f3SRob Suderman // from and to a quantized type).
65363dd3f3SRob Suderman if (!quantizedElementType.isCompatibleExpressedType(
66*04235d07SJacques Pienaar qbarrier.getArg().getType())) {
673145427dSRiver Riddle return failure();
68363dd3f3SRob Suderman }
69363dd3f3SRob Suderman
70363dd3f3SRob Suderman // Is the constant value a type expressed in a way that we support?
71ee394e68SRahul Joshi if (!value.isa<FloatAttr, DenseElementsAttr, SparseElementsAttr>()) {
723145427dSRiver Riddle return failure();
73363dd3f3SRob Suderman }
74363dd3f3SRob Suderman
75363dd3f3SRob Suderman Type newConstValueType;
76363dd3f3SRob Suderman auto newConstValue =
77363dd3f3SRob Suderman quantizeAttr(value, quantizedElementType, newConstValueType);
78363dd3f3SRob Suderman if (!newConstValue) {
793145427dSRiver Riddle return failure();
80363dd3f3SRob Suderman }
81363dd3f3SRob Suderman
82363dd3f3SRob Suderman // When creating the new const op, use a fused location that combines the
83363dd3f3SRob Suderman // original const and the qbarrier that led to the quantization.
84a4bb667dSRiver Riddle auto fusedLoc = rewriter.getFusedLoc(
85*04235d07SJacques Pienaar {qbarrier.getArg().getDefiningOp()->getLoc(), qbarrier.getLoc()});
86a54f4eaeSMogball auto newConstOp = rewriter.create<arith::ConstantOp>(
87a54f4eaeSMogball fusedLoc, newConstValueType, newConstValue);
88363dd3f3SRob Suderman rewriter.replaceOpWithNewOp<StorageCastOp>(qbarrier, qbarrier.getType(),
89363dd3f3SRob Suderman newConstOp);
903145427dSRiver Riddle return success();
91363dd3f3SRob Suderman }
92363dd3f3SRob Suderman
runOnOperation()9341574554SRiver Riddle void ConvertConstPass::runOnOperation() {
94dc4e913bSChris Lattner RewritePatternSet patterns(&getContext());
9541574554SRiver Riddle auto func = getOperation();
96363dd3f3SRob Suderman auto *context = &getContext();
97dc4e913bSChris Lattner patterns.add<QuantizedConstRewrite>(context);
98e21adfa3SRiver Riddle (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
99363dd3f3SRob Suderman }
100363dd3f3SRob Suderman
10158ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>>
createConvertConstPass()10258ceae95SRiver Riddle mlir::quant::createConvertConstPass() {
103363dd3f3SRob Suderman return std::make_unique<ConvertConstPass>();
104363dd3f3SRob Suderman }
105