1 //===-- RemoveShapeConstraints.cpp - Remove Shape Cstr and Assuming Ops ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Shape/IR/Shape.h" 11 #include "mlir/Dialect/Shape/Transforms/Passes.h" 12 #include "mlir/Transforms/DialectConversion.h" 13 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 14 15 using namespace mlir; 16 17 namespace { 18 /// Removal patterns. 19 class RemoveCstrBroadcastableOp 20 : public OpRewritePattern<shape::CstrBroadcastableOp> { 21 public: 22 using OpRewritePattern::OpRewritePattern; 23 matchAndRewrite(shape::CstrBroadcastableOp op,PatternRewriter & rewriter) const24 LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, 25 PatternRewriter &rewriter) const override { 26 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true); 27 return success(); 28 } 29 }; 30 31 class RemoveCstrEqOp : public OpRewritePattern<shape::CstrEqOp> { 32 public: 33 using OpRewritePattern::OpRewritePattern; 34 matchAndRewrite(shape::CstrEqOp op,PatternRewriter & rewriter) const35 LogicalResult matchAndRewrite(shape::CstrEqOp op, 36 PatternRewriter &rewriter) const override { 37 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true); 38 return success(); 39 } 40 }; 41 42 /// Removal pass. 43 class RemoveShapeConstraintsPass 44 : public RemoveShapeConstraintsBase<RemoveShapeConstraintsPass> { 45 runOnOperation()46 void runOnOperation() override { 47 MLIRContext &ctx = getContext(); 48 49 RewritePatternSet patterns(&ctx); 50 populateRemoveShapeConstraintsPatterns(patterns); 51 52 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 53 } 54 }; 55 56 } // namespace 57 populateRemoveShapeConstraintsPatterns(RewritePatternSet & patterns)58void mlir::populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns) { 59 patterns.add<RemoveCstrBroadcastableOp, RemoveCstrEqOp>( 60 patterns.getContext()); 61 } 62 63 std::unique_ptr<OperationPass<func::FuncOp>> createRemoveShapeConstraintsPass()64mlir::createRemoveShapeConstraintsPass() { 65 return std::make_unique<RemoveShapeConstraintsPass>(); 66 } 67