133245988STres Popp //===-- RemoveShapeConstraints.cpp - Remove Shape Cstr and Assuming Ops ---===//
233245988STres Popp //
333245988STres Popp // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
433245988STres Popp // See https://llvm.org/LICENSE.txt for license information.
533245988STres Popp // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
633245988STres Popp //
733245988STres Popp //===----------------------------------------------------------------------===//
833245988STres Popp 
933245988STres Popp #include "PassDetail.h"
1033245988STres Popp #include "mlir/Dialect/Shape/IR/Shape.h"
1133245988STres Popp #include "mlir/Dialect/Shape/Transforms/Passes.h"
1233245988STres Popp #include "mlir/Transforms/DialectConversion.h"
13b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1433245988STres Popp 
1533245988STres Popp using namespace mlir;
1633245988STres Popp 
1733245988STres Popp namespace {
1833245988STres Popp /// Removal patterns.
1933245988STres Popp class RemoveCstrBroadcastableOp
2033245988STres Popp     : public OpRewritePattern<shape::CstrBroadcastableOp> {
2133245988STres Popp public:
2233245988STres Popp   using OpRewritePattern::OpRewritePattern;
2333245988STres Popp 
matchAndRewrite(shape::CstrBroadcastableOp op,PatternRewriter & rewriter) const2433245988STres Popp   LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
2533245988STres Popp                                 PatternRewriter &rewriter) const override {
2633245988STres Popp     rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
2733245988STres Popp     return success();
2833245988STres Popp   }
2933245988STres Popp };
3033245988STres Popp 
3133245988STres Popp class RemoveCstrEqOp : public OpRewritePattern<shape::CstrEqOp> {
3233245988STres Popp public:
3333245988STres Popp   using OpRewritePattern::OpRewritePattern;
3433245988STres Popp 
matchAndRewrite(shape::CstrEqOp op,PatternRewriter & rewriter) const3533245988STres Popp   LogicalResult matchAndRewrite(shape::CstrEqOp op,
3633245988STres Popp                                 PatternRewriter &rewriter) const override {
3733245988STres Popp     rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
3833245988STres Popp     return success();
3933245988STres Popp   }
4033245988STres Popp };
4133245988STres Popp 
4233245988STres Popp /// Removal pass.
4333245988STres Popp class RemoveShapeConstraintsPass
4433245988STres Popp     : public RemoveShapeConstraintsBase<RemoveShapeConstraintsPass> {
4533245988STres Popp 
runOnOperation()4641574554SRiver Riddle   void runOnOperation() override {
4733245988STres Popp     MLIRContext &ctx = getContext();
4833245988STres Popp 
49dc4e913bSChris Lattner     RewritePatternSet patterns(&ctx);
503a506b31SChris Lattner     populateRemoveShapeConstraintsPatterns(patterns);
5133245988STres Popp 
5241574554SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
5333245988STres Popp   }
5433245988STres Popp };
5533245988STres Popp 
5633245988STres Popp } // namespace
5733245988STres Popp 
populateRemoveShapeConstraintsPatterns(RewritePatternSet & patterns)58dc4e913bSChris Lattner void mlir::populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns) {
59dc4e913bSChris Lattner   patterns.add<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(
603a506b31SChris Lattner       patterns.getContext());
6133245988STres Popp }
6233245988STres Popp 
63*58ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>>
createRemoveShapeConstraintsPass()6441574554SRiver Riddle mlir::createRemoveShapeConstraintsPass() {
6533245988STres Popp   return std::make_unique<RemoveShapeConstraintsPass>();
6633245988STres Popp }
67