1 //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Quant/FakeQuantSupport.h" 10 #include "mlir/Dialect/Quant/Passes.h" 11 #include "mlir/Dialect/Quant/QuantOps.h" 12 #include "mlir/Dialect/Quant/UniformSupport.h" 13 #include "mlir/IR/Attributes.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/StandardTypes.h" 16 #include "mlir/Pass/Pass.h" 17 18 using namespace mlir; 19 using namespace mlir::quant; 20 21 namespace { 22 struct ConvertSimulatedQuantPass 23 : public PassWrapper<ConvertSimulatedQuantPass, FunctionPass> { 24 /// Include the generated pass utilities. 25 #define GEN_PASS_QuantConvertSimulatedQuant 26 #include "mlir/Dialect/Quant/Passes.h.inc" 27 28 void runOnFunction() override; 29 }; 30 31 /// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair. 32 template <typename ConcreteRewriteClass, typename FakeQuantOp> 33 class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> { 34 public: 35 using OpRewritePattern<FakeQuantOp>::OpRewritePattern; 36 37 FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) 38 : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {} 39 40 LogicalResult matchAndRewrite(FakeQuantOp op, 41 PatternRewriter &rewriter) const override { 42 // TODO: If this pattern comes up more frequently, consider adding core 43 // support for failable rewrites. 44 if (failableRewrite(op, rewriter)) { 45 *hadFailure = true; 46 return failure(); 47 } 48 49 return success(); 50 } 51 52 private: 53 bool *hadFailure; 54 55 bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const { 56 auto converter = ExpressedToQuantizedConverter::forInputType(op.getType()); 57 if (!converter) { 58 return (op.emitError("unsupported quantized type conversion"), true); 59 } 60 61 QuantizedType elementType = 62 static_cast<const ConcreteRewriteClass *>(this) 63 ->convertFakeQuantAttrsToType(op, converter.expressedType); 64 65 if (!elementType) { 66 // Note that the fakeQuantAttrsToType will have emitted the error. 67 return true; 68 } 69 70 Type quantizedType = converter.convert(elementType); 71 assert(quantizedType && 72 "Converter accepted a type that it did not convert"); 73 74 // TODO: Map to a qbarrier with an attribute like [Forced] to signal that 75 // this is a forced/hard-coded constraint. 76 auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType, 77 op.inputs()); 78 rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType, 79 qbarrier.getResult()); 80 81 return false; 82 } 83 }; 84 85 class ConstFakeQuantRewrite 86 : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> { 87 public: 88 using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>; 89 90 ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) 91 : BaseRewrite(ctx, hadFailure) {} 92 93 QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp, 94 Type expressedType) const { 95 return fakeQuantAttrsToType( 96 fqOp.getLoc(), fqOp.num_bits().getSExtValue(), 97 fqOp.min().convertToFloat(), fqOp.max().convertToFloat(), 98 fqOp.narrow_range(), expressedType, fqOp.is_signed()); 99 } 100 }; 101 102 class ConstFakeQuantPerAxisRewrite 103 : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, 104 ConstFakeQuantPerAxis> { 105 public: 106 using BaseRewrite = 107 FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>; 108 109 ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure) 110 : BaseRewrite(ctx, hadFailure) {} 111 112 QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp, 113 Type expressedType) const { 114 SmallVector<double, 4> min, max; 115 min.reserve(fqOp.min().size()); 116 max.reserve(fqOp.max().size()); 117 for (auto m : fqOp.min()) 118 min.push_back(m.cast<FloatAttr>().getValueAsDouble()); 119 for (auto m : fqOp.max()) 120 max.push_back(m.cast<FloatAttr>().getValueAsDouble()); 121 122 return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(), 123 fqOp.axis().getSExtValue(), min, max, 124 fqOp.narrow_range(), expressedType, 125 fqOp.is_signed()); 126 } 127 }; 128 129 } // namespace 130 131 void ConvertSimulatedQuantPass::runOnFunction() { 132 bool hadFailure = false; 133 OwningRewritePatternList patterns; 134 auto func = getFunction(); 135 auto ctx = func.getContext(); 136 patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>( 137 ctx, &hadFailure); 138 applyPatternsGreedily(func, patterns); 139 if (hadFailure) 140 signalPassFailure(); 141 } 142 143 std::unique_ptr<OperationPass<FuncOp>> 144 mlir::quant::createConvertSimulatedQuantPass() { 145 return std::make_unique<ConvertSimulatedQuantPass>(); 146 } 147