1 //===- TestMatchers.cpp - Pass to test matchers ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/IR/BuiltinOps.h"
11 #include "mlir/IR/FunctionInterfaces.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/Pass/Pass.h"
14 
15 using namespace mlir;
16 
17 namespace {
18 /// This is a test pass for verifying matchers.
19 struct TestMatchers
20     : public PassWrapper<TestMatchers, InterfacePass<FunctionOpInterface>> {
21   void runOnOperation() override;
22   StringRef getArgument() const final { return "test-matchers"; }
23   StringRef getDescription() const final {
24     return "Test C++ pattern matchers.";
25   }
26 };
27 } // namespace
28 
29 // This could be done better but is not worth the variadic template trouble.
30 template <typename Matcher>
31 static unsigned countMatches(FunctionOpInterface f, Matcher &matcher) {
32   unsigned count = 0;
33   f.walk([&count, &matcher](Operation *op) {
34     if (matcher.match(op))
35       ++count;
36   });
37   return count;
38 }
39 
40 using mlir::matchers::m_Any;
41 using mlir::matchers::m_Val;
42 static void test1(FunctionOpInterface f) {
43   assert(f.getNumArguments() == 3 && "matcher test funcs must have 3 args");
44 
45   auto a = m_Val(f.getArgument(0));
46   auto b = m_Val(f.getArgument(1));
47   auto c = m_Val(f.getArgument(2));
48 
49   auto p0 = m_Op<arith::AddFOp>(); // using 0-arity matcher
50   llvm::outs() << "Pattern add(*) matched " << countMatches(f, p0)
51                << " times\n";
52 
53   auto p1 = m_Op<arith::MulFOp>(); // using 0-arity matcher
54   llvm::outs() << "Pattern mul(*) matched " << countMatches(f, p1)
55                << " times\n";
56 
57   auto p2 = m_Op<arith::AddFOp>(m_Op<arith::AddFOp>(), m_Any());
58   llvm::outs() << "Pattern add(add(*), *) matched " << countMatches(f, p2)
59                << " times\n";
60 
61   auto p3 = m_Op<arith::AddFOp>(m_Any(), m_Op<arith::AddFOp>());
62   llvm::outs() << "Pattern add(*, add(*)) matched " << countMatches(f, p3)
63                << " times\n";
64 
65   auto p4 = m_Op<arith::MulFOp>(m_Op<arith::AddFOp>(), m_Any());
66   llvm::outs() << "Pattern mul(add(*), *) matched " << countMatches(f, p4)
67                << " times\n";
68 
69   auto p5 = m_Op<arith::MulFOp>(m_Any(), m_Op<arith::AddFOp>());
70   llvm::outs() << "Pattern mul(*, add(*)) matched " << countMatches(f, p5)
71                << " times\n";
72 
73   auto p6 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Any());
74   llvm::outs() << "Pattern mul(mul(*), *) matched " << countMatches(f, p6)
75                << " times\n";
76 
77   auto p7 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::MulFOp>());
78   llvm::outs() << "Pattern mul(mul(*), mul(*)) matched " << countMatches(f, p7)
79                << " times\n";
80 
81   auto mulOfMulmul =
82       m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::MulFOp>());
83   auto p8 = m_Op<arith::MulFOp>(mulOfMulmul, mulOfMulmul);
84   llvm::outs()
85       << "Pattern mul(mul(mul(*), mul(*)), mul(mul(*), mul(*))) matched "
86       << countMatches(f, p8) << " times\n";
87 
88   // clang-format off
89   auto mulOfMuladd = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::AddFOp>());
90   auto mulOfAnyadd = m_Op<arith::MulFOp>(m_Any(), m_Op<arith::AddFOp>());
91   auto p9 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(
92                      mulOfMuladd, m_Op<arith::MulFOp>()),
93                    m_Op<arith::MulFOp>(mulOfAnyadd, mulOfAnyadd));
94   // clang-format on
95   llvm::outs() << "Pattern mul(mul(mul(mul(*), add(*)), mul(*)), mul(mul(*, "
96                   "add(*)), mul(*, add(*)))) matched "
97                << countMatches(f, p9) << " times\n";
98 
99   auto p10 = m_Op<arith::AddFOp>(a, b);
100   llvm::outs() << "Pattern add(a, b) matched " << countMatches(f, p10)
101                << " times\n";
102 
103   auto p11 = m_Op<arith::AddFOp>(a, c);
104   llvm::outs() << "Pattern add(a, c) matched " << countMatches(f, p11)
105                << " times\n";
106 
107   auto p12 = m_Op<arith::AddFOp>(b, a);
108   llvm::outs() << "Pattern add(b, a) matched " << countMatches(f, p12)
109                << " times\n";
110 
111   auto p13 = m_Op<arith::AddFOp>(c, a);
112   llvm::outs() << "Pattern add(c, a) matched " << countMatches(f, p13)
113                << " times\n";
114 
115   auto p14 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(c, b));
116   llvm::outs() << "Pattern mul(a, add(c, b)) matched " << countMatches(f, p14)
117                << " times\n";
118 
119   auto p15 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(b, c));
120   llvm::outs() << "Pattern mul(a, add(b, c)) matched " << countMatches(f, p15)
121                << " times\n";
122 
123   auto mulOfAany = m_Op<arith::MulFOp>(a, m_Any());
124   auto p16 = m_Op<arith::MulFOp>(mulOfAany, m_Op<arith::AddFOp>(a, c));
125   llvm::outs() << "Pattern mul(mul(a, *), add(a, c)) matched "
126                << countMatches(f, p16) << " times\n";
127 
128   auto p17 = m_Op<arith::MulFOp>(mulOfAany, m_Op<arith::AddFOp>(c, b));
129   llvm::outs() << "Pattern mul(mul(a, *), add(c, b)) matched "
130                << countMatches(f, p17) << " times\n";
131 }
132 
133 void test2(FunctionOpInterface f) {
134   auto a = m_Val(f.getArgument(0));
135   FloatAttr floatAttr;
136   auto p =
137       m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(a, m_Constant(&floatAttr)));
138   auto p1 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(a, m_Constant()));
139   // Last operation that is not the terminator.
140   Operation *lastOp = f.getBody().front().back().getPrevNode();
141   if (p.match(lastOp))
142     llvm::outs()
143         << "Pattern add(add(a, constant), a) matched and bound constant to: "
144         << floatAttr.getValueAsDouble() << "\n";
145   if (p1.match(lastOp))
146     llvm::outs() << "Pattern add(add(a, constant), a) matched\n";
147 }
148 
149 void TestMatchers::runOnOperation() {
150   auto f = getOperation();
151   llvm::outs() << f.getName() << "\n";
152   if (f.getName() == "test1")
153     test1(f);
154   if (f.getName() == "test2")
155     test2(f);
156 }
157 
158 namespace mlir {
159 void registerTestMatchers() { PassRegistration<TestMatchers>(); }
160 } // namespace mlir
161