1 //===- PatternBenefit.cpp - RewritePattern benefit unit tests -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/PatternMatch.h" 10 #include "mlir/Rewrite/PatternApplicator.h" 11 #include "gtest/gtest.h" 12 13 using namespace mlir; 14 15 namespace { 16 TEST(PatternBenefitTest, BenefitOrder) { 17 // There was a bug which caused low-benefit op-specific patterns to never be 18 // called in presence of high-benefit op-agnostic pattern 19 20 MLIRContext context; 21 22 OpBuilder builder(&context); 23 auto module = ModuleOp::create(builder.getUnknownLoc()); 24 25 struct Pattern1 : public OpRewritePattern<ModuleOp> { 26 Pattern1(mlir::MLIRContext *context, bool *called) 27 : OpRewritePattern<ModuleOp>(context, /*benefit*/ 1), called(called) {} 28 29 mlir::LogicalResult 30 matchAndRewrite(ModuleOp /*op*/, 31 mlir::PatternRewriter & /*rewriter*/) const override { 32 *called = true; 33 return failure(); 34 } 35 36 private: 37 bool *called; 38 }; 39 40 struct Pattern2 : public RewritePattern { 41 Pattern2(MLIRContext *context, bool *called) 42 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/2, context), 43 called(called) {} 44 45 mlir::LogicalResult 46 matchAndRewrite(Operation * /*op*/, 47 mlir::PatternRewriter & /*rewriter*/) const override { 48 *called = true; 49 return failure(); 50 } 51 52 private: 53 bool *called; 54 }; 55 56 RewritePatternSet patterns(&context); 57 58 bool called1 = false; 59 bool called2 = false; 60 61 patterns.add<Pattern1>(&context, &called1); 62 patterns.add<Pattern2>(&context, &called2); 63 64 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 65 PatternApplicator pa(frozenPatterns); 66 pa.applyDefaultCostModel(); 67 68 class MyPatternRewriter : public PatternRewriter { 69 public: 70 MyPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx) {} 71 }; 72 73 MyPatternRewriter rewriter(&context); 74 (void)pa.matchAndRewrite(module, rewriter); 75 76 EXPECT_TRUE(called1); 77 EXPECT_TRUE(called2); 78 } 79 } // namespace 80