1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements mlir::applyPatternsAndFoldGreedily.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Interfaces/SideEffectInterfaces.h"
15 #include "mlir/Transforms/FoldUtils.h"
16 #include "mlir/Transforms/RegionUtils.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/Support/CommandLine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/raw_ostream.h"
21 
22 using namespace mlir;
23 
24 #define DEBUG_TYPE "pattern-matcher"
25 
26 /// The max number of iterations scanning for pattern match.
27 static unsigned maxPatternMatchIterations = 10;
28 
29 //===----------------------------------------------------------------------===//
30 // GreedyPatternRewriteDriver
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
35 /// applies the locally optimal patterns in a roughly "bottom up" way.
36 class GreedyPatternRewriteDriver : public PatternRewriter {
37 public:
38   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
39                                       const OwningRewritePatternList &patterns)
40       : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
41     worklist.reserve(64);
42 
43     // Apply a simple cost model based solely on pattern benefit.
44     matcher.applyDefaultCostModel();
45   }
46 
47   bool simplify(MutableArrayRef<Region> regions, int maxIterations);
48 
49   void addToWorklist(Operation *op) {
50     // Check to see if the worklist already contains this op.
51     if (worklistMap.count(op))
52       return;
53 
54     worklistMap[op] = worklist.size();
55     worklist.push_back(op);
56   }
57 
58   Operation *popFromWorklist() {
59     auto *op = worklist.back();
60     worklist.pop_back();
61 
62     // This operation is no longer in the worklist, keep worklistMap up to date.
63     if (op)
64       worklistMap.erase(op);
65     return op;
66   }
67 
68   /// If the specified operation is in the worklist, remove it.  If not, this is
69   /// a no-op.
70   void removeFromWorklist(Operation *op) {
71     auto it = worklistMap.find(op);
72     if (it != worklistMap.end()) {
73       assert(worklist[it->second] == op && "malformed worklist data structure");
74       worklist[it->second] = nullptr;
75       worklistMap.erase(it);
76     }
77   }
78 
79   // These are hooks implemented for PatternRewriter.
80 protected:
81   // Implement the hook for inserting operations, and make sure that newly
82   // inserted ops are added to the worklist for processing.
83   void notifyOperationInserted(Operation *op) override { addToWorklist(op); }
84 
85   // If an operation is about to be removed, make sure it is not in our
86   // worklist anymore because we'd get dangling references to it.
87   void notifyOperationRemoved(Operation *op) override {
88     addToWorklist(op->getOperands());
89     op->walk([this](Operation *operation) {
90       removeFromWorklist(operation);
91       folder.notifyRemoval(operation);
92     });
93   }
94 
95   // When the root of a pattern is about to be replaced, it can trigger
96   // simplifications to its users - make sure to add them to the worklist
97   // before the root is changed.
98   void notifyRootReplaced(Operation *op) override {
99     for (auto result : op->getResults())
100       for (auto *user : result.getUsers())
101         addToWorklist(user);
102   }
103 
104 private:
105   // Look over the provided operands for any defining operations that should
106   // be re-added to the worklist. This function should be called when an
107   // operation is modified or removed, as it may trigger further
108   // simplifications.
109   template <typename Operands> void addToWorklist(Operands &&operands) {
110     for (Value operand : operands) {
111       // If the use count of this operand is now < 2, we re-add the defining
112       // operation to the worklist.
113       // TODO: This is based on the fact that zero use operations
114       // may be deleted, and that single use values often have more
115       // canonicalization opportunities.
116       if (!operand.use_empty() && !operand.hasOneUse())
117         continue;
118       if (auto *defInst = operand.getDefiningOp())
119         addToWorklist(defInst);
120     }
121   }
122 
123   /// The low-level pattern applicator.
124   PatternApplicator matcher;
125 
126   /// The worklist for this transformation keeps track of the operations that
127   /// need to be revisited, plus their index in the worklist.  This allows us to
128   /// efficiently remove operations from the worklist when they are erased, even
129   /// if they aren't the root of a pattern.
130   std::vector<Operation *> worklist;
131   DenseMap<Operation *, unsigned> worklistMap;
132 
133   /// Non-pattern based folder for operations.
134   OperationFolder folder;
135 };
136 } // end anonymous namespace
137 
138 /// Performs the rewrites while folding and erasing any dead ops. Returns true
139 /// if the rewrite converges in `maxIterations`.
140 bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
141                                           int maxIterations) {
142   // Add the given operation to the worklist.
143   auto collectOps = [this](Operation *op) { addToWorklist(op); };
144 
145   bool changed = false;
146   int i = 0;
147   do {
148     // Add all nested operations to the worklist.
149     for (auto &region : regions)
150       region.walk(collectOps);
151 
152     // These are scratch vectors used in the folding loop below.
153     SmallVector<Value, 8> originalOperands, resultValues;
154 
155     changed = false;
156     while (!worklist.empty()) {
157       auto *op = popFromWorklist();
158 
159       // Nulls get added to the worklist when operations are removed, ignore
160       // them.
161       if (op == nullptr)
162         continue;
163 
164       // If the operation is trivially dead - remove it.
165       if (isOpTriviallyDead(op)) {
166         notifyOperationRemoved(op);
167         op->erase();
168         changed = true;
169         continue;
170       }
171 
172       // Collects all the operands and result uses of the given `op` into work
173       // list. Also remove `op` and nested ops from worklist.
174       originalOperands.assign(op->operand_begin(), op->operand_end());
175       auto preReplaceAction = [&](Operation *op) {
176         // Add the operands to the worklist for visitation.
177         addToWorklist(originalOperands);
178 
179         // Add all the users of the result to the worklist so we make sure
180         // to revisit them.
181         for (auto result : op->getResults())
182           for (auto *userOp : result.getUsers())
183             addToWorklist(userOp);
184 
185         notifyOperationRemoved(op);
186       };
187 
188       // Try to fold this op.
189       bool inPlaceUpdate;
190       if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
191                                       &inPlaceUpdate)))) {
192         changed = true;
193         if (!inPlaceUpdate)
194           continue;
195       }
196 
197       // Try to match one of the patterns. The rewriter is automatically
198       // notified of any necessary changes, so there is nothing else to do here.
199       changed |= succeeded(matcher.matchAndRewrite(op, *this));
200     }
201 
202     // After applying patterns, make sure that the CFG of each of the regions is
203     // kept up to date.
204     if (succeeded(simplifyRegions(regions))) {
205       folder.clear();
206       changed = true;
207     }
208   } while (changed && ++i < maxIterations);
209   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
210   return !changed;
211 }
212 
213 /// Rewrite the regions of the specified operation, which must be isolated from
214 /// above, by repeatedly applying the highest benefit patterns in a greedy
215 /// work-list driven manner. Return success if no more patterns can be matched
216 /// in the result operation regions. Note: This does not apply patterns to the
217 /// top-level operation itself.
218 ///
219 LogicalResult
220 mlir::applyPatternsAndFoldGreedily(Operation *op,
221                                    const OwningRewritePatternList &patterns) {
222   return applyPatternsAndFoldGreedily(op->getRegions(), patterns);
223 }
224 /// Rewrite the given regions, which must be isolated from above.
225 LogicalResult
226 mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
227                                    const OwningRewritePatternList &patterns) {
228   if (regions.empty())
229     return success();
230 
231   // The top-level operation must be known to be isolated from above to
232   // prevent performing canonicalizations on operations defined at or above
233   // the region containing 'op'.
234   auto regionIsIsolated = [](Region &region) {
235     return region.getParentOp()->isKnownIsolatedFromAbove();
236   };
237   (void)regionIsIsolated;
238   assert(llvm::all_of(regions, regionIsIsolated) &&
239          "patterns can only be applied to operations IsolatedFromAbove");
240 
241   // Start the pattern driver.
242   GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns);
243   bool converged = driver.simplify(regions, maxPatternMatchIterations);
244   LLVM_DEBUG(if (!converged) {
245     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
246                  << maxPatternMatchIterations << " times";
247   });
248   return success(converged);
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // OpPatternRewriteDriver
253 //===----------------------------------------------------------------------===//
254 
255 namespace {
256 /// This is a simple driver for the PatternMatcher to apply patterns and perform
257 /// folding on a single op. It repeatedly applies locally optimal patterns.
258 class OpPatternRewriteDriver : public PatternRewriter {
259 public:
260   explicit OpPatternRewriteDriver(MLIRContext *ctx,
261                                   const OwningRewritePatternList &patterns)
262       : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
263     // Apply a simple cost model based solely on pattern benefit.
264     matcher.applyDefaultCostModel();
265   }
266 
267   /// Performs the rewrites and folding only on `op`. The simplification
268   /// converges if the op is erased as a result of being folded, replaced, or
269   /// dead, or no more changes happen in an iteration. Returns success if the
270   /// rewrite converges in `maxIterations`. `erased` is set to true if `op` gets
271   /// erased.
272   LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased);
273 
274   // These are hooks implemented for PatternRewriter.
275 protected:
276   /// If an operation is about to be removed, mark it so that we can let clients
277   /// know.
278   void notifyOperationRemoved(Operation *op) override {
279     opErasedViaPatternRewrites = true;
280   }
281 
282   // When a root is going to be replaced, its removal will be notified as well.
283   // So there is nothing to do here.
284   void notifyRootReplaced(Operation *op) override {}
285 
286 private:
287   /// The low-level pattern applicator.
288   PatternApplicator matcher;
289 
290   /// Non-pattern based folder for operations.
291   OperationFolder folder;
292 
293   /// Set to true if the operation has been erased via pattern rewrites.
294   bool opErasedViaPatternRewrites = false;
295 };
296 
297 } // anonymous namespace
298 
299 LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
300                                                       int maxIterations,
301                                                       bool &erased) {
302   bool changed = false;
303   erased = false;
304   opErasedViaPatternRewrites = false;
305   int i = 0;
306   // Iterate until convergence or until maxIterations. Deletion of the op as
307   // a result of being dead or folded is convergence.
308   do {
309     // If the operation is trivially dead - remove it.
310     if (isOpTriviallyDead(op)) {
311       op->erase();
312       erased = true;
313       return success();
314     }
315 
316     // Try to fold this op.
317     bool inPlaceUpdate;
318     if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr,
319                                    /*preReplaceAction=*/nullptr,
320                                    &inPlaceUpdate))) {
321       changed = true;
322       if (!inPlaceUpdate) {
323         erased = true;
324         return success();
325       }
326     }
327 
328     // Try to match one of the patterns. The rewriter is automatically
329     // notified of any necessary changes, so there is nothing else to do here.
330     changed |= succeeded(matcher.matchAndRewrite(op, *this));
331     if ((erased = opErasedViaPatternRewrites))
332       return success();
333   } while (changed && ++i < maxIterations);
334 
335   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
336   return failure(changed);
337 }
338 
339 /// Rewrites only `op` using the supplied canonicalization patterns and
340 /// folding. `erased` is set to true if the op is erased as a result of being
341 /// folded, replaced, or dead.
342 LogicalResult mlir::applyOpPatternsAndFold(
343     Operation *op, const OwningRewritePatternList &patterns, bool *erased) {
344   // Start the pattern driver.
345   OpPatternRewriteDriver driver(op->getContext(), patterns);
346   bool opErased;
347   LogicalResult converged =
348       driver.simplifyLocally(op, maxPatternMatchIterations, opErased);
349   if (erased)
350     *erased = opErased;
351   LLVM_DEBUG(if (failed(converged)) {
352     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
353                  << maxPatternMatchIterations << " times";
354   });
355   return converged;
356 }
357