1 //===- PatternBenefit.cpp - RewritePattern benefit unit tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/Rewrite/PatternApplicator.h"
11 #include "gtest/gtest.h"
12 
13 using namespace mlir;
14 
15 namespace {
16 TEST(PatternBenefitTest, BenefitOrder) {
17   // There was a bug which caused low-benefit op-specific patterns to never be
18   // called in presence of high-benefit op-agnostic pattern
19 
20   MLIRContext context;
21 
22   OpBuilder builder(&context);
23   auto module = ModuleOp::create(builder.getUnknownLoc());
24 
25   struct Pattern1 : public OpRewritePattern<ModuleOp> {
26     Pattern1(mlir::MLIRContext *context, bool *called)
27         : OpRewritePattern<ModuleOp>(context, /*benefit*/ 1), called(called) {}
28 
29     mlir::LogicalResult
30     matchAndRewrite(ModuleOp /*op*/,
31                     mlir::PatternRewriter & /*rewriter*/) const override {
32       *called = true;
33       return failure();
34     }
35 
36   private:
37     bool *called;
38   };
39 
40   struct Pattern2 : public RewritePattern {
41     Pattern2(MLIRContext *context, bool *called)
42         : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/2, context),
43           called(called) {}
44 
45     mlir::LogicalResult
46     matchAndRewrite(Operation * /*op*/,
47                     mlir::PatternRewriter & /*rewriter*/) const override {
48       *called = true;
49       return failure();
50     }
51 
52   private:
53     bool *called;
54   };
55 
56   RewritePatternSet patterns(&context);
57 
58   bool called1 = false;
59   bool called2 = false;
60 
61   patterns.add<Pattern1>(&context, &called1);
62   patterns.add<Pattern2>(&context, &called2);
63 
64   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
65   PatternApplicator pa(frozenPatterns);
66   pa.applyDefaultCostModel();
67 
68   class MyPatternRewriter : public PatternRewriter {
69   public:
70     MyPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx) {}
71   };
72 
73   MyPatternRewriter rewriter(&context);
74   (void)pa.matchAndRewrite(module, rewriter);
75 
76   EXPECT_TRUE(called1);
77   EXPECT_TRUE(called2);
78 }
79 } // namespace
80