1 //===- PatternBenefit.cpp - RewritePattern benefit unit tests -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/OwningOpRef.h" 10 #include "mlir/IR/PatternMatch.h" 11 #include "mlir/Rewrite/PatternApplicator.h" 12 #include "gtest/gtest.h" 13 14 using namespace mlir; 15 16 namespace { 17 TEST(PatternBenefitTest, BenefitOrder) { 18 // There was a bug which caused low-benefit op-specific patterns to never be 19 // called in presence of high-benefit op-agnostic pattern 20 21 MLIRContext context; 22 23 OpBuilder builder(&context); 24 OwningOpRef<ModuleOp> module = ModuleOp::create(builder.getUnknownLoc()); 25 26 struct Pattern1 : public OpRewritePattern<ModuleOp> { 27 Pattern1(mlir::MLIRContext *context, bool *called) 28 : OpRewritePattern<ModuleOp>(context, /*benefit*/ 1), called(called) {} 29 30 mlir::LogicalResult 31 matchAndRewrite(ModuleOp /*op*/, 32 mlir::PatternRewriter & /*rewriter*/) const override { 33 *called = true; 34 return failure(); 35 } 36 37 private: 38 bool *called; 39 }; 40 41 struct Pattern2 : public RewritePattern { 42 Pattern2(MLIRContext *context, bool *called) 43 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/2, context), 44 called(called) {} 45 46 mlir::LogicalResult 47 matchAndRewrite(Operation * /*op*/, 48 mlir::PatternRewriter & /*rewriter*/) const override { 49 *called = true; 50 return failure(); 51 } 52 53 private: 54 bool *called; 55 }; 56 57 RewritePatternSet patterns(&context); 58 59 bool called1 = false; 60 bool called2 = false; 61 62 patterns.add<Pattern1>(&context, &called1); 63 patterns.add<Pattern2>(&context, &called2); 64 65 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 66 PatternApplicator pa(frozenPatterns); 67 pa.applyDefaultCostModel(); 68 69 class MyPatternRewriter : public PatternRewriter { 70 public: 71 MyPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx) {} 72 }; 73 74 MyPatternRewriter rewriter(&context); 75 (void)pa.matchAndRewrite(*module, rewriter); 76 77 EXPECT_TRUE(called1); 78 EXPECT_TRUE(called2); 79 } 80 } // namespace 81