1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements mlir::applyPatternsGreedily.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/StandardOps/Ops.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Transforms/FoldUtils.h"
17 #include "mlir/Transforms/RegionUtils.h"
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 using namespace mlir;
24 
25 #define DEBUG_TYPE "pattern-matcher"
26 
27 static llvm::cl::opt<unsigned> maxPatternMatchIterations(
28     "mlir-max-pattern-match-iterations",
29     llvm::cl::desc("Max number of iterations scanning for pattern match"),
30     llvm::cl::init(10));
31 
32 namespace {
33 
34 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
35 /// applies the locally optimal patterns in a roughly "bottom up" way.
36 class GreedyPatternRewriteDriver : public PatternRewriter {
37 public:
38   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
39                                       const OwningRewritePatternList &patterns)
40       : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
41     worklist.reserve(64);
42   }
43 
44   /// Perform the rewrites. Return true if the rewrite converges in
45   /// `maxIterations`.
46   bool simplify(MutableArrayRef<Region> regions, int maxIterations);
47 
48   void addToWorklist(Operation *op) {
49     // Check to see if the worklist already contains this op.
50     if (worklistMap.count(op))
51       return;
52 
53     worklistMap[op] = worklist.size();
54     worklist.push_back(op);
55   }
56 
57   Operation *popFromWorklist() {
58     auto *op = worklist.back();
59     worklist.pop_back();
60 
61     // This operation is no longer in the worklist, keep worklistMap up to date.
62     if (op)
63       worklistMap.erase(op);
64     return op;
65   }
66 
67   /// If the specified operation is in the worklist, remove it.  If not, this is
68   /// a no-op.
69   void removeFromWorklist(Operation *op) {
70     auto it = worklistMap.find(op);
71     if (it != worklistMap.end()) {
72       assert(worklist[it->second] == op && "malformed worklist data structure");
73       worklist[it->second] = nullptr;
74       worklistMap.erase(it);
75     }
76   }
77 
78   // These are hooks implemented for PatternRewriter.
79 protected:
80   // Implement the hook for inserting operations, and make sure that newly
81   // inserted ops are added to the worklist for processing.
82   Operation *insert(Operation *op) override {
83     addToWorklist(op);
84     return OpBuilder::insert(op);
85   }
86 
87   // If an operation is about to be removed, make sure it is not in our
88   // worklist anymore because we'd get dangling references to it.
89   void notifyOperationRemoved(Operation *op) override {
90     addToWorklist(op->getOperands());
91     op->walk([this](Operation *operation) {
92       removeFromWorklist(operation);
93       folder.notifyRemoval(operation);
94     });
95   }
96 
97   // When the root of a pattern is about to be replaced, it can trigger
98   // simplifications to its users - make sure to add them to the worklist
99   // before the root is changed.
100   void notifyRootReplaced(Operation *op) override {
101     for (auto result : op->getResults())
102       for (auto *user : result.getUsers())
103         addToWorklist(user);
104   }
105 
106 private:
107   // Look over the provided operands for any defining operations that should
108   // be re-added to the worklist. This function should be called when an
109   // operation is modified or removed, as it may trigger further
110   // simplifications.
111   template <typename Operands> void addToWorklist(Operands &&operands) {
112     for (Value operand : operands) {
113       // If the use count of this operand is now < 2, we re-add the defining
114       // operation to the worklist.
115       // TODO(riverriddle) This is based on the fact that zero use operations
116       // may be deleted, and that single use values often have more
117       // canonicalization opportunities.
118       if (!operand.use_empty() && !operand.hasOneUse())
119         continue;
120       if (auto *defInst = operand.getDefiningOp())
121         addToWorklist(defInst);
122     }
123   }
124 
125   /// The low-level pattern matcher.
126   RewritePatternMatcher matcher;
127 
128   /// The worklist for this transformation keeps track of the operations that
129   /// need to be revisited, plus their index in the worklist.  This allows us to
130   /// efficiently remove operations from the worklist when they are erased, even
131   /// if they aren't the root of a pattern.
132   std::vector<Operation *> worklist;
133   DenseMap<Operation *, unsigned> worklistMap;
134 
135   /// Non-pattern based folder for operations.
136   OperationFolder folder;
137 };
138 } // end anonymous namespace
139 
140 /// Perform the rewrites.
141 bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
142                                           int maxIterations) {
143   // Add the given operation to the worklist.
144   auto collectOps = [this](Operation *op) { addToWorklist(op); };
145 
146   bool changed = false;
147   int i = 0;
148   do {
149     // Add all nested operations to the worklist.
150     for (auto &region : regions)
151       region.walk(collectOps);
152 
153     // These are scratch vectors used in the folding loop below.
154     SmallVector<Value, 8> originalOperands, resultValues;
155 
156     changed = false;
157     while (!worklist.empty()) {
158       auto *op = popFromWorklist();
159 
160       // Nulls get added to the worklist when operations are removed, ignore
161       // them.
162       if (op == nullptr)
163         continue;
164 
165       // If the operation has no side effects, and no users, then it is
166       // trivially dead - remove it.
167       if (op->hasNoSideEffect() && op->use_empty()) {
168         // Be careful to update bookkeeping.
169         notifyOperationRemoved(op);
170         op->erase();
171         continue;
172       }
173 
174       // Collects all the operands and result uses of the given `op` into work
175       // list. Also remove `op` and nested ops from worklist.
176       originalOperands.assign(op->operand_begin(), op->operand_end());
177       auto preReplaceAction = [&](Operation *op) {
178         // Add the operands to the worklist for visitation.
179         addToWorklist(originalOperands);
180 
181         // Add all the users of the result to the worklist so we make sure
182         // to revisit them.
183         for (auto result : op->getResults())
184           for (auto *operand : result.getUsers())
185             addToWorklist(operand);
186 
187         notifyOperationRemoved(op);
188       };
189 
190       // Try to fold this op.
191       if (succeeded(folder.tryToFold(op, collectOps, preReplaceAction))) {
192         changed |= true;
193         continue;
194       }
195 
196       // Make sure that any new operations are inserted at this point.
197       setInsertionPoint(op);
198 
199       // Try to match one of the patterns. The rewriter is automatically
200       // notified of any necessary changes, so there is nothing else to do here.
201       changed |= matcher.matchAndRewrite(op, *this);
202     }
203 
204     // After applying patterns, make sure that the CFG of each of the regions is
205     // kept up to date.
206     changed |= succeeded(simplifyRegions(regions));
207   } while (changed && ++i < maxIterations);
208   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
209   return !changed;
210 }
211 
212 /// Rewrite the regions of the specified operation, which must be isolated from
213 /// above, by repeatedly applying the highest benefit patterns in a greedy
214 /// work-list driven manner. Return true if no more patterns can be matched in
215 /// the result operation regions.
216 /// Note: This does not apply patterns to the top-level operation itself.
217 ///
218 bool mlir::applyPatternsGreedily(Operation *op,
219                                  const OwningRewritePatternList &patterns) {
220   return applyPatternsGreedily(op->getRegions(), patterns);
221 }
222 
223 /// Rewrite the given regions, which must be isolated from above.
224 bool mlir::applyPatternsGreedily(MutableArrayRef<Region> regions,
225                                  const OwningRewritePatternList &patterns) {
226   if (regions.empty())
227     return true;
228 
229   // The top-level operation must be known to be isolated from above to
230   // prevent performing canonicalizations on operations defined at or above
231   // the region containing 'op'.
232   auto regionIsIsolated = [](Region &region) {
233     return region.getParentOp()->isKnownIsolatedFromAbove();
234   };
235   (void)regionIsIsolated;
236   assert(llvm::all_of(regions, regionIsIsolated) &&
237          "patterns can only be applied to operations IsolatedFromAbove");
238 
239   // Start the pattern driver.
240   GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns);
241   bool converged = driver.simplify(regions, maxPatternMatchIterations);
242   LLVM_DEBUG(if (!converged) {
243     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
244                  << maxPatternMatchIterations << " times";
245   });
246   return converged;
247 }
248