1 //===- ConvertConst.cpp - Quantizes constant ops --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Quant/Passes.h" 11 #include "mlir/Dialect/Quant/QuantOps.h" 12 #include "mlir/Dialect/Quant/QuantizeUtils.h" 13 #include "mlir/Dialect/Quant/UniformSupport.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/Matchers.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/IR/StandardTypes.h" 19 20 using namespace mlir; 21 using namespace mlir::quant; 22 23 namespace { 24 struct ConvertConstPass : public QuantConvertConstBase<ConvertConstPass> { 25 void runOnFunction() override; 26 }; 27 28 struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> { 29 using OpRewritePattern<QuantizeCastOp>::OpRewritePattern; 30 31 LogicalResult matchAndRewrite(QuantizeCastOp qbarrier, 32 PatternRewriter &rewriter) const override; 33 }; 34 35 } // end anonymous namespace 36 37 /// Matches a [constant] -> [qbarrier] where the qbarrier results type is 38 /// quantized and the operand type is quantizable. 39 40 LogicalResult 41 QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, 42 PatternRewriter &rewriter) const { 43 Attribute value; 44 45 // Is the operand a constant? 46 if (!matchPattern(qbarrier.arg(), m_Constant(&value))) { 47 return failure(); 48 } 49 50 // Does the qbarrier convert to a quantized type. This will not be true 51 // if a quantized type has not yet been chosen or if the cast to an equivalent 52 // storage type is not supported. 53 Type qbarrierResultType = qbarrier.getResult().getType(); 54 QuantizedType quantizedElementType = 55 QuantizedType::getQuantizedElementType(qbarrierResultType); 56 if (!quantizedElementType) { 57 return failure(); 58 } 59 if (!QuantizedType::castToStorageType(qbarrierResultType)) { 60 return failure(); 61 } 62 63 // Is the operand type compatible with the expressed type of the quantized 64 // type? This will not be true if the qbarrier is superfluous (converts 65 // from and to a quantized type). 66 if (!quantizedElementType.isCompatibleExpressedType( 67 qbarrier.arg().getType())) { 68 return failure(); 69 } 70 71 // Is the constant value a type expressed in a way that we support? 72 if (!value.isa<FloatAttr, DenseElementsAttr, SparseElementsAttr>()) { 73 return failure(); 74 } 75 76 Type newConstValueType; 77 auto newConstValue = 78 quantizeAttr(value, quantizedElementType, newConstValueType); 79 if (!newConstValue) { 80 return failure(); 81 } 82 83 // When creating the new const op, use a fused location that combines the 84 // original const and the qbarrier that led to the quantization. 85 auto fusedLoc = FusedLoc::get( 86 {qbarrier.arg().getDefiningOp()->getLoc(), qbarrier.getLoc()}, 87 rewriter.getContext()); 88 auto newConstOp = 89 rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue); 90 rewriter.replaceOpWithNewOp<StorageCastOp>(qbarrier, qbarrier.getType(), 91 newConstOp); 92 return success(); 93 } 94 95 void ConvertConstPass::runOnFunction() { 96 OwningRewritePatternList patterns; 97 auto func = getFunction(); 98 auto *context = &getContext(); 99 patterns.insert<QuantizedConstRewrite>(context); 100 applyPatternsAndFoldGreedily(func, patterns); 101 } 102 103 std::unique_ptr<OperationPass<FuncOp>> mlir::quant::createConvertConstPass() { 104 return std::make_unique<ConvertConstPass>(); 105 } 106