1 //===- ConvertConst.cpp - Quantizes constant ops --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Quant/Passes.h" 10 #include "mlir/Dialect/Quant/QuantOps.h" 11 #include "mlir/Dialect/Quant/QuantizeUtils.h" 12 #include "mlir/Dialect/Quant/UniformSupport.h" 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 #include "mlir/IR/Attributes.h" 15 #include "mlir/IR/Matchers.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/IR/StandardTypes.h" 18 #include "mlir/Pass/Pass.h" 19 20 using namespace mlir; 21 using namespace mlir::quant; 22 23 namespace { 24 25 class ConvertConstPass : public FunctionPass<ConvertConstPass> { 26 public: 27 void runOnFunction() override; 28 }; 29 30 struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> { 31 using OpRewritePattern<QuantizeCastOp>::OpRewritePattern; 32 33 LogicalResult matchAndRewrite(QuantizeCastOp qbarrier, 34 PatternRewriter &rewriter) const override; 35 }; 36 37 } // end anonymous namespace 38 39 /// Matches a [constant] -> [qbarrier] where the qbarrier results type is 40 /// quantized and the operand type is quantizable. 41 42 LogicalResult 43 QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, 44 PatternRewriter &rewriter) const { 45 Attribute value; 46 47 // Is the operand a constant? 48 if (!matchPattern(qbarrier.arg(), m_Constant(&value))) { 49 return failure(); 50 } 51 52 // Does the qbarrier convert to a quantized type. This will not be true 53 // if a quantized type has not yet been chosen or if the cast to an equivalent 54 // storage type is not supported. 55 Type qbarrierResultType = qbarrier.getResult().getType(); 56 QuantizedType quantizedElementType = 57 QuantizedType::getQuantizedElementType(qbarrierResultType); 58 if (!quantizedElementType) { 59 return failure(); 60 } 61 if (!QuantizedType::castToStorageType(qbarrierResultType)) { 62 return failure(); 63 } 64 65 // Is the operand type compatible with the expressed type of the quantized 66 // type? This will not be true if the qbarrier is superfluous (converts 67 // from and to a quantized type). 68 if (!quantizedElementType.isCompatibleExpressedType( 69 qbarrier.arg().getType())) { 70 return failure(); 71 } 72 73 // Is the constant value a type expressed in a way that we support? 74 if (!value.isa<FloatAttr>() && !value.isa<DenseElementsAttr>() && 75 !value.isa<SparseElementsAttr>()) { 76 return failure(); 77 } 78 79 Type newConstValueType; 80 auto newConstValue = 81 quantizeAttr(value, quantizedElementType, newConstValueType); 82 if (!newConstValue) { 83 return failure(); 84 } 85 86 // When creating the new const op, use a fused location that combines the 87 // original const and the qbarrier that led to the quantization. 88 auto fusedLoc = FusedLoc::get( 89 {qbarrier.arg().getDefiningOp()->getLoc(), qbarrier.getLoc()}, 90 rewriter.getContext()); 91 auto newConstOp = 92 rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue); 93 rewriter.replaceOpWithNewOp<StorageCastOp>(qbarrier, qbarrier.getType(), 94 newConstOp); 95 return success(); 96 } 97 98 void ConvertConstPass::runOnFunction() { 99 OwningRewritePatternList patterns; 100 auto func = getFunction(); 101 auto *context = &getContext(); 102 patterns.insert<QuantizedConstRewrite>(context); 103 applyPatternsGreedily(func, patterns); 104 } 105 106 std::unique_ptr<OpPassBase<FuncOp>> mlir::quant::createConvertConstPass() { 107 return std::make_unique<ConvertConstPass>(); 108 } 109 110 static PassRegistration<ConvertConstPass> 111 pass("quant-convert-const", 112 "Converts constants followed by qbarrier to actual quantized values"); 113