1 //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Quant/FakeQuantSupport.h" 11 #include "mlir/Dialect/Quant/Passes.h" 12 #include "mlir/Dialect/Quant/QuantOps.h" 13 #include "mlir/Dialect/Quant/UniformSupport.h" 14 #include "mlir/IR/BuiltinTypes.h" 15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 16 17 using namespace mlir; 18 using namespace mlir::quant; 19 20 namespace { 21 struct ConvertSimulatedQuantPass 22 : public QuantConvertSimulatedQuantBase<ConvertSimulatedQuantPass> { 23 void runOnOperation() override; 24 }; 25 26 /// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair. 27 template <typename ConcreteRewriteClass, typename FakeQuantOp> 28 class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> { 29 public: 30 using OpRewritePattern<FakeQuantOp>::OpRewritePattern; 31 32 FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) 33 : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {} 34 35 LogicalResult matchAndRewrite(FakeQuantOp op, 36 PatternRewriter &rewriter) const override { 37 // TODO: If this pattern comes up more frequently, consider adding core 38 // support for failable rewrites. 39 if (failableRewrite(op, rewriter)) { 40 *hadFailure = true; 41 return failure(); 42 } 43 44 return success(); 45 } 46 47 private: 48 bool *hadFailure; 49 50 bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const { 51 auto converter = ExpressedToQuantizedConverter::forInputType(op.getType()); 52 if (!converter) { 53 return (op.emitError("unsupported quantized type conversion"), true); 54 } 55 56 QuantizedType elementType = 57 static_cast<const ConcreteRewriteClass *>(this) 58 ->convertFakeQuantAttrsToType(op, converter.expressedType); 59 60 if (!elementType) { 61 // Note that the fakeQuantAttrsToType will have emitted the error. 62 return true; 63 } 64 65 Type quantizedType = converter.convert(elementType); 66 assert(quantizedType && 67 "Converter accepted a type that it did not convert"); 68 69 // TODO: Map to a qbarrier with an attribute like [Forced] to signal that 70 // this is a forced/hard-coded constraint. 71 auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType, 72 op.getInputs()); 73 rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType, 74 qbarrier.getResult()); 75 76 return false; 77 } 78 }; 79 80 class ConstFakeQuantRewrite 81 : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> { 82 public: 83 using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>; 84 85 ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) 86 : BaseRewrite(ctx, hadFailure) {} 87 88 QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp, 89 Type expressedType) const { 90 return fakeQuantAttrsToType( 91 fqOp.getLoc(), fqOp.getNumBits(), fqOp.getMin().convertToFloat(), 92 fqOp.getMax().convertToFloat(), fqOp.getNarrowRange(), expressedType, 93 fqOp.getIsSigned()); 94 } 95 }; 96 97 class ConstFakeQuantPerAxisRewrite 98 : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, 99 ConstFakeQuantPerAxis> { 100 public: 101 using BaseRewrite = 102 FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>; 103 104 ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure) 105 : BaseRewrite(ctx, hadFailure) {} 106 107 QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp, 108 Type expressedType) const { 109 SmallVector<double, 4> min, max; 110 min.reserve(fqOp.getMin().size()); 111 max.reserve(fqOp.getMax().size()); 112 for (auto m : fqOp.getMin()) 113 min.push_back(m.cast<FloatAttr>().getValueAsDouble()); 114 for (auto m : fqOp.getMax()) 115 max.push_back(m.cast<FloatAttr>().getValueAsDouble()); 116 117 return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.getNumBits(), 118 fqOp.getAxis(), min, max, fqOp.getNarrowRange(), 119 expressedType, fqOp.getIsSigned()); 120 } 121 }; 122 123 } // namespace 124 125 void ConvertSimulatedQuantPass::runOnOperation() { 126 bool hadFailure = false; 127 auto func = getOperation(); 128 RewritePatternSet patterns(func.getContext()); 129 auto *ctx = func.getContext(); 130 patterns.add<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>( 131 ctx, &hadFailure); 132 (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 133 if (hadFailure) 134 signalPassFailure(); 135 } 136 137 std::unique_ptr<OperationPass<func::FuncOp>> 138 mlir::quant::createConvertSimulatedQuantPass() { 139 return std::make_unique<ConvertSimulatedQuantPass>(); 140 } 141