1 //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Quant/FakeQuantSupport.h" 11 #include "mlir/Dialect/Quant/Passes.h" 12 #include "mlir/Dialect/Quant/QuantOps.h" 13 #include "mlir/Dialect/Quant/UniformSupport.h" 14 #include "mlir/IR/Attributes.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/StandardTypes.h" 17 18 using namespace mlir; 19 using namespace mlir::quant; 20 21 namespace { 22 struct ConvertSimulatedQuantPass 23 : public QuantConvertSimulatedQuantBase<ConvertSimulatedQuantPass> { 24 void runOnFunction() override; 25 }; 26 27 /// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair. 28 template <typename ConcreteRewriteClass, typename FakeQuantOp> 29 class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> { 30 public: 31 using OpRewritePattern<FakeQuantOp>::OpRewritePattern; 32 33 FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) 34 : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {} 35 36 LogicalResult matchAndRewrite(FakeQuantOp op, 37 PatternRewriter &rewriter) const override { 38 // TODO: If this pattern comes up more frequently, consider adding core 39 // support for failable rewrites. 40 if (failableRewrite(op, rewriter)) { 41 *hadFailure = true; 42 return failure(); 43 } 44 45 return success(); 46 } 47 48 private: 49 bool *hadFailure; 50 51 bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const { 52 auto converter = ExpressedToQuantizedConverter::forInputType(op.getType()); 53 if (!converter) { 54 return (op.emitError("unsupported quantized type conversion"), true); 55 } 56 57 QuantizedType elementType = 58 static_cast<const ConcreteRewriteClass *>(this) 59 ->convertFakeQuantAttrsToType(op, converter.expressedType); 60 61 if (!elementType) { 62 // Note that the fakeQuantAttrsToType will have emitted the error. 63 return true; 64 } 65 66 Type quantizedType = converter.convert(elementType); 67 assert(quantizedType && 68 "Converter accepted a type that it did not convert"); 69 70 // TODO: Map to a qbarrier with an attribute like [Forced] to signal that 71 // this is a forced/hard-coded constraint. 72 auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType, 73 op.inputs()); 74 rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType, 75 qbarrier.getResult()); 76 77 return false; 78 } 79 }; 80 81 class ConstFakeQuantRewrite 82 : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> { 83 public: 84 using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>; 85 86 ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) 87 : BaseRewrite(ctx, hadFailure) {} 88 89 QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp, 90 Type expressedType) const { 91 return fakeQuantAttrsToType( 92 fqOp.getLoc(), fqOp.num_bits().getSExtValue(), 93 fqOp.min().convertToFloat(), fqOp.max().convertToFloat(), 94 fqOp.narrow_range(), expressedType, fqOp.is_signed()); 95 } 96 }; 97 98 class ConstFakeQuantPerAxisRewrite 99 : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, 100 ConstFakeQuantPerAxis> { 101 public: 102 using BaseRewrite = 103 FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>; 104 105 ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure) 106 : BaseRewrite(ctx, hadFailure) {} 107 108 QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp, 109 Type expressedType) const { 110 SmallVector<double, 4> min, max; 111 min.reserve(fqOp.min().size()); 112 max.reserve(fqOp.max().size()); 113 for (auto m : fqOp.min()) 114 min.push_back(m.cast<FloatAttr>().getValueAsDouble()); 115 for (auto m : fqOp.max()) 116 max.push_back(m.cast<FloatAttr>().getValueAsDouble()); 117 118 return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(), 119 fqOp.axis().getSExtValue(), min, max, 120 fqOp.narrow_range(), expressedType, 121 fqOp.is_signed()); 122 } 123 }; 124 125 } // namespace 126 127 void ConvertSimulatedQuantPass::runOnFunction() { 128 bool hadFailure = false; 129 OwningRewritePatternList patterns; 130 auto func = getFunction(); 131 auto ctx = func.getContext(); 132 patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>( 133 ctx, &hadFailure); 134 applyPatternsAndFoldGreedily(func, patterns); 135 if (hadFailure) 136 signalPassFailure(); 137 } 138 139 std::unique_ptr<OperationPass<FuncOp>> 140 mlir::quant::createConvertSimulatedQuantPass() { 141 return std::make_unique<ConvertSimulatedQuantPass>(); 142 } 143