1 //===-- RemoveShapeConstraints.cpp - Remove Shape Cstr and Assuming Ops ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Shape/IR/Shape.h" 11 #include "mlir/Dialect/Shape/Transforms/Passes.h" 12 #include "mlir/Transforms/DialectConversion.h" 13 14 using namespace mlir; 15 16 namespace { 17 /// Removal patterns. 18 class RemoveCstrBroadcastableOp 19 : public OpRewritePattern<shape::CstrBroadcastableOp> { 20 public: 21 using OpRewritePattern::OpRewritePattern; 22 23 LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, 24 PatternRewriter &rewriter) const override { 25 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true); 26 return success(); 27 } 28 }; 29 30 class RemoveCstrEqOp : public OpRewritePattern<shape::CstrEqOp> { 31 public: 32 using OpRewritePattern::OpRewritePattern; 33 34 LogicalResult matchAndRewrite(shape::CstrEqOp op, 35 PatternRewriter &rewriter) const override { 36 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true); 37 return success(); 38 } 39 }; 40 41 /// Removal pass. 42 class RemoveShapeConstraintsPass 43 : public RemoveShapeConstraintsBase<RemoveShapeConstraintsPass> { 44 45 void runOnFunction() override { 46 MLIRContext &ctx = getContext(); 47 48 OwningRewritePatternList patterns; 49 populateRemoveShapeConstraintsPatterns(patterns, &ctx); 50 51 applyPatternsAndFoldGreedily(getFunction(), patterns); 52 } 53 }; 54 55 } // namespace 56 57 void mlir::populateRemoveShapeConstraintsPatterns( 58 OwningRewritePatternList &patterns, MLIRContext *ctx) { 59 patterns.insert<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(ctx); 60 } 61 62 std::unique_ptr<FunctionPass> mlir::createRemoveShapeConstraintsPass() { 63 return std::make_unique<RemoveShapeConstraintsPass>(); 64 } 65