1363dd3f3SRob Suderman //===- ConvertConst.cpp - Quantizes constant ops --------------------------===// 2363dd3f3SRob Suderman // 3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information. 5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6363dd3f3SRob Suderman // 7363dd3f3SRob Suderman //===----------------------------------------------------------------------===// 8363dd3f3SRob Suderman 91834ad4aSRiver Riddle #include "PassDetail.h" 10*a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11363dd3f3SRob Suderman #include "mlir/Dialect/Quant/Passes.h" 12363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.h" 13363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantizeUtils.h" 14363dd3f3SRob Suderman #include "mlir/Dialect/Quant/UniformSupport.h" 15363dd3f3SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h" 1609f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h" 17363dd3f3SRob Suderman #include "mlir/IR/Matchers.h" 18b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19363dd3f3SRob Suderman 20363dd3f3SRob Suderman using namespace mlir; 21363dd3f3SRob Suderman using namespace mlir::quant; 22363dd3f3SRob Suderman 23363dd3f3SRob Suderman namespace { 241834ad4aSRiver Riddle struct ConvertConstPass : public QuantConvertConstBase<ConvertConstPass> { 25363dd3f3SRob Suderman void runOnFunction() override; 26363dd3f3SRob Suderman }; 27363dd3f3SRob Suderman 28363dd3f3SRob Suderman struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> { 29363dd3f3SRob Suderman using OpRewritePattern<QuantizeCastOp>::OpRewritePattern; 30363dd3f3SRob Suderman 313145427dSRiver Riddle LogicalResult matchAndRewrite(QuantizeCastOp qbarrier, 32363dd3f3SRob Suderman PatternRewriter &rewriter) const override; 33363dd3f3SRob Suderman }; 34363dd3f3SRob Suderman 35363dd3f3SRob Suderman } // end anonymous namespace 36363dd3f3SRob Suderman 37363dd3f3SRob Suderman /// Matches a [constant] -> [qbarrier] where the qbarrier results type is 38363dd3f3SRob Suderman /// quantized and the operand type is quantizable. 39363dd3f3SRob Suderman 403145427dSRiver Riddle LogicalResult 41363dd3f3SRob Suderman QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, 42363dd3f3SRob Suderman PatternRewriter &rewriter) const { 43363dd3f3SRob Suderman Attribute value; 44363dd3f3SRob Suderman 45363dd3f3SRob Suderman // Is the operand a constant? 46363dd3f3SRob Suderman if (!matchPattern(qbarrier.arg(), m_Constant(&value))) { 473145427dSRiver Riddle return failure(); 48363dd3f3SRob Suderman } 49363dd3f3SRob Suderman 50363dd3f3SRob Suderman // Does the qbarrier convert to a quantized type. This will not be true 51363dd3f3SRob Suderman // if a quantized type has not yet been chosen or if the cast to an equivalent 52363dd3f3SRob Suderman // storage type is not supported. 53363dd3f3SRob Suderman Type qbarrierResultType = qbarrier.getResult().getType(); 54363dd3f3SRob Suderman QuantizedType quantizedElementType = 55363dd3f3SRob Suderman QuantizedType::getQuantizedElementType(qbarrierResultType); 56363dd3f3SRob Suderman if (!quantizedElementType) { 573145427dSRiver Riddle return failure(); 58363dd3f3SRob Suderman } 59363dd3f3SRob Suderman if (!QuantizedType::castToStorageType(qbarrierResultType)) { 603145427dSRiver Riddle return failure(); 61363dd3f3SRob Suderman } 62363dd3f3SRob Suderman 63363dd3f3SRob Suderman // Is the operand type compatible with the expressed type of the quantized 64363dd3f3SRob Suderman // type? This will not be true if the qbarrier is superfluous (converts 65363dd3f3SRob Suderman // from and to a quantized type). 66363dd3f3SRob Suderman if (!quantizedElementType.isCompatibleExpressedType( 67363dd3f3SRob Suderman qbarrier.arg().getType())) { 683145427dSRiver Riddle return failure(); 69363dd3f3SRob Suderman } 70363dd3f3SRob Suderman 71363dd3f3SRob Suderman // Is the constant value a type expressed in a way that we support? 72ee394e68SRahul Joshi if (!value.isa<FloatAttr, DenseElementsAttr, SparseElementsAttr>()) { 733145427dSRiver Riddle return failure(); 74363dd3f3SRob Suderman } 75363dd3f3SRob Suderman 76363dd3f3SRob Suderman Type newConstValueType; 77363dd3f3SRob Suderman auto newConstValue = 78363dd3f3SRob Suderman quantizeAttr(value, quantizedElementType, newConstValueType); 79363dd3f3SRob Suderman if (!newConstValue) { 803145427dSRiver Riddle return failure(); 81363dd3f3SRob Suderman } 82363dd3f3SRob Suderman 83363dd3f3SRob Suderman // When creating the new const op, use a fused location that combines the 84363dd3f3SRob Suderman // original const and the qbarrier that led to the quantization. 85a4bb667dSRiver Riddle auto fusedLoc = rewriter.getFusedLoc( 86a4bb667dSRiver Riddle {qbarrier.arg().getDefiningOp()->getLoc(), qbarrier.getLoc()}); 87*a54f4eaeSMogball auto newConstOp = rewriter.create<arith::ConstantOp>( 88*a54f4eaeSMogball fusedLoc, newConstValueType, newConstValue); 89363dd3f3SRob Suderman rewriter.replaceOpWithNewOp<StorageCastOp>(qbarrier, qbarrier.getType(), 90363dd3f3SRob Suderman newConstOp); 913145427dSRiver Riddle return success(); 92363dd3f3SRob Suderman } 93363dd3f3SRob Suderman 94363dd3f3SRob Suderman void ConvertConstPass::runOnFunction() { 95dc4e913bSChris Lattner RewritePatternSet patterns(&getContext()); 96363dd3f3SRob Suderman auto func = getFunction(); 97363dd3f3SRob Suderman auto *context = &getContext(); 98dc4e913bSChris Lattner patterns.add<QuantizedConstRewrite>(context); 99e21adfa3SRiver Riddle (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 100363dd3f3SRob Suderman } 101363dd3f3SRob Suderman 10280aca1eaSRiver Riddle std::unique_ptr<OperationPass<FuncOp>> mlir::quant::createConvertConstPass() { 103363dd3f3SRob Suderman return std::make_unique<ConvertConstPass>(); 104363dd3f3SRob Suderman } 105