1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements mlir::applyPatternsAndFoldGreedily.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14 #include "mlir/Interfaces/SideEffectInterfaces.h"
15 #include "mlir/Rewrite/PatternApplicator.h"
16 #include "mlir/Transforms/FoldUtils.h"
17 #include "mlir/Transforms/RegionUtils.h"
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 using namespace mlir;
24 
25 #define DEBUG_TYPE "pattern-matcher"
26 
27 //===----------------------------------------------------------------------===//
28 // GreedyPatternRewriteDriver
29 //===----------------------------------------------------------------------===//
30 
31 namespace {
32 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
33 /// applies the locally optimal patterns in a roughly "bottom up" way.
34 class GreedyPatternRewriteDriver : public PatternRewriter {
35 public:
36   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
37                                       const FrozenRewritePatternSet &patterns,
38                                       const GreedyRewriteConfig &config);
39 
40   /// Simplify the operations within the given regions.
41   bool simplify(MutableArrayRef<Region> regions);
42 
43   /// Add the given operation to the worklist.
44   void addToWorklist(Operation *op);
45 
46   /// Pop the next operation from the worklist.
47   Operation *popFromWorklist();
48 
49   /// If the specified operation is in the worklist, remove it.
50   void removeFromWorklist(Operation *op);
51 
52 protected:
53   // Implement the hook for inserting operations, and make sure that newly
54   // inserted ops are added to the worklist for processing.
55   void notifyOperationInserted(Operation *op) override;
56 
57   // Look over the provided operands for any defining operations that should
58   // be re-added to the worklist. This function should be called when an
59   // operation is modified or removed, as it may trigger further
60   // simplifications.
61   template <typename Operands>
62   void addToWorklist(Operands &&operands);
63 
64   // If an operation is about to be removed, make sure it is not in our
65   // worklist anymore because we'd get dangling references to it.
66   void notifyOperationRemoved(Operation *op) override;
67 
68   // When the root of a pattern is about to be replaced, it can trigger
69   // simplifications to its users - make sure to add them to the worklist
70   // before the root is changed.
71   void notifyRootReplaced(Operation *op) override;
72 
73   /// The low-level pattern applicator.
74   PatternApplicator matcher;
75 
76   /// The worklist for this transformation keeps track of the operations that
77   /// need to be revisited, plus their index in the worklist.  This allows us to
78   /// efficiently remove operations from the worklist when they are erased, even
79   /// if they aren't the root of a pattern.
80   std::vector<Operation *> worklist;
81   DenseMap<Operation *, unsigned> worklistMap;
82 
83   /// Non-pattern based folder for operations.
84   OperationFolder folder;
85 
86 private:
87   /// Configuration information for how to simplify.
88   GreedyRewriteConfig config;
89 };
90 } // end anonymous namespace
91 
92 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
93     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
94     const GreedyRewriteConfig &config)
95     : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
96   worklist.reserve(64);
97 
98   // Apply a simple cost model based solely on pattern benefit.
99   matcher.applyDefaultCostModel();
100 }
101 
102 bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
103   bool changed = false;
104   unsigned iteration = 0;
105   do {
106     worklist.clear();
107     worklistMap.clear();
108 
109     if (!config.useTopDownTraversal) {
110       // Add operations to the worklist in postorder.
111       for (auto &region : regions)
112         region.walk([this](Operation *op) { addToWorklist(op); });
113     } else {
114       // Add all nested operations to the worklist in preorder.
115       for (auto &region : regions)
116         region.walk<WalkOrder::PreOrder>(
117             [this](Operation *op) { worklist.push_back(op); });
118 
119       // Reverse the list so our pop-back loop processes them in-order.
120       std::reverse(worklist.begin(), worklist.end());
121       // Remember the reverse index.
122       for (size_t i = 0, e = worklist.size(); i != e; ++i)
123         worklistMap[worklist[i]] = i;
124     }
125 
126     // These are scratch vectors used in the folding loop below.
127     SmallVector<Value, 8> originalOperands, resultValues;
128 
129     changed = false;
130     while (!worklist.empty()) {
131       auto *op = popFromWorklist();
132 
133       // Nulls get added to the worklist when operations are removed, ignore
134       // them.
135       if (op == nullptr)
136         continue;
137 
138       // If the operation is trivially dead - remove it.
139       if (isOpTriviallyDead(op)) {
140         notifyOperationRemoved(op);
141         op->erase();
142         changed = true;
143         continue;
144       }
145 
146       // Collects all the operands and result uses of the given `op` into work
147       // list. Also remove `op` and nested ops from worklist.
148       originalOperands.assign(op->operand_begin(), op->operand_end());
149       auto preReplaceAction = [&](Operation *op) {
150         // Add the operands to the worklist for visitation.
151         addToWorklist(originalOperands);
152 
153         // Add all the users of the result to the worklist so we make sure
154         // to revisit them.
155         for (auto result : op->getResults())
156           for (auto *userOp : result.getUsers())
157             addToWorklist(userOp);
158 
159         notifyOperationRemoved(op);
160       };
161 
162       // Add the given operation to the worklist.
163       auto collectOps = [this](Operation *op) { addToWorklist(op); };
164 
165       // Try to fold this op.
166       bool inPlaceUpdate;
167       if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
168                                       &inPlaceUpdate)))) {
169         changed = true;
170         if (!inPlaceUpdate)
171           continue;
172       }
173 
174       // Try to match one of the patterns. The rewriter is automatically
175       // notified of any necessary changes, so there is nothing else to do
176       // here.
177       changed |= succeeded(matcher.matchAndRewrite(op, *this));
178     }
179 
180     // After applying patterns, make sure that the CFG of each of the regions
181     // is kept up to date.
182     if (config.enableRegionSimplification)
183       changed |= succeeded(simplifyRegions(*this, regions));
184   } while (changed &&
185            (++iteration < config.maxIterations ||
186             config.maxIterations == GreedyRewriteConfig::kNoIterationLimit));
187 
188   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
189   return !changed;
190 }
191 
192 void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
193   // Check to see if the worklist already contains this op.
194   if (worklistMap.count(op))
195     return;
196 
197   worklistMap[op] = worklist.size();
198   worklist.push_back(op);
199 }
200 
201 Operation *GreedyPatternRewriteDriver::popFromWorklist() {
202   auto *op = worklist.back();
203   worklist.pop_back();
204 
205   // This operation is no longer in the worklist, keep worklistMap up to date.
206   if (op)
207     worklistMap.erase(op);
208   return op;
209 }
210 
211 void GreedyPatternRewriteDriver::removeFromWorklist(Operation *op) {
212   auto it = worklistMap.find(op);
213   if (it != worklistMap.end()) {
214     assert(worklist[it->second] == op && "malformed worklist data structure");
215     worklist[it->second] = nullptr;
216     worklistMap.erase(it);
217   }
218 }
219 
220 void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
221   addToWorklist(op);
222 }
223 
224 template <typename Operands>
225 void GreedyPatternRewriteDriver::addToWorklist(Operands &&operands) {
226   for (Value operand : operands) {
227     // If the use count of this operand is now < 2, we re-add the defining
228     // operation to the worklist.
229     // TODO: This is based on the fact that zero use operations
230     // may be deleted, and that single use values often have more
231     // canonicalization opportunities.
232     if (!operand || (!operand.use_empty() && !operand.hasOneUse()))
233       continue;
234     if (auto *defOp = operand.getDefiningOp())
235       addToWorklist(defOp);
236   }
237 }
238 
239 void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
240   addToWorklist(op->getOperands());
241   op->walk([this](Operation *operation) {
242     removeFromWorklist(operation);
243     folder.notifyRemoval(operation);
244   });
245 }
246 
247 void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op) {
248   for (auto result : op->getResults())
249     for (auto *user : result.getUsers())
250       addToWorklist(user);
251 }
252 
253 /// Rewrite the regions of the specified operation, which must be isolated from
254 /// above, by repeatedly applying the highest benefit patterns in a greedy
255 /// work-list driven manner. Return success if no more patterns can be matched
256 /// in the result operation regions. Note: This does not apply patterns to the
257 /// top-level operation itself.
258 ///
259 LogicalResult
260 mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
261                                    const FrozenRewritePatternSet &patterns,
262                                    GreedyRewriteConfig config) {
263   if (regions.empty())
264     return success();
265 
266   // The top-level operation must be known to be isolated from above to
267   // prevent performing canonicalizations on operations defined at or above
268   // the region containing 'op'.
269   auto regionIsIsolated = [](Region &region) {
270     return region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>();
271   };
272   (void)regionIsIsolated;
273   assert(llvm::all_of(regions, regionIsIsolated) &&
274          "patterns can only be applied to operations IsolatedFromAbove");
275 
276   // Start the pattern driver.
277   GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config);
278   bool converged = driver.simplify(regions);
279   LLVM_DEBUG(if (!converged) {
280     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
281                  << config.maxIterations << " times\n";
282   });
283   return success(converged);
284 }
285 
286 //===----------------------------------------------------------------------===//
287 // OpPatternRewriteDriver
288 //===----------------------------------------------------------------------===//
289 
290 namespace {
291 /// This is a simple driver for the PatternMatcher to apply patterns and perform
292 /// folding on a single op. It repeatedly applies locally optimal patterns.
293 class OpPatternRewriteDriver : public PatternRewriter {
294 public:
295   explicit OpPatternRewriteDriver(MLIRContext *ctx,
296                                   const FrozenRewritePatternSet &patterns)
297       : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
298     // Apply a simple cost model based solely on pattern benefit.
299     matcher.applyDefaultCostModel();
300   }
301 
302   LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased);
303 
304   // These are hooks implemented for PatternRewriter.
305 protected:
306   /// If an operation is about to be removed, mark it so that we can let clients
307   /// know.
308   void notifyOperationRemoved(Operation *op) override {
309     opErasedViaPatternRewrites = true;
310   }
311 
312   // When a root is going to be replaced, its removal will be notified as well.
313   // So there is nothing to do here.
314   void notifyRootReplaced(Operation *op) override {}
315 
316 private:
317   /// The low-level pattern applicator.
318   PatternApplicator matcher;
319 
320   /// Non-pattern based folder for operations.
321   OperationFolder folder;
322 
323   /// Set to true if the operation has been erased via pattern rewrites.
324   bool opErasedViaPatternRewrites = false;
325 };
326 
327 } // anonymous namespace
328 
329 /// Performs the rewrites and folding only on `op`. The simplification
330 /// converges if the op is erased as a result of being folded, replaced, or
331 /// becoming dead, or no more changes happen in an iteration. Returns success if
332 /// the rewrite converges in `maxIterations`. `erased` is set to true if `op`
333 /// gets erased.
334 LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
335                                                       int maxIterations,
336                                                       bool &erased) {
337   bool changed = false;
338   erased = false;
339   opErasedViaPatternRewrites = false;
340   int iterations = 0;
341   // Iterate until convergence or until maxIterations. Deletion of the op as
342   // a result of being dead or folded is convergence.
343   do {
344     changed = false;
345 
346     // If the operation is trivially dead - remove it.
347     if (isOpTriviallyDead(op)) {
348       op->erase();
349       erased = true;
350       return success();
351     }
352 
353     // Try to fold this op.
354     bool inPlaceUpdate;
355     if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr,
356                                    /*preReplaceAction=*/nullptr,
357                                    &inPlaceUpdate))) {
358       changed = true;
359       if (!inPlaceUpdate) {
360         erased = true;
361         return success();
362       }
363     }
364 
365     // Try to match one of the patterns. The rewriter is automatically
366     // notified of any necessary changes, so there is nothing else to do here.
367     changed |= succeeded(matcher.matchAndRewrite(op, *this));
368     if ((erased = opErasedViaPatternRewrites))
369       return success();
370   } while (changed &&
371            (++iterations < maxIterations ||
372             maxIterations == GreedyRewriteConfig::kNoIterationLimit));
373 
374   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
375   return failure(changed);
376 }
377 
378 //===----------------------------------------------------------------------===//
379 // MultiOpPatternRewriteDriver
380 //===----------------------------------------------------------------------===//
381 
382 namespace {
383 
384 /// This is a specialized GreedyPatternRewriteDriver to apply patterns and
385 /// perform folding for a supplied set of ops. It repeatedly simplifies while
386 /// restricting the rewrites to only the provided set of ops or optionally
387 /// to those directly affected by it (result users or operand providers).
388 class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
389 public:
390   explicit MultiOpPatternRewriteDriver(MLIRContext *ctx,
391                                        const FrozenRewritePatternSet &patterns,
392                                        bool strict)
393       : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()),
394         strictMode(strict) {}
395 
396   bool simplifyLocally(ArrayRef<Operation *> op);
397 
398 private:
399   // Look over the provided operands for any defining operations that should
400   // be re-added to the worklist. This function should be called when an
401   // operation is modified or removed, as it may trigger further
402   // simplifications. If `strict` is set to true, only ops in
403   // `strictModeFilteredOps` are considered.
404   template <typename Operands>
405   void addOperandsToWorklist(Operands &&operands) {
406     for (Value operand : operands) {
407       if (auto *defOp = operand.getDefiningOp()) {
408         if (!strictMode || strictModeFilteredOps.contains(defOp))
409           addToWorklist(defOp);
410       }
411     }
412   }
413 
414   void notifyOperationRemoved(Operation *op) override {
415     GreedyPatternRewriteDriver::notifyOperationRemoved(op);
416     if (strictMode)
417       strictModeFilteredOps.erase(op);
418   }
419 
420   /// If `strictMode` is true, any pre-existing ops outside of
421   /// `strictModeFilteredOps` remain completely untouched by the rewrite driver.
422   /// If `strictMode` is false, operations that use results of (or supply
423   /// operands to) any rewritten ops stemming from the simplification of the
424   /// provided ops are in turn simplified; any other ops still remain untouched
425   /// (i.e., regardless of `strictMode`).
426   bool strictMode = false;
427 
428   /// The list of ops we are restricting our rewrites to if `strictMode` is on.
429   /// These include the supplied set of ops as well as new ops created while
430   /// rewriting those ops. This set is not maintained when strictMode is off.
431   llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
432 };
433 
434 } // end anonymous namespace
435 
436 /// Performs the specified rewrites on `ops` while also trying to fold these ops
437 /// as well as any other ops that were in turn created due to these rewrite
438 /// patterns. Any pre-existing ops outside of `ops` remain completely
439 /// unmodified if `strictMode` is true. If `strictMode` is false, other
440 /// operations that use results of rewritten ops or supply operands to such ops
441 /// are in turn simplified; any other ops still remain unmodified (i.e.,
442 /// regardless of `strictMode`). Note that ops in `ops` could be erased as a
443 /// result of folding, becoming dead, or via pattern rewrites. Returns true if
444 /// at all any changes happened.
445 // Unlike `OpPatternRewriteDriver::simplifyLocally` which works on a single op
446 // or GreedyPatternRewriteDriver::simplify, this method just iterates until
447 // the worklist is empty. As our objective is to keep simplification "local",
448 // there is no strong rationale to re-add all operations into the worklist and
449 // rerun until an iteration changes nothing. If more widereaching simplification
450 // is desired, GreedyPatternRewriteDriver should be used.
451 bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
452   if (strictMode) {
453     strictModeFilteredOps.clear();
454     strictModeFilteredOps.insert(ops.begin(), ops.end());
455   }
456 
457   bool changed = false;
458   worklist.clear();
459   worklistMap.clear();
460   for (Operation *op : ops)
461     addToWorklist(op);
462 
463   // These are scratch vectors used in the folding loop below.
464   SmallVector<Value, 8> originalOperands, resultValues;
465   while (!worklist.empty()) {
466     Operation *op = popFromWorklist();
467 
468     // Nulls get added to the worklist when operations are removed, ignore
469     // them.
470     if (op == nullptr)
471       continue;
472 
473     // If the operation is trivially dead - remove it.
474     if (isOpTriviallyDead(op)) {
475       notifyOperationRemoved(op);
476       op->erase();
477       changed = true;
478       continue;
479     }
480 
481     // Collects all the operands and result uses of the given `op` into work
482     // list. Also remove `op` and nested ops from worklist.
483     originalOperands.assign(op->operand_begin(), op->operand_end());
484     auto preReplaceAction = [&](Operation *op) {
485       // Add the operands to the worklist for visitation.
486       addOperandsToWorklist(originalOperands);
487 
488       // Add all the users of the result to the worklist so we make sure
489       // to revisit them.
490       for (Value result : op->getResults())
491         for (Operation *userOp : result.getUsers()) {
492           if (!strictMode || strictModeFilteredOps.contains(userOp))
493             addToWorklist(userOp);
494         }
495       notifyOperationRemoved(op);
496     };
497 
498     // Add the given operation generated by the folder to the worklist.
499     auto processGeneratedConstants = [this](Operation *op) {
500       // Newly created ops are also simplified -- these are also "local".
501       addToWorklist(op);
502       // When strict mode is off, we don't need to maintain
503       // strictModeFilteredOps.
504       if (strictMode)
505         strictModeFilteredOps.insert(op);
506     };
507 
508     // Try to fold this op.
509     bool inPlaceUpdate;
510     if (succeeded(folder.tryToFold(op, processGeneratedConstants,
511                                    preReplaceAction, &inPlaceUpdate))) {
512       changed = true;
513       if (!inPlaceUpdate) {
514         // Op has been erased.
515         continue;
516       }
517     }
518 
519     // Try to match one of the patterns. The rewriter is automatically
520     // notified of any necessary changes, so there is nothing else to do
521     // here.
522     changed |= succeeded(matcher.matchAndRewrite(op, *this));
523   }
524 
525   return changed;
526 }
527 
528 /// Rewrites only `op` using the supplied canonicalization patterns and
529 /// folding. `erased` is set to true if the op is erased as a result of being
530 /// folded, replaced, or dead.
531 LogicalResult mlir::applyOpPatternsAndFold(
532     Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) {
533   // Start the pattern driver.
534   GreedyRewriteConfig config;
535   OpPatternRewriteDriver driver(op->getContext(), patterns);
536   bool opErased;
537   LogicalResult converged =
538       driver.simplifyLocally(op, config.maxIterations, opErased);
539   if (erased)
540     *erased = opErased;
541   LLVM_DEBUG(if (failed(converged)) {
542     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
543                  << config.maxIterations << " times";
544   });
545   return converged;
546 }
547 
548 bool mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
549                                   const FrozenRewritePatternSet &patterns,
550                                   bool strict) {
551   if (ops.empty())
552     return false;
553 
554   // Start the pattern driver.
555   MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
556                                      strict);
557   return driver.simplifyLocally(ops);
558 }
559