1 //===-- RemoveShapeConstraints.cpp - Remove Shape Cstr and Assuming Ops ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Shape/IR/Shape.h"
11 #include "mlir/Dialect/Shape/Transforms/Passes.h"
12 #include "mlir/Transforms/DialectConversion.h"
13 
14 using namespace mlir;
15 
16 namespace {
17 /// Removal patterns.
18 class RemoveCstrBroadcastableOp
19     : public OpRewritePattern<shape::CstrBroadcastableOp> {
20 public:
21   using OpRewritePattern::OpRewritePattern;
22 
23   LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
24                                 PatternRewriter &rewriter) const override {
25     rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
26     return success();
27   }
28 };
29 
30 class RemoveCstrEqOp : public OpRewritePattern<shape::CstrEqOp> {
31 public:
32   using OpRewritePattern::OpRewritePattern;
33 
34   LogicalResult matchAndRewrite(shape::CstrEqOp op,
35                                 PatternRewriter &rewriter) const override {
36     rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
37     return success();
38   }
39 };
40 
41 /// Removal pass.
42 class RemoveShapeConstraintsPass
43     : public RemoveShapeConstraintsBase<RemoveShapeConstraintsPass> {
44 
45   void runOnFunction() override {
46     MLIRContext &ctx = getContext();
47 
48     OwningRewritePatternList patterns;
49     populateRemoveShapeConstraintsPatterns(patterns, &ctx);
50 
51     applyPatternsAndFoldGreedily(getFunction(), patterns);
52   }
53 };
54 
55 } // namespace
56 
57 void mlir::populateRemoveShapeConstraintsPatterns(
58     OwningRewritePatternList &patterns, MLIRContext *ctx) {
59   patterns.insert<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(ctx);
60 }
61 
62 std::unique_ptr<FunctionPass> mlir::createRemoveShapeConstraintsPass() {
63   return std::make_unique<RemoveShapeConstraintsPass>();
64 }
65