1363dd3f3SRob Suderman //===- ConvertConst.cpp - Quantizes constant ops --------------------------===// 2363dd3f3SRob Suderman // 3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information. 5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6363dd3f3SRob Suderman // 7363dd3f3SRob Suderman //===----------------------------------------------------------------------===// 8363dd3f3SRob Suderman 9363dd3f3SRob Suderman #include "mlir/Dialect/Quant/Passes.h" 10363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantOps.h" 11363dd3f3SRob Suderman #include "mlir/Dialect/Quant/QuantizeUtils.h" 12363dd3f3SRob Suderman #include "mlir/Dialect/Quant/UniformSupport.h" 13363dd3f3SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h" 14363dd3f3SRob Suderman #include "mlir/IR/Attributes.h" 15363dd3f3SRob Suderman #include "mlir/IR/Matchers.h" 16363dd3f3SRob Suderman #include "mlir/IR/PatternMatch.h" 17363dd3f3SRob Suderman #include "mlir/IR/StandardTypes.h" 18363dd3f3SRob Suderman #include "mlir/Pass/Pass.h" 19363dd3f3SRob Suderman 20363dd3f3SRob Suderman using namespace mlir; 21363dd3f3SRob Suderman using namespace mlir::quant; 22363dd3f3SRob Suderman 23363dd3f3SRob Suderman namespace { 24*80aca1eaSRiver Riddle struct ConvertConstPass : public PassWrapper<ConvertConstPass, FunctionPass> { 259a277af2SRiver Riddle /// Include the generated pass utilities. 269a277af2SRiver Riddle #define GEN_PASS_QuantConvertConst 279a277af2SRiver Riddle #include "mlir/Dialect/Quant/Passes.h.inc" 28363dd3f3SRob Suderman 29363dd3f3SRob Suderman void runOnFunction() override; 30363dd3f3SRob Suderman }; 31363dd3f3SRob Suderman 32363dd3f3SRob Suderman struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> { 33363dd3f3SRob Suderman using OpRewritePattern<QuantizeCastOp>::OpRewritePattern; 34363dd3f3SRob Suderman 353145427dSRiver Riddle LogicalResult matchAndRewrite(QuantizeCastOp qbarrier, 36363dd3f3SRob Suderman PatternRewriter &rewriter) const override; 37363dd3f3SRob Suderman }; 38363dd3f3SRob Suderman 39363dd3f3SRob Suderman } // end anonymous namespace 40363dd3f3SRob Suderman 41363dd3f3SRob Suderman /// Matches a [constant] -> [qbarrier] where the qbarrier results type is 42363dd3f3SRob Suderman /// quantized and the operand type is quantizable. 43363dd3f3SRob Suderman 443145427dSRiver Riddle LogicalResult 45363dd3f3SRob Suderman QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, 46363dd3f3SRob Suderman PatternRewriter &rewriter) const { 47363dd3f3SRob Suderman Attribute value; 48363dd3f3SRob Suderman 49363dd3f3SRob Suderman // Is the operand a constant? 50363dd3f3SRob Suderman if (!matchPattern(qbarrier.arg(), m_Constant(&value))) { 513145427dSRiver Riddle return failure(); 52363dd3f3SRob Suderman } 53363dd3f3SRob Suderman 54363dd3f3SRob Suderman // Does the qbarrier convert to a quantized type. This will not be true 55363dd3f3SRob Suderman // if a quantized type has not yet been chosen or if the cast to an equivalent 56363dd3f3SRob Suderman // storage type is not supported. 57363dd3f3SRob Suderman Type qbarrierResultType = qbarrier.getResult().getType(); 58363dd3f3SRob Suderman QuantizedType quantizedElementType = 59363dd3f3SRob Suderman QuantizedType::getQuantizedElementType(qbarrierResultType); 60363dd3f3SRob Suderman if (!quantizedElementType) { 613145427dSRiver Riddle return failure(); 62363dd3f3SRob Suderman } 63363dd3f3SRob Suderman if (!QuantizedType::castToStorageType(qbarrierResultType)) { 643145427dSRiver Riddle return failure(); 65363dd3f3SRob Suderman } 66363dd3f3SRob Suderman 67363dd3f3SRob Suderman // Is the operand type compatible with the expressed type of the quantized 68363dd3f3SRob Suderman // type? This will not be true if the qbarrier is superfluous (converts 69363dd3f3SRob Suderman // from and to a quantized type). 70363dd3f3SRob Suderman if (!quantizedElementType.isCompatibleExpressedType( 71363dd3f3SRob Suderman qbarrier.arg().getType())) { 723145427dSRiver Riddle return failure(); 73363dd3f3SRob Suderman } 74363dd3f3SRob Suderman 75363dd3f3SRob Suderman // Is the constant value a type expressed in a way that we support? 76363dd3f3SRob Suderman if (!value.isa<FloatAttr>() && !value.isa<DenseElementsAttr>() && 77363dd3f3SRob Suderman !value.isa<SparseElementsAttr>()) { 783145427dSRiver Riddle return failure(); 79363dd3f3SRob Suderman } 80363dd3f3SRob Suderman 81363dd3f3SRob Suderman Type newConstValueType; 82363dd3f3SRob Suderman auto newConstValue = 83363dd3f3SRob Suderman quantizeAttr(value, quantizedElementType, newConstValueType); 84363dd3f3SRob Suderman if (!newConstValue) { 853145427dSRiver Riddle return failure(); 86363dd3f3SRob Suderman } 87363dd3f3SRob Suderman 88363dd3f3SRob Suderman // When creating the new const op, use a fused location that combines the 89363dd3f3SRob Suderman // original const and the qbarrier that led to the quantization. 90363dd3f3SRob Suderman auto fusedLoc = FusedLoc::get( 91363dd3f3SRob Suderman {qbarrier.arg().getDefiningOp()->getLoc(), qbarrier.getLoc()}, 92363dd3f3SRob Suderman rewriter.getContext()); 93363dd3f3SRob Suderman auto newConstOp = 94363dd3f3SRob Suderman rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue); 95363dd3f3SRob Suderman rewriter.replaceOpWithNewOp<StorageCastOp>(qbarrier, qbarrier.getType(), 96363dd3f3SRob Suderman newConstOp); 973145427dSRiver Riddle return success(); 98363dd3f3SRob Suderman } 99363dd3f3SRob Suderman 100363dd3f3SRob Suderman void ConvertConstPass::runOnFunction() { 101363dd3f3SRob Suderman OwningRewritePatternList patterns; 102363dd3f3SRob Suderman auto func = getFunction(); 103363dd3f3SRob Suderman auto *context = &getContext(); 104363dd3f3SRob Suderman patterns.insert<QuantizedConstRewrite>(context); 105363dd3f3SRob Suderman applyPatternsGreedily(func, patterns); 106363dd3f3SRob Suderman } 107363dd3f3SRob Suderman 108*80aca1eaSRiver Riddle std::unique_ptr<OperationPass<FuncOp>> mlir::quant::createConvertConstPass() { 109363dd3f3SRob Suderman return std::make_unique<ConvertConstPass>(); 110363dd3f3SRob Suderman } 111