1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements mlir::applyPatternsAndFoldGreedily.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14 #include "mlir/Interfaces/SideEffectInterfaces.h"
15 #include "mlir/Rewrite/PatternApplicator.h"
16 #include "mlir/Transforms/FoldUtils.h"
17 #include "mlir/Transforms/RegionUtils.h"
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 using namespace mlir;
24 
25 #define DEBUG_TYPE "pattern-matcher"
26 
27 //===----------------------------------------------------------------------===//
28 // GreedyPatternRewriteDriver
29 //===----------------------------------------------------------------------===//
30 
31 namespace {
32 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
33 /// applies the locally optimal patterns in a roughly "bottom up" way.
34 class GreedyPatternRewriteDriver : public PatternRewriter {
35 public:
36   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
37                                       const FrozenRewritePatternSet &patterns,
38                                       const GreedyRewriteConfig &config)
39       : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
40     worklist.reserve(64);
41 
42     // Apply a simple cost model based solely on pattern benefit.
43     matcher.applyDefaultCostModel();
44   }
45 
46   bool simplify(MutableArrayRef<Region> regions);
47 
48   void addToWorklist(Operation *op) {
49     // Check to see if the worklist already contains this op.
50     if (worklistMap.count(op))
51       return;
52 
53     worklistMap[op] = worklist.size();
54     worklist.push_back(op);
55   }
56 
57   Operation *popFromWorklist() {
58     auto *op = worklist.back();
59     worklist.pop_back();
60 
61     // This operation is no longer in the worklist, keep worklistMap up to date.
62     if (op)
63       worklistMap.erase(op);
64     return op;
65   }
66 
67   /// If the specified operation is in the worklist, remove it.  If not, this is
68   /// a no-op.
69   void removeFromWorklist(Operation *op) {
70     auto it = worklistMap.find(op);
71     if (it != worklistMap.end()) {
72       assert(worklist[it->second] == op && "malformed worklist data structure");
73       worklist[it->second] = nullptr;
74       worklistMap.erase(it);
75     }
76   }
77 
78   // These are hooks implemented for PatternRewriter.
79 protected:
80   // Implement the hook for inserting operations, and make sure that newly
81   // inserted ops are added to the worklist for processing.
82   void notifyOperationInserted(Operation *op) override { addToWorklist(op); }
83 
84   // Look over the provided operands for any defining operations that should
85   // be re-added to the worklist. This function should be called when an
86   // operation is modified or removed, as it may trigger further
87   // simplifications.
88   template <typename Operands>
89   void addToWorklist(Operands &&operands) {
90     for (Value operand : operands) {
91       // If the use count of this operand is now < 2, we re-add the defining
92       // operation to the worklist.
93       // TODO: This is based on the fact that zero use operations
94       // may be deleted, and that single use values often have more
95       // canonicalization opportunities.
96       if (!operand || (!operand.use_empty() && !operand.hasOneUse()))
97         continue;
98       if (auto *defOp = operand.getDefiningOp())
99         addToWorklist(defOp);
100     }
101   }
102 
103   // If an operation is about to be removed, make sure it is not in our
104   // worklist anymore because we'd get dangling references to it.
105   void notifyOperationRemoved(Operation *op) override {
106     addToWorklist(op->getOperands());
107     op->walk([this](Operation *operation) {
108       removeFromWorklist(operation);
109       folder.notifyRemoval(operation);
110     });
111   }
112 
113   // When the root of a pattern is about to be replaced, it can trigger
114   // simplifications to its users - make sure to add them to the worklist
115   // before the root is changed.
116   void notifyRootReplaced(Operation *op) override {
117     for (auto result : op->getResults())
118       for (auto *user : result.getUsers())
119         addToWorklist(user);
120   }
121 
122   /// The low-level pattern applicator.
123   PatternApplicator matcher;
124 
125   /// The worklist for this transformation keeps track of the operations that
126   /// need to be revisited, plus their index in the worklist.  This allows us to
127   /// efficiently remove operations from the worklist when they are erased, even
128   /// if they aren't the root of a pattern.
129   std::vector<Operation *> worklist;
130   DenseMap<Operation *, unsigned> worklistMap;
131 
132   /// Non-pattern based folder for operations.
133   OperationFolder folder;
134 
135 private:
136   /// Configuration information for how to simplify.
137   GreedyRewriteConfig config;
138 };
139 } // end anonymous namespace
140 
141 /// Performs the rewrites while folding and erasing any dead ops. Returns true
142 /// if the rewrite converges in `maxIterations`.
143 bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
144   bool changed = false;
145   unsigned iteration = 0;
146   do {
147     worklist.clear();
148     worklistMap.clear();
149 
150     if (!config.useTopDownTraversal) {
151       // Add operations to the worklist in postorder.
152       for (auto &region : regions)
153         region.walk([this](Operation *op) { addToWorklist(op); });
154     } else {
155       // Add all nested operations to the worklist in preorder.
156       for (auto &region : regions)
157         region.walk<WalkOrder::PreOrder>(
158             [this](Operation *op) { worklist.push_back(op); });
159 
160       // Reverse the list so our pop-back loop processes them in-order.
161       std::reverse(worklist.begin(), worklist.end());
162       // Remember the reverse index.
163       for (size_t i = 0, e = worklist.size(); i != e; ++i)
164         worklistMap[worklist[i]] = i;
165     }
166 
167     // These are scratch vectors used in the folding loop below.
168     SmallVector<Value, 8> originalOperands, resultValues;
169 
170     changed = false;
171     while (!worklist.empty()) {
172       auto *op = popFromWorklist();
173 
174       // Nulls get added to the worklist when operations are removed, ignore
175       // them.
176       if (op == nullptr)
177         continue;
178 
179       // If the operation is trivially dead - remove it.
180       if (isOpTriviallyDead(op)) {
181         notifyOperationRemoved(op);
182         op->erase();
183         changed = true;
184         continue;
185       }
186 
187       // Collects all the operands and result uses of the given `op` into work
188       // list. Also remove `op` and nested ops from worklist.
189       originalOperands.assign(op->operand_begin(), op->operand_end());
190       auto preReplaceAction = [&](Operation *op) {
191         // Add the operands to the worklist for visitation.
192         addToWorklist(originalOperands);
193 
194         // Add all the users of the result to the worklist so we make sure
195         // to revisit them.
196         for (auto result : op->getResults())
197           for (auto *userOp : result.getUsers())
198             addToWorklist(userOp);
199 
200         notifyOperationRemoved(op);
201       };
202 
203       // Add the given operation to the worklist.
204       auto collectOps = [this](Operation *op) { addToWorklist(op); };
205 
206       // Try to fold this op.
207       bool inPlaceUpdate;
208       if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
209                                       &inPlaceUpdate)))) {
210         changed = true;
211         if (!inPlaceUpdate)
212           continue;
213       }
214 
215       // Try to match one of the patterns. The rewriter is automatically
216       // notified of any necessary changes, so there is nothing else to do
217       // here.
218       changed |= succeeded(matcher.matchAndRewrite(op, *this));
219     }
220 
221     // After applying patterns, make sure that the CFG of each of the regions
222     // is kept up to date.
223     if (config.enableRegionSimplification)
224       changed |= succeeded(simplifyRegions(*this, regions));
225   } while (changed &&
226            (++iteration < config.maxIterations ||
227             config.maxIterations == GreedyRewriteConfig::kNoIterationLimit));
228 
229   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
230   return !changed;
231 }
232 
233 /// Rewrite the regions of the specified operation, which must be isolated from
234 /// above, by repeatedly applying the highest benefit patterns in a greedy
235 /// work-list driven manner. Return success if no more patterns can be matched
236 /// in the result operation regions. Note: This does not apply patterns to the
237 /// top-level operation itself.
238 ///
239 LogicalResult
240 mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
241                                    const FrozenRewritePatternSet &patterns,
242                                    GreedyRewriteConfig config) {
243   if (regions.empty())
244     return success();
245 
246   // The top-level operation must be known to be isolated from above to
247   // prevent performing canonicalizations on operations defined at or above
248   // the region containing 'op'.
249   auto regionIsIsolated = [](Region &region) {
250     return region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>();
251   };
252   (void)regionIsIsolated;
253   assert(llvm::all_of(regions, regionIsIsolated) &&
254          "patterns can only be applied to operations IsolatedFromAbove");
255 
256   // Start the pattern driver.
257   GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config);
258   bool converged = driver.simplify(regions);
259   LLVM_DEBUG(if (!converged) {
260     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
261                  << config.maxIterations << " times\n";
262   });
263   return success(converged);
264 }
265 
266 //===----------------------------------------------------------------------===//
267 // OpPatternRewriteDriver
268 //===----------------------------------------------------------------------===//
269 
270 namespace {
271 /// This is a simple driver for the PatternMatcher to apply patterns and perform
272 /// folding on a single op. It repeatedly applies locally optimal patterns.
273 class OpPatternRewriteDriver : public PatternRewriter {
274 public:
275   explicit OpPatternRewriteDriver(MLIRContext *ctx,
276                                   const FrozenRewritePatternSet &patterns)
277       : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
278     // Apply a simple cost model based solely on pattern benefit.
279     matcher.applyDefaultCostModel();
280   }
281 
282   LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased);
283 
284   // These are hooks implemented for PatternRewriter.
285 protected:
286   /// If an operation is about to be removed, mark it so that we can let clients
287   /// know.
288   void notifyOperationRemoved(Operation *op) override {
289     opErasedViaPatternRewrites = true;
290   }
291 
292   // When a root is going to be replaced, its removal will be notified as well.
293   // So there is nothing to do here.
294   void notifyRootReplaced(Operation *op) override {}
295 
296 private:
297   /// The low-level pattern applicator.
298   PatternApplicator matcher;
299 
300   /// Non-pattern based folder for operations.
301   OperationFolder folder;
302 
303   /// Set to true if the operation has been erased via pattern rewrites.
304   bool opErasedViaPatternRewrites = false;
305 };
306 
307 } // anonymous namespace
308 
309 /// Performs the rewrites and folding only on `op`. The simplification
310 /// converges if the op is erased as a result of being folded, replaced, or
311 /// becoming dead, or no more changes happen in an iteration. Returns success if
312 /// the rewrite converges in `maxIterations`. `erased` is set to true if `op`
313 /// gets erased.
314 LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
315                                                       int maxIterations,
316                                                       bool &erased) {
317   bool changed = false;
318   erased = false;
319   opErasedViaPatternRewrites = false;
320   int iterations = 0;
321   // Iterate until convergence or until maxIterations. Deletion of the op as
322   // a result of being dead or folded is convergence.
323   do {
324     changed = false;
325 
326     // If the operation is trivially dead - remove it.
327     if (isOpTriviallyDead(op)) {
328       op->erase();
329       erased = true;
330       return success();
331     }
332 
333     // Try to fold this op.
334     bool inPlaceUpdate;
335     if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr,
336                                    /*preReplaceAction=*/nullptr,
337                                    &inPlaceUpdate))) {
338       changed = true;
339       if (!inPlaceUpdate) {
340         erased = true;
341         return success();
342       }
343     }
344 
345     // Try to match one of the patterns. The rewriter is automatically
346     // notified of any necessary changes, so there is nothing else to do here.
347     changed |= succeeded(matcher.matchAndRewrite(op, *this));
348     if ((erased = opErasedViaPatternRewrites))
349       return success();
350   } while (changed &&
351            (++iterations < maxIterations ||
352             maxIterations == GreedyRewriteConfig::kNoIterationLimit));
353 
354   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
355   return failure(changed);
356 }
357 
358 //===----------------------------------------------------------------------===//
359 // MultiOpPatternRewriteDriver
360 //===----------------------------------------------------------------------===//
361 
362 namespace {
363 
364 /// This is a specialized GreedyPatternRewriteDriver to apply patterns and
365 /// perform folding for a supplied set of ops. It repeatedly simplifies while
366 /// restricting the rewrites to only the provided set of ops or optionally
367 /// to those directly affected by it (result users or operand providers).
368 class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
369 public:
370   explicit MultiOpPatternRewriteDriver(MLIRContext *ctx,
371                                        const FrozenRewritePatternSet &patterns,
372                                        bool strict)
373       : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()),
374         strictMode(strict) {}
375 
376   bool simplifyLocally(ArrayRef<Operation *> op);
377 
378 private:
379   // Look over the provided operands for any defining operations that should
380   // be re-added to the worklist. This function should be called when an
381   // operation is modified or removed, as it may trigger further
382   // simplifications. If `strict` is set to true, only ops in
383   // `strictModeFilteredOps` are considered.
384   template <typename Operands>
385   void addOperandsToWorklist(Operands &&operands) {
386     for (Value operand : operands) {
387       if (auto *defOp = operand.getDefiningOp()) {
388         if (!strictMode || strictModeFilteredOps.contains(defOp))
389           addToWorklist(defOp);
390       }
391     }
392   }
393 
394   void notifyOperationRemoved(Operation *op) override {
395     GreedyPatternRewriteDriver::notifyOperationRemoved(op);
396     if (strictMode)
397       strictModeFilteredOps.erase(op);
398   }
399 
400   /// If `strictMode` is true, any pre-existing ops outside of
401   /// `strictModeFilteredOps` remain completely untouched by the rewrite driver.
402   /// If `strictMode` is false, operations that use results of (or supply
403   /// operands to) any rewritten ops stemming from the simplification of the
404   /// provided ops are in turn simplified; any other ops still remain untouched
405   /// (i.e., regardless of `strictMode`).
406   bool strictMode = false;
407 
408   /// The list of ops we are restricting our rewrites to if `strictMode` is on.
409   /// These include the supplied set of ops as well as new ops created while
410   /// rewriting those ops. This set is not maintained when strictMode is off.
411   llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
412 };
413 
414 } // end anonymous namespace
415 
416 /// Performs the specified rewrites on `ops` while also trying to fold these ops
417 /// as well as any other ops that were in turn created due to these rewrite
418 /// patterns. Any pre-existing ops outside of `ops` remain completely
419 /// unmodified if `strictMode` is true. If `strictMode` is false, other
420 /// operations that use results of rewritten ops or supply operands to such ops
421 /// are in turn simplified; any other ops still remain unmodified (i.e.,
422 /// regardless of `strictMode`). Note that ops in `ops` could be erased as a
423 /// result of folding, becoming dead, or via pattern rewrites. Returns true if
424 /// at all any changes happened.
425 // Unlike `OpPatternRewriteDriver::simplifyLocally` which works on a single op
426 // or GreedyPatternRewriteDriver::simplify, this method just iterates until
427 // the worklist is empty. As our objective is to keep simplification "local",
428 // there is no strong rationale to re-add all operations into the worklist and
429 // rerun until an iteration changes nothing. If more widereaching simplification
430 // is desired, GreedyPatternRewriteDriver should be used.
431 bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
432   if (strictMode) {
433     strictModeFilteredOps.clear();
434     strictModeFilteredOps.insert(ops.begin(), ops.end());
435   }
436 
437   bool changed = false;
438   worklist.clear();
439   worklistMap.clear();
440   for (Operation *op : ops)
441     addToWorklist(op);
442 
443   // These are scratch vectors used in the folding loop below.
444   SmallVector<Value, 8> originalOperands, resultValues;
445   while (!worklist.empty()) {
446     Operation *op = popFromWorklist();
447 
448     // Nulls get added to the worklist when operations are removed, ignore
449     // them.
450     if (op == nullptr)
451       continue;
452 
453     // If the operation is trivially dead - remove it.
454     if (isOpTriviallyDead(op)) {
455       notifyOperationRemoved(op);
456       op->erase();
457       changed = true;
458       continue;
459     }
460 
461     // Collects all the operands and result uses of the given `op` into work
462     // list. Also remove `op` and nested ops from worklist.
463     originalOperands.assign(op->operand_begin(), op->operand_end());
464     auto preReplaceAction = [&](Operation *op) {
465       // Add the operands to the worklist for visitation.
466       addOperandsToWorklist(originalOperands);
467 
468       // Add all the users of the result to the worklist so we make sure
469       // to revisit them.
470       for (Value result : op->getResults())
471         for (Operation *userOp : result.getUsers()) {
472           if (!strictMode || strictModeFilteredOps.contains(userOp))
473             addToWorklist(userOp);
474         }
475       notifyOperationRemoved(op);
476     };
477 
478     // Add the given operation generated by the folder to the worklist.
479     auto processGeneratedConstants = [this](Operation *op) {
480       // Newly created ops are also simplified -- these are also "local".
481       addToWorklist(op);
482       // When strict mode is off, we don't need to maintain
483       // strictModeFilteredOps.
484       if (strictMode)
485         strictModeFilteredOps.insert(op);
486     };
487 
488     // Try to fold this op.
489     bool inPlaceUpdate;
490     if (succeeded(folder.tryToFold(op, processGeneratedConstants,
491                                    preReplaceAction, &inPlaceUpdate))) {
492       changed = true;
493       if (!inPlaceUpdate) {
494         // Op has been erased.
495         continue;
496       }
497     }
498 
499     // Try to match one of the patterns. The rewriter is automatically
500     // notified of any necessary changes, so there is nothing else to do
501     // here.
502     changed |= succeeded(matcher.matchAndRewrite(op, *this));
503   }
504 
505   return changed;
506 }
507 
508 /// Rewrites only `op` using the supplied canonicalization patterns and
509 /// folding. `erased` is set to true if the op is erased as a result of being
510 /// folded, replaced, or dead.
511 LogicalResult mlir::applyOpPatternsAndFold(
512     Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) {
513   // Start the pattern driver.
514   GreedyRewriteConfig config;
515   OpPatternRewriteDriver driver(op->getContext(), patterns);
516   bool opErased;
517   LogicalResult converged =
518       driver.simplifyLocally(op, config.maxIterations, opErased);
519   if (erased)
520     *erased = opErased;
521   LLVM_DEBUG(if (failed(converged)) {
522     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
523                  << config.maxIterations << " times";
524   });
525   return converged;
526 }
527 
528 bool mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
529                                   const FrozenRewritePatternSet &patterns,
530                                   bool strict) {
531   if (ops.empty())
532     return false;
533 
534   // Start the pattern driver.
535   MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
536                                      strict);
537   return driver.simplifyLocally(ops);
538 }
539