1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements mlir::applyPatternsGreedily.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Interfaces/SideEffects.h"
15 #include "mlir/Transforms/FoldUtils.h"
16 #include "mlir/Transforms/RegionUtils.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/Support/CommandLine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/raw_ostream.h"
21 
22 using namespace mlir;
23 
24 #define DEBUG_TYPE "pattern-matcher"
25 
26 static llvm::cl::opt<unsigned> maxPatternMatchIterations(
27     "mlir-max-pattern-match-iterations",
28     llvm::cl::desc("Max number of iterations scanning for pattern match"),
29     llvm::cl::init(10));
30 
31 namespace {
32 
33 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
34 /// applies the locally optimal patterns in a roughly "bottom up" way.
35 class GreedyPatternRewriteDriver : public PatternRewriter {
36 public:
37   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
38                                       const OwningRewritePatternList &patterns)
39       : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
40     worklist.reserve(64);
41   }
42 
43   /// Perform the rewrites. Return true if the rewrite converges in
44   /// `maxIterations`.
45   bool simplify(MutableArrayRef<Region> regions, int maxIterations);
46 
47   void addToWorklist(Operation *op) {
48     // Check to see if the worklist already contains this op.
49     if (worklistMap.count(op))
50       return;
51 
52     worklistMap[op] = worklist.size();
53     worklist.push_back(op);
54   }
55 
56   Operation *popFromWorklist() {
57     auto *op = worklist.back();
58     worklist.pop_back();
59 
60     // This operation is no longer in the worklist, keep worklistMap up to date.
61     if (op)
62       worklistMap.erase(op);
63     return op;
64   }
65 
66   /// If the specified operation is in the worklist, remove it.  If not, this is
67   /// a no-op.
68   void removeFromWorklist(Operation *op) {
69     auto it = worklistMap.find(op);
70     if (it != worklistMap.end()) {
71       assert(worklist[it->second] == op && "malformed worklist data structure");
72       worklist[it->second] = nullptr;
73       worklistMap.erase(it);
74     }
75   }
76 
77   // These are hooks implemented for PatternRewriter.
78 protected:
79   // Implement the hook for inserting operations, and make sure that newly
80   // inserted ops are added to the worklist for processing.
81   Operation *insert(Operation *op) override {
82     addToWorklist(op);
83     return OpBuilder::insert(op);
84   }
85 
86   // If an operation is about to be removed, make sure it is not in our
87   // worklist anymore because we'd get dangling references to it.
88   void notifyOperationRemoved(Operation *op) override {
89     addToWorklist(op->getOperands());
90     op->walk([this](Operation *operation) {
91       removeFromWorklist(operation);
92       folder.notifyRemoval(operation);
93     });
94   }
95 
96   // When the root of a pattern is about to be replaced, it can trigger
97   // simplifications to its users - make sure to add them to the worklist
98   // before the root is changed.
99   void notifyRootReplaced(Operation *op) override {
100     for (auto result : op->getResults())
101       for (auto *user : result.getUsers())
102         addToWorklist(user);
103   }
104 
105 private:
106   // Look over the provided operands for any defining operations that should
107   // be re-added to the worklist. This function should be called when an
108   // operation is modified or removed, as it may trigger further
109   // simplifications.
110   template <typename Operands> void addToWorklist(Operands &&operands) {
111     for (Value operand : operands) {
112       // If the use count of this operand is now < 2, we re-add the defining
113       // operation to the worklist.
114       // TODO(riverriddle) This is based on the fact that zero use operations
115       // may be deleted, and that single use values often have more
116       // canonicalization opportunities.
117       if (!operand.use_empty() && !operand.hasOneUse())
118         continue;
119       if (auto *defInst = operand.getDefiningOp())
120         addToWorklist(defInst);
121     }
122   }
123 
124   /// The low-level pattern matcher.
125   RewritePatternMatcher matcher;
126 
127   /// The worklist for this transformation keeps track of the operations that
128   /// need to be revisited, plus their index in the worklist.  This allows us to
129   /// efficiently remove operations from the worklist when they are erased, even
130   /// if they aren't the root of a pattern.
131   std::vector<Operation *> worklist;
132   DenseMap<Operation *, unsigned> worklistMap;
133 
134   /// Non-pattern based folder for operations.
135   OperationFolder folder;
136 };
137 } // end anonymous namespace
138 
139 /// Perform the rewrites.
140 bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
141                                           int maxIterations) {
142   // Add the given operation to the worklist.
143   auto collectOps = [this](Operation *op) { addToWorklist(op); };
144 
145   bool changed = false;
146   int i = 0;
147   do {
148     // Add all nested operations to the worklist.
149     for (auto &region : regions)
150       region.walk(collectOps);
151 
152     // These are scratch vectors used in the folding loop below.
153     SmallVector<Value, 8> originalOperands, resultValues;
154 
155     changed = false;
156     while (!worklist.empty()) {
157       auto *op = popFromWorklist();
158 
159       // Nulls get added to the worklist when operations are removed, ignore
160       // them.
161       if (op == nullptr)
162         continue;
163 
164       // If the operation is trivially dead - remove it.
165       if (isOpTriviallyDead(op)) {
166         notifyOperationRemoved(op);
167         op->erase();
168         changed = true;
169         continue;
170       }
171 
172       // Collects all the operands and result uses of the given `op` into work
173       // list. Also remove `op` and nested ops from worklist.
174       originalOperands.assign(op->operand_begin(), op->operand_end());
175       auto preReplaceAction = [&](Operation *op) {
176         // Add the operands to the worklist for visitation.
177         addToWorklist(originalOperands);
178 
179         // Add all the users of the result to the worklist so we make sure
180         // to revisit them.
181         for (auto result : op->getResults())
182           for (auto *userOp : result.getUsers())
183             addToWorklist(userOp);
184 
185         notifyOperationRemoved(op);
186       };
187 
188       // Try to fold this op.
189       if (succeeded(folder.tryToFold(op, collectOps, preReplaceAction))) {
190         changed = true;
191         continue;
192       }
193 
194       // Make sure that any new operations are inserted at this point.
195       setInsertionPoint(op);
196 
197       // Try to match one of the patterns. The rewriter is automatically
198       // notified of any necessary changes, so there is nothing else to do here.
199       changed |= matcher.matchAndRewrite(op, *this);
200     }
201 
202     // After applying patterns, make sure that the CFG of each of the regions is
203     // kept up to date.
204     if (succeeded(simplifyRegions(regions))) {
205       folder.clear();
206       changed = true;
207     }
208   } while (changed && ++i < maxIterations);
209   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
210   return !changed;
211 }
212 
213 /// Rewrite the regions of the specified operation, which must be isolated from
214 /// above, by repeatedly applying the highest benefit patterns in a greedy
215 /// work-list driven manner. Return true if no more patterns can be matched in
216 /// the result operation regions.
217 /// Note: This does not apply patterns to the top-level operation itself.
218 ///
219 bool mlir::applyPatternsGreedily(Operation *op,
220                                  const OwningRewritePatternList &patterns) {
221   return applyPatternsGreedily(op->getRegions(), patterns);
222 }
223 
224 /// Rewrite the given regions, which must be isolated from above.
225 bool mlir::applyPatternsGreedily(MutableArrayRef<Region> regions,
226                                  const OwningRewritePatternList &patterns) {
227   if (regions.empty())
228     return true;
229 
230   // The top-level operation must be known to be isolated from above to
231   // prevent performing canonicalizations on operations defined at or above
232   // the region containing 'op'.
233   auto regionIsIsolated = [](Region &region) {
234     return region.getParentOp()->isKnownIsolatedFromAbove();
235   };
236   (void)regionIsIsolated;
237   assert(llvm::all_of(regions, regionIsIsolated) &&
238          "patterns can only be applied to operations IsolatedFromAbove");
239 
240   // Start the pattern driver.
241   GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns);
242   bool converged = driver.simplify(regions, maxPatternMatchIterations);
243   LLVM_DEBUG(if (!converged) {
244     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
245                  << maxPatternMatchIterations << " times";
246   });
247   return converged;
248 }
249