133245988STres Popp //===-- RemoveShapeConstraints.cpp - Remove Shape Cstr and Assuming Ops ---===// 233245988STres Popp // 333245988STres Popp // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 433245988STres Popp // See https://llvm.org/LICENSE.txt for license information. 533245988STres Popp // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 633245988STres Popp // 733245988STres Popp //===----------------------------------------------------------------------===// 833245988STres Popp 933245988STres Popp #include "PassDetail.h" 1033245988STres Popp #include "mlir/Dialect/Shape/IR/Shape.h" 1133245988STres Popp #include "mlir/Dialect/Shape/Transforms/Passes.h" 1233245988STres Popp #include "mlir/Transforms/DialectConversion.h" 13b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 1433245988STres Popp 1533245988STres Popp using namespace mlir; 1633245988STres Popp 1733245988STres Popp namespace { 1833245988STres Popp /// Removal patterns. 1933245988STres Popp class RemoveCstrBroadcastableOp 2033245988STres Popp : public OpRewritePattern<shape::CstrBroadcastableOp> { 2133245988STres Popp public: 2233245988STres Popp using OpRewritePattern::OpRewritePattern; 2333245988STres Popp matchAndRewrite(shape::CstrBroadcastableOp op,PatternRewriter & rewriter) const2433245988STres Popp LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, 2533245988STres Popp PatternRewriter &rewriter) const override { 2633245988STres Popp rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true); 2733245988STres Popp return success(); 2833245988STres Popp } 2933245988STres Popp }; 3033245988STres Popp 3133245988STres Popp class RemoveCstrEqOp : public OpRewritePattern<shape::CstrEqOp> { 3233245988STres Popp public: 3333245988STres Popp using OpRewritePattern::OpRewritePattern; 3433245988STres Popp matchAndRewrite(shape::CstrEqOp op,PatternRewriter & rewriter) const3533245988STres Popp LogicalResult matchAndRewrite(shape::CstrEqOp op, 3633245988STres Popp PatternRewriter &rewriter) const override { 3733245988STres Popp rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true); 3833245988STres Popp return success(); 3933245988STres Popp } 4033245988STres Popp }; 4133245988STres Popp 4233245988STres Popp /// Removal pass. 4333245988STres Popp class RemoveShapeConstraintsPass 4433245988STres Popp : public RemoveShapeConstraintsBase<RemoveShapeConstraintsPass> { 4533245988STres Popp runOnOperation()4641574554SRiver Riddle void runOnOperation() override { 4733245988STres Popp MLIRContext &ctx = getContext(); 4833245988STres Popp 49dc4e913bSChris Lattner RewritePatternSet patterns(&ctx); 503a506b31SChris Lattner populateRemoveShapeConstraintsPatterns(patterns); 5133245988STres Popp 5241574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 5333245988STres Popp } 5433245988STres Popp }; 5533245988STres Popp 5633245988STres Popp } // namespace 5733245988STres Popp populateRemoveShapeConstraintsPatterns(RewritePatternSet & patterns)58dc4e913bSChris Lattnervoid mlir::populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns) { 59dc4e913bSChris Lattner patterns.add<RemoveCstrBroadcastableOp, RemoveCstrEqOp>( 603a506b31SChris Lattner patterns.getContext()); 6133245988STres Popp } 6233245988STres Popp 63*58ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>> createRemoveShapeConstraintsPass()6441574554SRiver Riddlemlir::createRemoveShapeConstraintsPass() { 6533245988STres Popp return std::make_unique<RemoveShapeConstraintsPass>(); 6633245988STres Popp } 67