1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements mlir::applyPatternsAndFoldGreedily.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14 #include "mlir/Interfaces/SideEffectInterfaces.h"
15 #include "mlir/Rewrite/PatternApplicator.h"
16 #include "mlir/Transforms/FoldUtils.h"
17 #include "mlir/Transforms/RegionUtils.h"
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 using namespace mlir;
24 
25 #define DEBUG_TYPE "pattern-matcher"
26 
27 //===----------------------------------------------------------------------===//
28 // GreedyPatternRewriteDriver
29 //===----------------------------------------------------------------------===//
30 
31 namespace {
32 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
33 /// applies the locally optimal patterns in a roughly "bottom up" way.
34 class GreedyPatternRewriteDriver : public PatternRewriter {
35 public:
36   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
37                                       const FrozenRewritePatternSet &patterns,
38                                       const GreedyRewriteConfig &config)
39       : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
40     worklist.reserve(64);
41 
42     // Apply a simple cost model based solely on pattern benefit.
43     matcher.applyDefaultCostModel();
44   }
45 
46   bool simplify(MutableArrayRef<Region> regions);
47 
48   void addToWorklist(Operation *op) {
49     // Check to see if the worklist already contains this op.
50     if (worklistMap.count(op))
51       return;
52 
53     worklistMap[op] = worklist.size();
54     worklist.push_back(op);
55   }
56 
57   Operation *popFromWorklist() {
58     auto *op = worklist.back();
59     worklist.pop_back();
60 
61     // This operation is no longer in the worklist, keep worklistMap up to date.
62     if (op)
63       worklistMap.erase(op);
64     return op;
65   }
66 
67   /// If the specified operation is in the worklist, remove it.  If not, this is
68   /// a no-op.
69   void removeFromWorklist(Operation *op) {
70     auto it = worklistMap.find(op);
71     if (it != worklistMap.end()) {
72       assert(worklist[it->second] == op && "malformed worklist data structure");
73       worklist[it->second] = nullptr;
74       worklistMap.erase(it);
75     }
76   }
77 
78   // These are hooks implemented for PatternRewriter.
79 protected:
80   // Implement the hook for inserting operations, and make sure that newly
81   // inserted ops are added to the worklist for processing.
82   void notifyOperationInserted(Operation *op) override { addToWorklist(op); }
83 
84   // Look over the provided operands for any defining operations that should
85   // be re-added to the worklist. This function should be called when an
86   // operation is modified or removed, as it may trigger further
87   // simplifications.
88   template <typename Operands>
89   void addToWorklist(Operands &&operands) {
90     for (Value operand : operands) {
91       // If the use count of this operand is now < 2, we re-add the defining
92       // operation to the worklist.
93       // TODO: This is based on the fact that zero use operations
94       // may be deleted, and that single use values often have more
95       // canonicalization opportunities.
96       if (!operand || (!operand.use_empty() && !operand.hasOneUse()))
97         continue;
98       if (auto *defOp = operand.getDefiningOp())
99         addToWorklist(defOp);
100     }
101   }
102 
103   // If an operation is about to be removed, make sure it is not in our
104   // worklist anymore because we'd get dangling references to it.
105   void notifyOperationRemoved(Operation *op) override {
106     addToWorklist(op->getOperands());
107     op->walk([this](Operation *operation) {
108       removeFromWorklist(operation);
109       folder.notifyRemoval(operation);
110     });
111   }
112 
113   // When the root of a pattern is about to be replaced, it can trigger
114   // simplifications to its users - make sure to add them to the worklist
115   // before the root is changed.
116   void notifyRootReplaced(Operation *op) override {
117     for (auto result : op->getResults())
118       for (auto *user : result.getUsers())
119         addToWorklist(user);
120   }
121 
122   /// The low-level pattern applicator.
123   PatternApplicator matcher;
124 
125   /// The worklist for this transformation keeps track of the operations that
126   /// need to be revisited, plus their index in the worklist.  This allows us to
127   /// efficiently remove operations from the worklist when they are erased, even
128   /// if they aren't the root of a pattern.
129   std::vector<Operation *> worklist;
130   DenseMap<Operation *, unsigned> worklistMap;
131 
132   /// Non-pattern based folder for operations.
133   OperationFolder folder;
134 
135 private:
136   /// Configuration information for how to simplify.
137   GreedyRewriteConfig config;
138 };
139 } // end anonymous namespace
140 
141 /// Performs the rewrites while folding and erasing any dead ops. Returns true
142 /// if the rewrite converges in `maxIterations`.
143 bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
144   bool changed = false;
145   unsigned iteration = 0;
146   do {
147     worklist.clear();
148     worklistMap.clear();
149 
150     if (!config.useTopDownTraversal) {
151       // Add operations to the worklist in postorder.
152       for (auto &region : regions)
153         region.walk([this](Operation *op) { addToWorklist(op); });
154     } else {
155       // Add all nested operations to the worklist in preorder.
156       for (auto &region : regions)
157         region.walk<WalkOrder::PreOrder>(
158             [this](Operation *op) { worklist.push_back(op); });
159 
160       // Reverse the list so our pop-back loop processes them in-order.
161       std::reverse(worklist.begin(), worklist.end());
162       // Remember the reverse index.
163       for (size_t i = 0, e = worklist.size(); i != e; ++i)
164         worklistMap[worklist[i]] = i;
165     }
166 
167     // These are scratch vectors used in the folding loop below.
168     SmallVector<Value, 8> originalOperands, resultValues;
169 
170     changed = false;
171     while (!worklist.empty()) {
172       auto *op = popFromWorklist();
173 
174       // Nulls get added to the worklist when operations are removed, ignore
175       // them.
176       if (op == nullptr)
177         continue;
178 
179       // If the operation is trivially dead - remove it.
180       if (isOpTriviallyDead(op)) {
181         notifyOperationRemoved(op);
182         op->erase();
183         changed = true;
184         continue;
185       }
186 
187       // Collects all the operands and result uses of the given `op` into work
188       // list. Also remove `op` and nested ops from worklist.
189       originalOperands.assign(op->operand_begin(), op->operand_end());
190       auto preReplaceAction = [&](Operation *op) {
191         // Add the operands to the worklist for visitation.
192         addToWorklist(originalOperands);
193 
194         // Add all the users of the result to the worklist so we make sure
195         // to revisit them.
196         for (auto result : op->getResults())
197           for (auto *userOp : result.getUsers())
198             addToWorklist(userOp);
199 
200         notifyOperationRemoved(op);
201       };
202 
203       // Add the given operation to the worklist.
204       auto collectOps = [this](Operation *op) { addToWorklist(op); };
205 
206       // Try to fold this op.
207       bool inPlaceUpdate;
208       if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
209                                       &inPlaceUpdate)))) {
210         changed = true;
211         if (!inPlaceUpdate)
212           continue;
213       }
214 
215       // Try to match one of the patterns. The rewriter is automatically
216       // notified of any necessary changes, so there is nothing else to do
217       // here.
218       changed |= succeeded(matcher.matchAndRewrite(op, *this));
219     }
220 
221     // After applying patterns, make sure that the CFG of each of the regions
222     // is kept up to date.
223     if (config.enableRegionSimplification)
224       changed |= succeeded(simplifyRegions(*this, regions));
225   } while (changed && ++iteration < config.maxIterations);
226 
227   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
228   return !changed;
229 }
230 
231 /// Rewrite the regions of the specified operation, which must be isolated from
232 /// above, by repeatedly applying the highest benefit patterns in a greedy
233 /// work-list driven manner. Return success if no more patterns can be matched
234 /// in the result operation regions. Note: This does not apply patterns to the
235 /// top-level operation itself.
236 ///
237 LogicalResult
238 mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
239                                    const FrozenRewritePatternSet &patterns,
240                                    GreedyRewriteConfig config) {
241   if (regions.empty())
242     return success();
243 
244   // The top-level operation must be known to be isolated from above to
245   // prevent performing canonicalizations on operations defined at or above
246   // the region containing 'op'.
247   auto regionIsIsolated = [](Region &region) {
248     return region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>();
249   };
250   (void)regionIsIsolated;
251   assert(llvm::all_of(regions, regionIsIsolated) &&
252          "patterns can only be applied to operations IsolatedFromAbove");
253 
254   // Start the pattern driver.
255   GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config);
256   bool converged = driver.simplify(regions);
257   LLVM_DEBUG(if (!converged) {
258     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
259                  << config.maxIterations << " times\n";
260   });
261   return success(converged);
262 }
263 
264 //===----------------------------------------------------------------------===//
265 // OpPatternRewriteDriver
266 //===----------------------------------------------------------------------===//
267 
268 namespace {
269 /// This is a simple driver for the PatternMatcher to apply patterns and perform
270 /// folding on a single op. It repeatedly applies locally optimal patterns.
271 class OpPatternRewriteDriver : public PatternRewriter {
272 public:
273   explicit OpPatternRewriteDriver(MLIRContext *ctx,
274                                   const FrozenRewritePatternSet &patterns)
275       : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
276     // Apply a simple cost model based solely on pattern benefit.
277     matcher.applyDefaultCostModel();
278   }
279 
280   LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased);
281 
282   // These are hooks implemented for PatternRewriter.
283 protected:
284   /// If an operation is about to be removed, mark it so that we can let clients
285   /// know.
286   void notifyOperationRemoved(Operation *op) override {
287     opErasedViaPatternRewrites = true;
288   }
289 
290   // When a root is going to be replaced, its removal will be notified as well.
291   // So there is nothing to do here.
292   void notifyRootReplaced(Operation *op) override {}
293 
294 private:
295   /// The low-level pattern applicator.
296   PatternApplicator matcher;
297 
298   /// Non-pattern based folder for operations.
299   OperationFolder folder;
300 
301   /// Set to true if the operation has been erased via pattern rewrites.
302   bool opErasedViaPatternRewrites = false;
303 };
304 
305 } // anonymous namespace
306 
307 /// Performs the rewrites and folding only on `op`. The simplification
308 /// converges if the op is erased as a result of being folded, replaced, or
309 /// becoming dead, or no more changes happen in an iteration. Returns success if
310 /// the rewrite converges in `maxIterations`. `erased` is set to true if `op`
311 /// gets erased.
312 LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
313                                                       int maxIterations,
314                                                       bool &erased) {
315   bool changed = false;
316   erased = false;
317   opErasedViaPatternRewrites = false;
318   int iterations = 0;
319   // Iterate until convergence or until maxIterations. Deletion of the op as
320   // a result of being dead or folded is convergence.
321   do {
322     changed = false;
323 
324     // If the operation is trivially dead - remove it.
325     if (isOpTriviallyDead(op)) {
326       op->erase();
327       erased = true;
328       return success();
329     }
330 
331     // Try to fold this op.
332     bool inPlaceUpdate;
333     if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr,
334                                    /*preReplaceAction=*/nullptr,
335                                    &inPlaceUpdate))) {
336       changed = true;
337       if (!inPlaceUpdate) {
338         erased = true;
339         return success();
340       }
341     }
342 
343     // Try to match one of the patterns. The rewriter is automatically
344     // notified of any necessary changes, so there is nothing else to do here.
345     changed |= succeeded(matcher.matchAndRewrite(op, *this));
346     if ((erased = opErasedViaPatternRewrites))
347       return success();
348   } while (changed && ++iterations < maxIterations);
349 
350   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
351   return failure(changed);
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // MultiOpPatternRewriteDriver
356 //===----------------------------------------------------------------------===//
357 
358 namespace {
359 
360 /// This is a specialized GreedyPatternRewriteDriver to apply patterns and
361 /// perform folding for a supplied set of ops. It repeatedly simplifies while
362 /// restricting the rewrites to only the provided set of ops or optionally
363 /// to those directly affected by it (result users or operand providers).
364 class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
365 public:
366   explicit MultiOpPatternRewriteDriver(MLIRContext *ctx,
367                                        const FrozenRewritePatternSet &patterns,
368                                        bool strict)
369       : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()),
370         strictMode(strict) {}
371 
372   bool simplifyLocally(ArrayRef<Operation *> op);
373 
374 private:
375   // Look over the provided operands for any defining operations that should
376   // be re-added to the worklist. This function should be called when an
377   // operation is modified or removed, as it may trigger further
378   // simplifications. If `strict` is set to true, only ops in
379   // `strictModeFilteredOps` are considered.
380   template <typename Operands>
381   void addOperandsToWorklist(Operands &&operands) {
382     for (Value operand : operands) {
383       if (auto *defOp = operand.getDefiningOp()) {
384         if (!strictMode || strictModeFilteredOps.contains(defOp))
385           addToWorklist(defOp);
386       }
387     }
388   }
389 
390   void notifyOperationRemoved(Operation *op) override {
391     GreedyPatternRewriteDriver::notifyOperationRemoved(op);
392     if (strictMode)
393       strictModeFilteredOps.erase(op);
394   }
395 
396   /// If `strictMode` is true, any pre-existing ops outside of
397   /// `strictModeFilteredOps` remain completely untouched by the rewrite driver.
398   /// If `strictMode` is false, operations that use results of (or supply
399   /// operands to) any rewritten ops stemming from the simplification of the
400   /// provided ops are in turn simplified; any other ops still remain untouched
401   /// (i.e., regardless of `strictMode`).
402   bool strictMode = false;
403 
404   /// The list of ops we are restricting our rewrites to if `strictMode` is on.
405   /// These include the supplied set of ops as well as new ops created while
406   /// rewriting those ops. This set is not maintained when strictMode is off.
407   llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
408 };
409 
410 } // end anonymous namespace
411 
412 /// Performs the specified rewrites on `ops` while also trying to fold these ops
413 /// as well as any other ops that were in turn created due to these rewrite
414 /// patterns. Any pre-existing ops outside of `ops` remain completely
415 /// unmodified if `strictMode` is true. If `strictMode` is false, other
416 /// operations that use results of rewritten ops or supply operands to such ops
417 /// are in turn simplified; any other ops still remain unmodified (i.e.,
418 /// regardless of `strictMode`). Note that ops in `ops` could be erased as a
419 /// result of folding, becoming dead, or via pattern rewrites. Returns true if
420 /// at all any changes happened.
421 // Unlike `OpPatternRewriteDriver::simplifyLocally` which works on a single op
422 // or GreedyPatternRewriteDriver::simplify, this method just iterates until
423 // the worklist is empty. As our objective is to keep simplification "local",
424 // there is no strong rationale to re-add all operations into the worklist and
425 // rerun until an iteration changes nothing. If more widereaching simplification
426 // is desired, GreedyPatternRewriteDriver should be used.
427 bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
428   if (strictMode) {
429     strictModeFilteredOps.clear();
430     strictModeFilteredOps.insert(ops.begin(), ops.end());
431   }
432 
433   bool changed = false;
434   worklist.clear();
435   worklistMap.clear();
436   for (Operation *op : ops)
437     addToWorklist(op);
438 
439   // These are scratch vectors used in the folding loop below.
440   SmallVector<Value, 8> originalOperands, resultValues;
441   while (!worklist.empty()) {
442     Operation *op = popFromWorklist();
443 
444     // Nulls get added to the worklist when operations are removed, ignore
445     // them.
446     if (op == nullptr)
447       continue;
448 
449     // If the operation is trivially dead - remove it.
450     if (isOpTriviallyDead(op)) {
451       notifyOperationRemoved(op);
452       op->erase();
453       changed = true;
454       continue;
455     }
456 
457     // Collects all the operands and result uses of the given `op` into work
458     // list. Also remove `op` and nested ops from worklist.
459     originalOperands.assign(op->operand_begin(), op->operand_end());
460     auto preReplaceAction = [&](Operation *op) {
461       // Add the operands to the worklist for visitation.
462       addOperandsToWorklist(originalOperands);
463 
464       // Add all the users of the result to the worklist so we make sure
465       // to revisit them.
466       for (Value result : op->getResults())
467         for (Operation *userOp : result.getUsers()) {
468           if (!strictMode || strictModeFilteredOps.contains(userOp))
469             addToWorklist(userOp);
470         }
471       notifyOperationRemoved(op);
472     };
473 
474     // Add the given operation generated by the folder to the worklist.
475     auto processGeneratedConstants = [this](Operation *op) {
476       // Newly created ops are also simplified -- these are also "local".
477       addToWorklist(op);
478       // When strict mode is off, we don't need to maintain
479       // strictModeFilteredOps.
480       if (strictMode)
481         strictModeFilteredOps.insert(op);
482     };
483 
484     // Try to fold this op.
485     bool inPlaceUpdate;
486     if (succeeded(folder.tryToFold(op, processGeneratedConstants,
487                                    preReplaceAction, &inPlaceUpdate))) {
488       changed = true;
489       if (!inPlaceUpdate) {
490         // Op has been erased.
491         continue;
492       }
493     }
494 
495     // Try to match one of the patterns. The rewriter is automatically
496     // notified of any necessary changes, so there is nothing else to do
497     // here.
498     changed |= succeeded(matcher.matchAndRewrite(op, *this));
499   }
500 
501   return changed;
502 }
503 
504 /// Rewrites only `op` using the supplied canonicalization patterns and
505 /// folding. `erased` is set to true if the op is erased as a result of being
506 /// folded, replaced, or dead.
507 LogicalResult mlir::applyOpPatternsAndFold(
508     Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) {
509   // Start the pattern driver.
510   GreedyRewriteConfig config;
511   OpPatternRewriteDriver driver(op->getContext(), patterns);
512   bool opErased;
513   LogicalResult converged =
514       driver.simplifyLocally(op, config.maxIterations, opErased);
515   if (erased)
516     *erased = opErased;
517   LLVM_DEBUG(if (failed(converged)) {
518     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
519                  << config.maxIterations << " times";
520   });
521   return converged;
522 }
523 
524 bool mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
525                                   const FrozenRewritePatternSet &patterns,
526                                   bool strict) {
527   if (ops.empty())
528     return false;
529 
530   // Start the pattern driver.
531   MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
532                                      strict);
533   return driver.simplifyLocally(ops);
534 }
535