1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements mlir::applyPatternsAndFoldGreedily.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14 #include "mlir/Interfaces/SideEffectInterfaces.h"
15 #include "mlir/Rewrite/PatternApplicator.h"
16 #include "mlir/Transforms/FoldUtils.h"
17 #include "mlir/Transforms/RegionUtils.h"
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 using namespace mlir;
24 
25 #define DEBUG_TYPE "pattern-matcher"
26 
27 /// The max number of iterations scanning for pattern match.
28 static unsigned maxPatternMatchIterations = 10;
29 
30 //===----------------------------------------------------------------------===//
31 // GreedyPatternRewriteDriver
32 //===----------------------------------------------------------------------===//
33 
34 namespace {
35 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
36 /// applies the locally optimal patterns in a roughly "bottom up" way.
37 class GreedyPatternRewriteDriver : public PatternRewriter {
38 public:
39   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
40                                       const FrozenRewritePatternSet &patterns,
41                                       bool useTopDownTraversal)
42       : PatternRewriter(ctx), matcher(patterns), folder(ctx),
43         useTopDownTraversal(useTopDownTraversal) {
44     worklist.reserve(64);
45 
46     // Apply a simple cost model based solely on pattern benefit.
47     matcher.applyDefaultCostModel();
48   }
49 
50   bool simplify(MutableArrayRef<Region> regions, int maxIterations);
51 
52   void addToWorklist(Operation *op) {
53     // Check to see if the worklist already contains this op.
54     if (worklistMap.count(op))
55       return;
56 
57     worklistMap[op] = worklist.size();
58     worklist.push_back(op);
59   }
60 
61   Operation *popFromWorklist() {
62     auto *op = worklist.back();
63     worklist.pop_back();
64 
65     // This operation is no longer in the worklist, keep worklistMap up to date.
66     if (op)
67       worklistMap.erase(op);
68     return op;
69   }
70 
71   /// If the specified operation is in the worklist, remove it.  If not, this is
72   /// a no-op.
73   void removeFromWorklist(Operation *op) {
74     auto it = worklistMap.find(op);
75     if (it != worklistMap.end()) {
76       assert(worklist[it->second] == op && "malformed worklist data structure");
77       worklist[it->second] = nullptr;
78       worklistMap.erase(it);
79     }
80   }
81 
82   // These are hooks implemented for PatternRewriter.
83 protected:
84   // Implement the hook for inserting operations, and make sure that newly
85   // inserted ops are added to the worklist for processing.
86   void notifyOperationInserted(Operation *op) override { addToWorklist(op); }
87 
88   // If an operation is about to be removed, make sure it is not in our
89   // worklist anymore because we'd get dangling references to it.
90   void notifyOperationRemoved(Operation *op) override {
91     addToWorklist(op->getOperands());
92     op->walk([this](Operation *operation) {
93       removeFromWorklist(operation);
94       folder.notifyRemoval(operation);
95     });
96   }
97 
98   // When the root of a pattern is about to be replaced, it can trigger
99   // simplifications to its users - make sure to add them to the worklist
100   // before the root is changed.
101   void notifyRootReplaced(Operation *op) override {
102     for (auto result : op->getResults())
103       for (auto *user : result.getUsers())
104         addToWorklist(user);
105   }
106 
107 private:
108   // Look over the provided operands for any defining operations that should
109   // be re-added to the worklist. This function should be called when an
110   // operation is modified or removed, as it may trigger further
111   // simplifications.
112   template <typename Operands>
113   void addToWorklist(Operands &&operands) {
114     for (Value operand : operands) {
115       // If the use count of this operand is now < 2, we re-add the defining
116       // operation to the worklist.
117       // TODO: This is based on the fact that zero use operations
118       // may be deleted, and that single use values often have more
119       // canonicalization opportunities.
120       if (!operand || (!operand.use_empty() && !operand.hasOneUse()))
121         continue;
122       if (auto *defInst = operand.getDefiningOp())
123         addToWorklist(defInst);
124     }
125   }
126 
127   /// The low-level pattern applicator.
128   PatternApplicator matcher;
129 
130   /// The worklist for this transformation keeps track of the operations that
131   /// need to be revisited, plus their index in the worklist.  This allows us to
132   /// efficiently remove operations from the worklist when they are erased, even
133   /// if they aren't the root of a pattern.
134   std::vector<Operation *> worklist;
135   DenseMap<Operation *, unsigned> worklistMap;
136 
137   /// Non-pattern based folder for operations.
138   OperationFolder folder;
139 
140   // Whether to use top-down or bottom-up traversal order.
141   bool useTopDownTraversal;
142 };
143 } // end anonymous namespace
144 
145 /// Performs the rewrites while folding and erasing any dead ops. Returns true
146 /// if the rewrite converges in `maxIterations`.
147 bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
148                                           int maxIterations) {
149   // Perform a prepass over the IR to discover constants.
150   for (auto &region : regions)
151     folder.processExistingConstants(region);
152 
153   bool changed = false;
154   int iteration = 0;
155   do {
156     worklist.clear();
157     worklistMap.clear();
158 
159     // Add all nested operations to the worklist in preorder.
160     for (auto &region : regions)
161       if (useTopDownTraversal)
162         region.walk<WalkOrder::PreOrder>(
163             [this](Operation *op) { worklist.push_back(op); });
164       else
165         region.walk([this](Operation *op) { addToWorklist(op); });
166 
167     if (useTopDownTraversal) {
168       // Reverse the list so our pop-back loop processes them in-order.
169       std::reverse(worklist.begin(), worklist.end());
170       // Remember the reverse index.
171       for (unsigned i = 0, e = worklist.size(); i != e; ++i)
172         worklistMap[worklist[i]] = i;
173     }
174 
175     // These are scratch vectors used in the folding loop below.
176     SmallVector<Value, 8> originalOperands, resultValues;
177 
178     changed = false;
179     while (!worklist.empty()) {
180       auto *op = popFromWorklist();
181 
182       // Nulls get added to the worklist when operations are removed, ignore
183       // them.
184       if (op == nullptr)
185         continue;
186 
187       // If the operation is trivially dead - remove it.
188       if (isOpTriviallyDead(op)) {
189         notifyOperationRemoved(op);
190         op->erase();
191         changed = true;
192         continue;
193       }
194 
195       // Collects all the operands and result uses of the given `op` into work
196       // list. Also remove `op` and nested ops from worklist.
197       originalOperands.assign(op->operand_begin(), op->operand_end());
198       auto preReplaceAction = [&](Operation *op) {
199         // Add the operands to the worklist for visitation.
200         addToWorklist(originalOperands);
201 
202         // Add all the users of the result to the worklist so we make sure
203         // to revisit them.
204         for (auto result : op->getResults())
205           for (auto *userOp : result.getUsers())
206             addToWorklist(userOp);
207 
208         notifyOperationRemoved(op);
209       };
210 
211       // Add the given operation to the worklist.
212       auto collectOps = [this](Operation *op) { addToWorklist(op); };
213 
214       // Try to fold this op.
215       bool inPlaceUpdate;
216       if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
217                                       &inPlaceUpdate)))) {
218         changed = true;
219         if (!inPlaceUpdate)
220           continue;
221       }
222 
223       // Try to match one of the patterns. The rewriter is automatically
224       // notified of any necessary changes, so there is nothing else to do here.
225       changed |= succeeded(matcher.matchAndRewrite(op, *this));
226     }
227 
228     // After applying patterns, make sure that the CFG of each of the regions is
229     // kept up to date.
230     changed |= succeeded(simplifyRegions(*this, regions));
231   } while (changed && ++iteration < maxIterations);
232 
233   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
234   return !changed;
235 }
236 
237 /// Rewrite the regions of the specified operation, which must be isolated from
238 /// above, by repeatedly applying the highest benefit patterns in a greedy
239 /// work-list driven manner. Return success if no more patterns can be matched
240 /// in the result operation regions. Note: This does not apply patterns to the
241 /// top-level operation itself.
242 ///
243 LogicalResult
244 mlir::applyPatternsAndFoldGreedily(Operation *op,
245                                    const FrozenRewritePatternSet &patterns,
246                                    bool useTopDownTraversal) {
247   return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations,
248                                       useTopDownTraversal);
249 }
250 LogicalResult mlir::applyPatternsAndFoldGreedily(
251     Operation *op, const FrozenRewritePatternSet &patterns,
252     unsigned maxIterations, bool useTopDownTraversal) {
253   return applyPatternsAndFoldGreedily(op->getRegions(), patterns, maxIterations,
254                                       useTopDownTraversal);
255 }
256 /// Rewrite the given regions, which must be isolated from above.
257 LogicalResult
258 mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
259                                    const FrozenRewritePatternSet &patterns,
260                                    bool useTopDownTraversal) {
261   return applyPatternsAndFoldGreedily(
262       regions, patterns, maxPatternMatchIterations, useTopDownTraversal);
263 }
264 LogicalResult mlir::applyPatternsAndFoldGreedily(
265     MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
266     unsigned maxIterations, bool useTopDownTraversal) {
267   if (regions.empty())
268     return success();
269 
270   // The top-level operation must be known to be isolated from above to
271   // prevent performing canonicalizations on operations defined at or above
272   // the region containing 'op'.
273   auto regionIsIsolated = [](Region &region) {
274     return region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>();
275   };
276   (void)regionIsIsolated;
277   assert(llvm::all_of(regions, regionIsIsolated) &&
278          "patterns can only be applied to operations IsolatedFromAbove");
279 
280   // Start the pattern driver.
281   GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns,
282                                     useTopDownTraversal);
283   bool converged = driver.simplify(regions, maxIterations);
284   LLVM_DEBUG(if (!converged) {
285     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
286                  << maxIterations << " times\n";
287   });
288   return success(converged);
289 }
290 
291 //===----------------------------------------------------------------------===//
292 // OpPatternRewriteDriver
293 //===----------------------------------------------------------------------===//
294 
295 namespace {
296 /// This is a simple driver for the PatternMatcher to apply patterns and perform
297 /// folding on a single op. It repeatedly applies locally optimal patterns.
298 class OpPatternRewriteDriver : public PatternRewriter {
299 public:
300   explicit OpPatternRewriteDriver(MLIRContext *ctx,
301                                   const FrozenRewritePatternSet &patterns)
302       : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
303     // Apply a simple cost model based solely on pattern benefit.
304     matcher.applyDefaultCostModel();
305   }
306 
307   /// Performs the rewrites and folding only on `op`. The simplification
308   /// converges if the op is erased as a result of being folded, replaced, or
309   /// dead, or no more changes happen in an iteration. Returns success if the
310   /// rewrite converges in `maxIterations`. `erased` is set to true if `op` gets
311   /// erased.
312   LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased);
313 
314   // These are hooks implemented for PatternRewriter.
315 protected:
316   /// If an operation is about to be removed, mark it so that we can let clients
317   /// know.
318   void notifyOperationRemoved(Operation *op) override {
319     opErasedViaPatternRewrites = true;
320   }
321 
322   // When a root is going to be replaced, its removal will be notified as well.
323   // So there is nothing to do here.
324   void notifyRootReplaced(Operation *op) override {}
325 
326 private:
327   /// The low-level pattern applicator.
328   PatternApplicator matcher;
329 
330   /// Non-pattern based folder for operations.
331   OperationFolder folder;
332 
333   /// Set to true if the operation has been erased via pattern rewrites.
334   bool opErasedViaPatternRewrites = false;
335 };
336 
337 } // anonymous namespace
338 
339 LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
340                                                       int maxIterations,
341                                                       bool &erased) {
342   bool changed = false;
343   erased = false;
344   opErasedViaPatternRewrites = false;
345   int i = 0;
346   // Iterate until convergence or until maxIterations. Deletion of the op as
347   // a result of being dead or folded is convergence.
348   do {
349     changed = false;
350 
351     // If the operation is trivially dead - remove it.
352     if (isOpTriviallyDead(op)) {
353       op->erase();
354       erased = true;
355       return success();
356     }
357 
358     // Try to fold this op.
359     bool inPlaceUpdate;
360     if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr,
361                                    /*preReplaceAction=*/nullptr,
362                                    &inPlaceUpdate))) {
363       changed = true;
364       if (!inPlaceUpdate) {
365         erased = true;
366         return success();
367       }
368     }
369 
370     // Try to match one of the patterns. The rewriter is automatically
371     // notified of any necessary changes, so there is nothing else to do here.
372     changed |= succeeded(matcher.matchAndRewrite(op, *this));
373     if ((erased = opErasedViaPatternRewrites))
374       return success();
375   } while (changed && ++i < maxIterations);
376 
377   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
378   return failure(changed);
379 }
380 
381 /// Rewrites only `op` using the supplied canonicalization patterns and
382 /// folding. `erased` is set to true if the op is erased as a result of being
383 /// folded, replaced, or dead.
384 LogicalResult mlir::applyOpPatternsAndFold(
385     Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) {
386   // Start the pattern driver.
387   OpPatternRewriteDriver driver(op->getContext(), patterns);
388   bool opErased;
389   LogicalResult converged =
390       driver.simplifyLocally(op, maxPatternMatchIterations, opErased);
391   if (erased)
392     *erased = opErased;
393   LLVM_DEBUG(if (failed(converged)) {
394     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
395                  << maxPatternMatchIterations << " times";
396   });
397   return converged;
398 }
399