1 //===- ConvertConst.cpp - Quantizes constant ops --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Quant/Passes.h" 11 #include "mlir/Dialect/Quant/QuantOps.h" 12 #include "mlir/Dialect/Quant/QuantizeUtils.h" 13 #include "mlir/Dialect/Quant/UniformSupport.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/Matchers.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 19 using namespace mlir; 20 using namespace mlir::quant; 21 22 namespace { 23 struct ConvertConstPass : public QuantConvertConstBase<ConvertConstPass> { 24 void runOnFunction() override; 25 }; 26 27 struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> { 28 using OpRewritePattern<QuantizeCastOp>::OpRewritePattern; 29 30 LogicalResult matchAndRewrite(QuantizeCastOp qbarrier, 31 PatternRewriter &rewriter) const override; 32 }; 33 34 } // end anonymous namespace 35 36 /// Matches a [constant] -> [qbarrier] where the qbarrier results type is 37 /// quantized and the operand type is quantizable. 38 39 LogicalResult 40 QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, 41 PatternRewriter &rewriter) const { 42 Attribute value; 43 44 // Is the operand a constant? 45 if (!matchPattern(qbarrier.arg(), m_Constant(&value))) { 46 return failure(); 47 } 48 49 // Does the qbarrier convert to a quantized type. This will not be true 50 // if a quantized type has not yet been chosen or if the cast to an equivalent 51 // storage type is not supported. 52 Type qbarrierResultType = qbarrier.getResult().getType(); 53 QuantizedType quantizedElementType = 54 QuantizedType::getQuantizedElementType(qbarrierResultType); 55 if (!quantizedElementType) { 56 return failure(); 57 } 58 if (!QuantizedType::castToStorageType(qbarrierResultType)) { 59 return failure(); 60 } 61 62 // Is the operand type compatible with the expressed type of the quantized 63 // type? This will not be true if the qbarrier is superfluous (converts 64 // from and to a quantized type). 65 if (!quantizedElementType.isCompatibleExpressedType( 66 qbarrier.arg().getType())) { 67 return failure(); 68 } 69 70 // Is the constant value a type expressed in a way that we support? 71 if (!value.isa<FloatAttr, DenseElementsAttr, SparseElementsAttr>()) { 72 return failure(); 73 } 74 75 Type newConstValueType; 76 auto newConstValue = 77 quantizeAttr(value, quantizedElementType, newConstValueType); 78 if (!newConstValue) { 79 return failure(); 80 } 81 82 // When creating the new const op, use a fused location that combines the 83 // original const and the qbarrier that led to the quantization. 84 auto fusedLoc = FusedLoc::get( 85 {qbarrier.arg().getDefiningOp()->getLoc(), qbarrier.getLoc()}, 86 rewriter.getContext()); 87 auto newConstOp = 88 rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue); 89 rewriter.replaceOpWithNewOp<StorageCastOp>(qbarrier, qbarrier.getType(), 90 newConstOp); 91 return success(); 92 } 93 94 void ConvertConstPass::runOnFunction() { 95 OwningRewritePatternList patterns; 96 auto func = getFunction(); 97 auto *context = &getContext(); 98 patterns.insert<QuantizedConstRewrite>(context); 99 applyPatternsAndFoldGreedily(func, std::move(patterns)); 100 } 101 102 std::unique_ptr<OperationPass<FuncOp>> mlir::quant::createConvertConstPass() { 103 return std::make_unique<ConvertConstPass>(); 104 } 105