1 //===- TestMatchers.cpp - Pass to test matchers ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/IR/BuiltinOps.h" 11 #include "mlir/IR/FunctionInterfaces.h" 12 #include "mlir/IR/Matchers.h" 13 #include "mlir/Pass/Pass.h" 14 15 using namespace mlir; 16 17 namespace { 18 /// This is a test pass for verifying matchers. 19 struct TestMatchers 20 : public PassWrapper<TestMatchers, InterfacePass<FunctionOpInterface>> { 21 void runOnOperation() override; 22 StringRef getArgument() const final { return "test-matchers"; } 23 StringRef getDescription() const final { 24 return "Test C++ pattern matchers."; 25 } 26 }; 27 } // namespace 28 29 // This could be done better but is not worth the variadic template trouble. 30 template <typename Matcher> 31 static unsigned countMatches(FunctionOpInterface f, Matcher &matcher) { 32 unsigned count = 0; 33 f.walk([&count, &matcher](Operation *op) { 34 if (matcher.match(op)) 35 ++count; 36 }); 37 return count; 38 } 39 40 using mlir::matchers::m_Any; 41 using mlir::matchers::m_Val; 42 static void test1(FunctionOpInterface f) { 43 assert(f.getNumArguments() == 3 && "matcher test funcs must have 3 args"); 44 45 auto a = m_Val(f.getArgument(0)); 46 auto b = m_Val(f.getArgument(1)); 47 auto c = m_Val(f.getArgument(2)); 48 49 auto p0 = m_Op<arith::AddFOp>(); // using 0-arity matcher 50 llvm::outs() << "Pattern add(*) matched " << countMatches(f, p0) 51 << " times\n"; 52 53 auto p1 = m_Op<arith::MulFOp>(); // using 0-arity matcher 54 llvm::outs() << "Pattern mul(*) matched " << countMatches(f, p1) 55 << " times\n"; 56 57 auto p2 = m_Op<arith::AddFOp>(m_Op<arith::AddFOp>(), m_Any()); 58 llvm::outs() << "Pattern add(add(*), *) matched " << countMatches(f, p2) 59 << " times\n"; 60 61 auto p3 = m_Op<arith::AddFOp>(m_Any(), m_Op<arith::AddFOp>()); 62 llvm::outs() << "Pattern add(*, add(*)) matched " << countMatches(f, p3) 63 << " times\n"; 64 65 auto p4 = m_Op<arith::MulFOp>(m_Op<arith::AddFOp>(), m_Any()); 66 llvm::outs() << "Pattern mul(add(*), *) matched " << countMatches(f, p4) 67 << " times\n"; 68 69 auto p5 = m_Op<arith::MulFOp>(m_Any(), m_Op<arith::AddFOp>()); 70 llvm::outs() << "Pattern mul(*, add(*)) matched " << countMatches(f, p5) 71 << " times\n"; 72 73 auto p6 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Any()); 74 llvm::outs() << "Pattern mul(mul(*), *) matched " << countMatches(f, p6) 75 << " times\n"; 76 77 auto p7 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::MulFOp>()); 78 llvm::outs() << "Pattern mul(mul(*), mul(*)) matched " << countMatches(f, p7) 79 << " times\n"; 80 81 auto mulOfMulmul = 82 m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::MulFOp>()); 83 auto p8 = m_Op<arith::MulFOp>(mulOfMulmul, mulOfMulmul); 84 llvm::outs() 85 << "Pattern mul(mul(mul(*), mul(*)), mul(mul(*), mul(*))) matched " 86 << countMatches(f, p8) << " times\n"; 87 88 // clang-format off 89 auto mulOfMuladd = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::AddFOp>()); 90 auto mulOfAnyadd = m_Op<arith::MulFOp>(m_Any(), m_Op<arith::AddFOp>()); 91 auto p9 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>( 92 mulOfMuladd, m_Op<arith::MulFOp>()), 93 m_Op<arith::MulFOp>(mulOfAnyadd, mulOfAnyadd)); 94 // clang-format on 95 llvm::outs() << "Pattern mul(mul(mul(mul(*), add(*)), mul(*)), mul(mul(*, " 96 "add(*)), mul(*, add(*)))) matched " 97 << countMatches(f, p9) << " times\n"; 98 99 auto p10 = m_Op<arith::AddFOp>(a, b); 100 llvm::outs() << "Pattern add(a, b) matched " << countMatches(f, p10) 101 << " times\n"; 102 103 auto p11 = m_Op<arith::AddFOp>(a, c); 104 llvm::outs() << "Pattern add(a, c) matched " << countMatches(f, p11) 105 << " times\n"; 106 107 auto p12 = m_Op<arith::AddFOp>(b, a); 108 llvm::outs() << "Pattern add(b, a) matched " << countMatches(f, p12) 109 << " times\n"; 110 111 auto p13 = m_Op<arith::AddFOp>(c, a); 112 llvm::outs() << "Pattern add(c, a) matched " << countMatches(f, p13) 113 << " times\n"; 114 115 auto p14 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(c, b)); 116 llvm::outs() << "Pattern mul(a, add(c, b)) matched " << countMatches(f, p14) 117 << " times\n"; 118 119 auto p15 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(b, c)); 120 llvm::outs() << "Pattern mul(a, add(b, c)) matched " << countMatches(f, p15) 121 << " times\n"; 122 123 auto mulOfAany = m_Op<arith::MulFOp>(a, m_Any()); 124 auto p16 = m_Op<arith::MulFOp>(mulOfAany, m_Op<arith::AddFOp>(a, c)); 125 llvm::outs() << "Pattern mul(mul(a, *), add(a, c)) matched " 126 << countMatches(f, p16) << " times\n"; 127 128 auto p17 = m_Op<arith::MulFOp>(mulOfAany, m_Op<arith::AddFOp>(c, b)); 129 llvm::outs() << "Pattern mul(mul(a, *), add(c, b)) matched " 130 << countMatches(f, p17) << " times\n"; 131 } 132 133 void test2(FunctionOpInterface f) { 134 auto a = m_Val(f.getArgument(0)); 135 FloatAttr floatAttr; 136 auto p = 137 m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(a, m_Constant(&floatAttr))); 138 auto p1 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(a, m_Constant())); 139 // Last operation that is not the terminator. 140 Operation *lastOp = f.getBody().front().back().getPrevNode(); 141 if (p.match(lastOp)) 142 llvm::outs() 143 << "Pattern add(add(a, constant), a) matched and bound constant to: " 144 << floatAttr.getValueAsDouble() << "\n"; 145 if (p1.match(lastOp)) 146 llvm::outs() << "Pattern add(add(a, constant), a) matched\n"; 147 } 148 149 void TestMatchers::runOnOperation() { 150 auto f = getOperation(); 151 llvm::outs() << f.getName() << "\n"; 152 if (f.getName() == "test1") 153 test1(f); 154 if (f.getName() == "test2") 155 test2(f); 156 } 157 158 namespace mlir { 159 void registerTestMatchers() { PassRegistration<TestMatchers>(); } 160 } // namespace mlir 161