164d52014SChris Lattner //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
264d52014SChris Lattner //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
664d52014SChris Lattner //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
864d52014SChris Lattner //
9a60fdd2bSLorenzo Chelini // This file implements mlir::applyPatternsAndFoldGreedily.
1064d52014SChris Lattner //
1164d52014SChris Lattner //===----------------------------------------------------------------------===//
1264d52014SChris Lattner 
13b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14eb623ae8SStephen Neuendorffer #include "mlir/Interfaces/SideEffectInterfaces.h"
15b6eb26fdSRiver Riddle #include "mlir/Rewrite/PatternApplicator.h"
161982afb1SRiver Riddle #include "mlir/Transforms/FoldUtils.h"
17fafb708bSRiver Riddle #include "mlir/Transforms/RegionUtils.h"
1864d52014SChris Lattner #include "llvm/ADT/DenseMap.h"
195c757087SFeng Liu #include "llvm/Support/CommandLine.h"
205c757087SFeng Liu #include "llvm/Support/Debug.h"
21*5652ecc3SRiver Riddle #include "llvm/Support/ScopedPrinter.h"
225c757087SFeng Liu #include "llvm/Support/raw_ostream.h"
234e40c832SLei Zhang 
2464d52014SChris Lattner using namespace mlir;
2564d52014SChris Lattner 
26*5652ecc3SRiver Riddle #define DEBUG_TYPE "greedy-rewriter"
275c757087SFeng Liu 
2804b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
2904b5274eSUday Bondhugula // GreedyPatternRewriteDriver
3004b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
3104b5274eSUday Bondhugula 
3264d52014SChris Lattner namespace {
3364d52014SChris Lattner /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
3464d52014SChris Lattner /// applies the locally optimal patterns in a roughly "bottom up" way.
354bd9f936SChris Lattner class GreedyPatternRewriteDriver : public PatternRewriter {
3664d52014SChris Lattner public:
372566a72aSRiver Riddle   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
38648f34a2SChris Lattner                                       const FrozenRewritePatternSet &patterns,
39b7144ab7SRiver Riddle                                       const GreedyRewriteConfig &config);
403e98fbf4SRiver Riddle 
41b7144ab7SRiver Riddle   /// Simplify the operations within the given regions.
4264716b2cSChris Lattner   bool simplify(MutableArrayRef<Region> regions);
4364d52014SChris Lattner 
44b7144ab7SRiver Riddle   /// Add the given operation to the worklist.
45b7144ab7SRiver Riddle   void addToWorklist(Operation *op);
465c4f1fddSRiver Riddle 
47b7144ab7SRiver Riddle   /// Pop the next operation from the worklist.
48b7144ab7SRiver Riddle   Operation *popFromWorklist();
4964d52014SChris Lattner 
50b7144ab7SRiver Riddle   /// If the specified operation is in the worklist, remove it.
51b7144ab7SRiver Riddle   void removeFromWorklist(Operation *op);
5264d52014SChris Lattner 
534bd9f936SChris Lattner protected:
54851a8516SRiver Riddle   // Implement the hook for inserting operations, and make sure that newly
55851a8516SRiver Riddle   // inserted ops are added to the worklist for processing.
56b7144ab7SRiver Riddle   void notifyOperationInserted(Operation *op) override;
5764d52014SChris Lattner 
587932d21fSUday Bondhugula   // Look over the provided operands for any defining operations that should
597932d21fSUday Bondhugula   // be re-added to the worklist. This function should be called when an
607932d21fSUday Bondhugula   // operation is modified or removed, as it may trigger further
617932d21fSUday Bondhugula   // simplifications.
627932d21fSUday Bondhugula   template <typename Operands>
63b7144ab7SRiver Riddle   void addToWorklist(Operands &&operands);
647932d21fSUday Bondhugula 
6564d52014SChris Lattner   // If an operation is about to be removed, make sure it is not in our
6664d52014SChris Lattner   // worklist anymore because we'd get dangling references to it.
67b7144ab7SRiver Riddle   void notifyOperationRemoved(Operation *op) override;
6864d52014SChris Lattner 
69085b687fSChris Lattner   // When the root of a pattern is about to be replaced, it can trigger
70085b687fSChris Lattner   // simplifications to its users - make sure to add them to the worklist
71085b687fSChris Lattner   // before the root is changed.
72b7144ab7SRiver Riddle   void notifyRootReplaced(Operation *op) override;
73085b687fSChris Lattner 
74*5652ecc3SRiver Riddle   /// PatternRewriter hook for erasing a dead operation.
75*5652ecc3SRiver Riddle   void eraseOp(Operation *op) override;
76*5652ecc3SRiver Riddle 
77*5652ecc3SRiver Riddle   /// PatternRewriter hook for notifying match failure reasons.
78*5652ecc3SRiver Riddle   LogicalResult
79*5652ecc3SRiver Riddle   notifyMatchFailure(Operation *op,
80*5652ecc3SRiver Riddle                      function_ref<void(Diagnostic &)> reasonCallback) override;
81*5652ecc3SRiver Riddle 
823e98fbf4SRiver Riddle   /// The low-level pattern applicator.
833e98fbf4SRiver Riddle   PatternApplicator matcher;
844bd9f936SChris Lattner 
854bd9f936SChris Lattner   /// The worklist for this transformation keeps track of the operations that
864bd9f936SChris Lattner   /// need to be revisited, plus their index in the worklist.  This allows us to
87e7a2ef21SRiver Riddle   /// efficiently remove operations from the worklist when they are erased, even
88e7a2ef21SRiver Riddle   /// if they aren't the root of a pattern.
8999b87c97SRiver Riddle   std::vector<Operation *> worklist;
9099b87c97SRiver Riddle   DenseMap<Operation *, unsigned> worklistMap;
9160a29837SRiver Riddle 
9260a29837SRiver Riddle   /// Non-pattern based folder for operations.
9360a29837SRiver Riddle   OperationFolder folder;
94648f34a2SChris Lattner 
957932d21fSUday Bondhugula private:
9664716b2cSChris Lattner   /// Configuration information for how to simplify.
9764716b2cSChris Lattner   GreedyRewriteConfig config;
98*5652ecc3SRiver Riddle 
99*5652ecc3SRiver Riddle #ifndef NDEBUG
100*5652ecc3SRiver Riddle   /// A logger used to emit information during the application process.
101*5652ecc3SRiver Riddle   llvm::ScopedPrinter logger{llvm::dbgs()};
102*5652ecc3SRiver Riddle #endif
10364d52014SChris Lattner };
10491f07810SMehdi Amini } // end anonymous namespace
10564d52014SChris Lattner 
106b7144ab7SRiver Riddle GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
107b7144ab7SRiver Riddle     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
108b7144ab7SRiver Riddle     const GreedyRewriteConfig &config)
109b7144ab7SRiver Riddle     : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
110b7144ab7SRiver Riddle   worklist.reserve(64);
111b7144ab7SRiver Riddle 
112b7144ab7SRiver Riddle   // Apply a simple cost model based solely on pattern benefit.
113b7144ab7SRiver Riddle   matcher.applyDefaultCostModel();
114b7144ab7SRiver Riddle }
115b7144ab7SRiver Riddle 
11664716b2cSChris Lattner bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
117*5652ecc3SRiver Riddle #ifndef NDEBUG
118*5652ecc3SRiver Riddle   const char *logLineComment =
119*5652ecc3SRiver Riddle       "//===-------------------------------------------===//\n";
120*5652ecc3SRiver Riddle 
121*5652ecc3SRiver Riddle   /// A utility function to log a process result for the given reason.
122*5652ecc3SRiver Riddle   auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
123*5652ecc3SRiver Riddle     logger.unindent();
124*5652ecc3SRiver Riddle     logger.startLine() << "} -> " << result;
125*5652ecc3SRiver Riddle     if (!msg.isTriviallyEmpty())
126*5652ecc3SRiver Riddle       logger.getOStream() << " : " << msg;
127*5652ecc3SRiver Riddle     logger.getOStream() << "\n";
128*5652ecc3SRiver Riddle   };
129*5652ecc3SRiver Riddle   auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
130*5652ecc3SRiver Riddle     logResult(result, msg);
131*5652ecc3SRiver Riddle     logger.startLine() << logLineComment;
132*5652ecc3SRiver Riddle   };
133*5652ecc3SRiver Riddle #endif
134*5652ecc3SRiver Riddle 
1355c757087SFeng Liu   bool changed = false;
13664716b2cSChris Lattner   unsigned iteration = 0;
1375c757087SFeng Liu   do {
138648f34a2SChris Lattner     worklist.clear();
139648f34a2SChris Lattner     worklistMap.clear();
140648f34a2SChris Lattner 
14164716b2cSChris Lattner     if (!config.useTopDownTraversal) {
14264716b2cSChris Lattner       // Add operations to the worklist in postorder.
14364716b2cSChris Lattner       for (auto &region : regions)
14464716b2cSChris Lattner         region.walk([this](Operation *op) { addToWorklist(op); });
14564716b2cSChris Lattner     } else {
146648f34a2SChris Lattner       // Add all nested operations to the worklist in preorder.
1476b1cc3c6SRiver Riddle       for (auto &region : regions)
148648f34a2SChris Lattner         region.walk<WalkOrder::PreOrder>(
149648f34a2SChris Lattner             [this](Operation *op) { worklist.push_back(op); });
150648f34a2SChris Lattner 
151648f34a2SChris Lattner       // Reverse the list so our pop-back loop processes them in-order.
152648f34a2SChris Lattner       std::reverse(worklist.begin(), worklist.end());
153648f34a2SChris Lattner       // Remember the reverse index.
154648f34a2SChris Lattner       for (size_t i = 0, e = worklist.size(); i != e; ++i)
155648f34a2SChris Lattner         worklistMap[worklist[i]] = i;
156648f34a2SChris Lattner     }
1574e40c832SLei Zhang 
1584e40c832SLei Zhang     // These are scratch vectors used in the folding loop below.
159e62a6956SRiver Riddle     SmallVector<Value, 8> originalOperands, resultValues;
16064d52014SChris Lattner 
1615c757087SFeng Liu     changed = false;
16264d52014SChris Lattner     while (!worklist.empty()) {
16364d52014SChris Lattner       auto *op = popFromWorklist();
16464d52014SChris Lattner 
1655c757087SFeng Liu       // Nulls get added to the worklist when operations are removed, ignore
1665c757087SFeng Liu       // them.
16764d52014SChris Lattner       if (op == nullptr)
16864d52014SChris Lattner         continue;
16964d52014SChris Lattner 
170*5652ecc3SRiver Riddle       LLVM_DEBUG({
171*5652ecc3SRiver Riddle         logger.getOStream() << "\n";
172*5652ecc3SRiver Riddle         logger.startLine() << logLineComment;
173*5652ecc3SRiver Riddle         logger.startLine() << "Processing operation : '" << op->getName()
174*5652ecc3SRiver Riddle                            << "'(" << op << ") {\n";
175*5652ecc3SRiver Riddle         logger.indent();
176*5652ecc3SRiver Riddle 
177*5652ecc3SRiver Riddle         // If the operation has no regions, just print it here.
178*5652ecc3SRiver Riddle         if (op->getNumRegions() == 0) {
179*5652ecc3SRiver Riddle           op->print(
180*5652ecc3SRiver Riddle               logger.startLine(),
181*5652ecc3SRiver Riddle               OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
182*5652ecc3SRiver Riddle           logger.getOStream() << "\n\n";
183*5652ecc3SRiver Riddle         }
184*5652ecc3SRiver Riddle       });
185*5652ecc3SRiver Riddle 
1860ddba0bdSRiver Riddle       // If the operation is trivially dead - remove it.
1870ddba0bdSRiver Riddle       if (isOpTriviallyDead(op)) {
1886a501e3dSAndy Ly         notifyOperationRemoved(op);
18964d52014SChris Lattner         op->erase();
190f875e55bSUday Bondhugula         changed = true;
191*5652ecc3SRiver Riddle 
192*5652ecc3SRiver Riddle         LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
19364d52014SChris Lattner         continue;
19464d52014SChris Lattner       }
19564d52014SChris Lattner 
1964e40c832SLei Zhang       // Collects all the operands and result uses of the given `op` into work
1976a501e3dSAndy Ly       // list. Also remove `op` and nested ops from worklist.
1981982afb1SRiver Riddle       originalOperands.assign(op->operand_begin(), op->operand_end());
1996a501e3dSAndy Ly       auto preReplaceAction = [&](Operation *op) {
200a8866258SRiver Riddle         // Add the operands to the worklist for visitation.
2011982afb1SRiver Riddle         addToWorklist(originalOperands);
2021982afb1SRiver Riddle 
2034e40c832SLei Zhang         // Add all the users of the result to the worklist so we make sure
2044e40c832SLei Zhang         // to revisit them.
20535807bc4SRiver Riddle         for (auto result : op->getResults())
206cc673894SUday Bondhugula           for (auto *userOp : result.getUsers())
207cc673894SUday Bondhugula             addToWorklist(userOp);
2086a501e3dSAndy Ly 
2096a501e3dSAndy Ly         notifyOperationRemoved(op);
2104e40c832SLei Zhang       };
21164d52014SChris Lattner 
212648f34a2SChris Lattner       // Add the given operation to the worklist.
213648f34a2SChris Lattner       auto collectOps = [this](Operation *op) { addToWorklist(op); };
214648f34a2SChris Lattner 
2151982afb1SRiver Riddle       // Try to fold this op.
216cbcb12fdSUday Bondhugula       bool inPlaceUpdate;
217cbcb12fdSUday Bondhugula       if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
218cbcb12fdSUday Bondhugula                                       &inPlaceUpdate)))) {
219*5652ecc3SRiver Riddle         LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
220*5652ecc3SRiver Riddle 
221f875e55bSUday Bondhugula         changed = true;
222cbcb12fdSUday Bondhugula         if (!inPlaceUpdate)
223934b6d12SChris Lattner           continue;
22464d52014SChris Lattner       }
22564d52014SChris Lattner 
22632052c84SRiver Riddle       // Try to match one of the patterns. The rewriter is automatically
227648f34a2SChris Lattner       // notified of any necessary changes, so there is nothing else to do
228648f34a2SChris Lattner       // here.
229*5652ecc3SRiver Riddle #ifndef NDEBUG
230*5652ecc3SRiver Riddle       auto canApply = [&](const Pattern &pattern) {
231*5652ecc3SRiver Riddle         LLVM_DEBUG({
232*5652ecc3SRiver Riddle           logger.getOStream() << "\n";
233*5652ecc3SRiver Riddle           logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
234*5652ecc3SRiver Riddle                              << op->getName() << " -> (";
235*5652ecc3SRiver Riddle           llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
236*5652ecc3SRiver Riddle           logger.getOStream() << ")' {\n";
237*5652ecc3SRiver Riddle           logger.indent();
238*5652ecc3SRiver Riddle         });
239*5652ecc3SRiver Riddle         return true;
240*5652ecc3SRiver Riddle       };
241*5652ecc3SRiver Riddle       auto onFailure = [&](const Pattern &pattern) {
242*5652ecc3SRiver Riddle         LLVM_DEBUG(logResult("failure", "pattern failed to match"));
243*5652ecc3SRiver Riddle       };
244*5652ecc3SRiver Riddle       auto onSuccess = [&](const Pattern &pattern) {
245*5652ecc3SRiver Riddle         LLVM_DEBUG(logResult("success", "pattern applied successfully"));
246*5652ecc3SRiver Riddle         return success();
247*5652ecc3SRiver Riddle       };
248*5652ecc3SRiver Riddle 
249*5652ecc3SRiver Riddle       LogicalResult matchResult =
250*5652ecc3SRiver Riddle           matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
251*5652ecc3SRiver Riddle       if (succeeded(matchResult))
252*5652ecc3SRiver Riddle         LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
253*5652ecc3SRiver Riddle       else
254*5652ecc3SRiver Riddle         LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
255*5652ecc3SRiver Riddle #else
256*5652ecc3SRiver Riddle       LogicalResult matchResult = matcher.matchAndRewrite(op, *this);
257*5652ecc3SRiver Riddle #endif
258*5652ecc3SRiver Riddle 
259*5652ecc3SRiver Riddle 
260*5652ecc3SRiver Riddle #ifndef NDEBUG
261*5652ecc3SRiver Riddle #endif
262*5652ecc3SRiver Riddle 
263*5652ecc3SRiver Riddle       changed |= succeeded(matchResult);
26464d52014SChris Lattner     }
265a32f0dcbSRiver Riddle 
266648f34a2SChris Lattner     // After applying patterns, make sure that the CFG of each of the regions
267648f34a2SChris Lattner     // is kept up to date.
26864716b2cSChris Lattner     if (config.enableRegionSimplification)
269d75a611aSRiver Riddle       changed |= succeeded(simplifyRegions(*this, regions));
270519663beSFrederik Gossen   } while (changed &&
271519663beSFrederik Gossen            (++iteration < config.maxIterations ||
272519663beSFrederik Gossen             config.maxIterations == GreedyRewriteConfig::kNoIterationLimit));
273648f34a2SChris Lattner 
2745c757087SFeng Liu   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
2755c757087SFeng Liu   return !changed;
27664d52014SChris Lattner }
27764d52014SChris Lattner 
278b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
279b7144ab7SRiver Riddle   // Check to see if the worklist already contains this op.
280b7144ab7SRiver Riddle   if (worklistMap.count(op))
281b7144ab7SRiver Riddle     return;
282b7144ab7SRiver Riddle 
283b7144ab7SRiver Riddle   worklistMap[op] = worklist.size();
284b7144ab7SRiver Riddle   worklist.push_back(op);
285b7144ab7SRiver Riddle }
286b7144ab7SRiver Riddle 
287b7144ab7SRiver Riddle Operation *GreedyPatternRewriteDriver::popFromWorklist() {
288b7144ab7SRiver Riddle   auto *op = worklist.back();
289b7144ab7SRiver Riddle   worklist.pop_back();
290b7144ab7SRiver Riddle 
291b7144ab7SRiver Riddle   // This operation is no longer in the worklist, keep worklistMap up to date.
292b7144ab7SRiver Riddle   if (op)
293b7144ab7SRiver Riddle     worklistMap.erase(op);
294b7144ab7SRiver Riddle   return op;
295b7144ab7SRiver Riddle }
296b7144ab7SRiver Riddle 
297b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::removeFromWorklist(Operation *op) {
298b7144ab7SRiver Riddle   auto it = worklistMap.find(op);
299b7144ab7SRiver Riddle   if (it != worklistMap.end()) {
300b7144ab7SRiver Riddle     assert(worklist[it->second] == op && "malformed worklist data structure");
301b7144ab7SRiver Riddle     worklist[it->second] = nullptr;
302b7144ab7SRiver Riddle     worklistMap.erase(it);
303b7144ab7SRiver Riddle   }
304b7144ab7SRiver Riddle }
305b7144ab7SRiver Riddle 
306b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
307*5652ecc3SRiver Riddle   LLVM_DEBUG({
308*5652ecc3SRiver Riddle     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
309*5652ecc3SRiver Riddle                        << ")\n";
310*5652ecc3SRiver Riddle   });
311b7144ab7SRiver Riddle   addToWorklist(op);
312b7144ab7SRiver Riddle }
313b7144ab7SRiver Riddle 
314b7144ab7SRiver Riddle template <typename Operands>
315b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::addToWorklist(Operands &&operands) {
316b7144ab7SRiver Riddle   for (Value operand : operands) {
317b7144ab7SRiver Riddle     // If the use count of this operand is now < 2, we re-add the defining
318b7144ab7SRiver Riddle     // operation to the worklist.
319b7144ab7SRiver Riddle     // TODO: This is based on the fact that zero use operations
320b7144ab7SRiver Riddle     // may be deleted, and that single use values often have more
321b7144ab7SRiver Riddle     // canonicalization opportunities.
322b7144ab7SRiver Riddle     if (!operand || (!operand.use_empty() && !operand.hasOneUse()))
323b7144ab7SRiver Riddle       continue;
324b7144ab7SRiver Riddle     if (auto *defOp = operand.getDefiningOp())
325b7144ab7SRiver Riddle       addToWorklist(defOp);
326b7144ab7SRiver Riddle   }
327b7144ab7SRiver Riddle }
328b7144ab7SRiver Riddle 
329b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
330b7144ab7SRiver Riddle   addToWorklist(op->getOperands());
331b7144ab7SRiver Riddle   op->walk([this](Operation *operation) {
332b7144ab7SRiver Riddle     removeFromWorklist(operation);
333b7144ab7SRiver Riddle     folder.notifyRemoval(operation);
334b7144ab7SRiver Riddle   });
335b7144ab7SRiver Riddle }
336b7144ab7SRiver Riddle 
337b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op) {
338*5652ecc3SRiver Riddle   LLVM_DEBUG({
339*5652ecc3SRiver Riddle     logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
340*5652ecc3SRiver Riddle                        << ")\n";
341*5652ecc3SRiver Riddle   });
342b7144ab7SRiver Riddle   for (auto result : op->getResults())
343b7144ab7SRiver Riddle     for (auto *user : result.getUsers())
344b7144ab7SRiver Riddle       addToWorklist(user);
345b7144ab7SRiver Riddle }
346b7144ab7SRiver Riddle 
347*5652ecc3SRiver Riddle void GreedyPatternRewriteDriver::eraseOp(Operation *op) {
348*5652ecc3SRiver Riddle   LLVM_DEBUG({
349*5652ecc3SRiver Riddle     logger.startLine() << "** Erase   : '" << op->getName() << "'(" << op
350*5652ecc3SRiver Riddle                        << ")\n";
351*5652ecc3SRiver Riddle   });
352*5652ecc3SRiver Riddle   PatternRewriter::eraseOp(op);
353*5652ecc3SRiver Riddle }
354*5652ecc3SRiver Riddle 
355*5652ecc3SRiver Riddle LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
356*5652ecc3SRiver Riddle     Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
357*5652ecc3SRiver Riddle   LLVM_DEBUG({
358*5652ecc3SRiver Riddle     Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
359*5652ecc3SRiver Riddle     reasonCallback(diag);
360*5652ecc3SRiver Riddle     logger.startLine() << "** Failure : " << diag.str() << "\n";
361*5652ecc3SRiver Riddle   });
362*5652ecc3SRiver Riddle   return failure();
363*5652ecc3SRiver Riddle }
364*5652ecc3SRiver Riddle 
365e7a2ef21SRiver Riddle /// Rewrite the regions of the specified operation, which must be isolated from
366e7a2ef21SRiver Riddle /// above, by repeatedly applying the highest benefit patterns in a greedy
3673e98fbf4SRiver Riddle /// work-list driven manner. Return success if no more patterns can be matched
3683e98fbf4SRiver Riddle /// in the result operation regions. Note: This does not apply patterns to the
3693e98fbf4SRiver Riddle /// top-level operation itself.
37064d52014SChris Lattner ///
3713e98fbf4SRiver Riddle LogicalResult
3723e98fbf4SRiver Riddle mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
3730b20413eSUday Bondhugula                                    const FrozenRewritePatternSet &patterns,
37464716b2cSChris Lattner                                    GreedyRewriteConfig config) {
3756b1cc3c6SRiver Riddle   if (regions.empty())
3763e98fbf4SRiver Riddle     return success();
3776b1cc3c6SRiver Riddle 
378e7a2ef21SRiver Riddle   // The top-level operation must be known to be isolated from above to
379e7a2ef21SRiver Riddle   // prevent performing canonicalizations on operations defined at or above
380e7a2ef21SRiver Riddle   // the region containing 'op'.
3816b1cc3c6SRiver Riddle   auto regionIsIsolated = [](Region &region) {
382fe7c0d90SRiver Riddle     return region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>();
3836b1cc3c6SRiver Riddle   };
3846b1cc3c6SRiver Riddle   (void)regionIsIsolated;
3856b1cc3c6SRiver Riddle   assert(llvm::all_of(regions, regionIsIsolated) &&
3866b1cc3c6SRiver Riddle          "patterns can only be applied to operations IsolatedFromAbove");
387e7a2ef21SRiver Riddle 
3886b1cc3c6SRiver Riddle   // Start the pattern driver.
38964716b2cSChris Lattner   GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config);
39064716b2cSChris Lattner   bool converged = driver.simplify(regions);
3915c757087SFeng Liu   LLVM_DEBUG(if (!converged) {
392e7a2ef21SRiver Riddle     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
39364716b2cSChris Lattner                  << config.maxIterations << " times\n";
3945c757087SFeng Liu   });
3953e98fbf4SRiver Riddle   return success(converged);
39664d52014SChris Lattner }
39704b5274eSUday Bondhugula 
39804b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
39904b5274eSUday Bondhugula // OpPatternRewriteDriver
40004b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
40104b5274eSUday Bondhugula 
40204b5274eSUday Bondhugula namespace {
40304b5274eSUday Bondhugula /// This is a simple driver for the PatternMatcher to apply patterns and perform
40404b5274eSUday Bondhugula /// folding on a single op. It repeatedly applies locally optimal patterns.
40504b5274eSUday Bondhugula class OpPatternRewriteDriver : public PatternRewriter {
40604b5274eSUday Bondhugula public:
40704b5274eSUday Bondhugula   explicit OpPatternRewriteDriver(MLIRContext *ctx,
40879d7f618SChris Lattner                                   const FrozenRewritePatternSet &patterns)
4093e98fbf4SRiver Riddle       : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
4103e98fbf4SRiver Riddle     // Apply a simple cost model based solely on pattern benefit.
4113e98fbf4SRiver Riddle     matcher.applyDefaultCostModel();
4123e98fbf4SRiver Riddle   }
41304b5274eSUday Bondhugula 
4143e98fbf4SRiver Riddle   LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased);
41504b5274eSUday Bondhugula 
41604b5274eSUday Bondhugula   // These are hooks implemented for PatternRewriter.
41704b5274eSUday Bondhugula protected:
41804b5274eSUday Bondhugula   /// If an operation is about to be removed, mark it so that we can let clients
41904b5274eSUday Bondhugula   /// know.
42004b5274eSUday Bondhugula   void notifyOperationRemoved(Operation *op) override {
42104b5274eSUday Bondhugula     opErasedViaPatternRewrites = true;
42204b5274eSUday Bondhugula   }
42304b5274eSUday Bondhugula 
42404b5274eSUday Bondhugula   // When a root is going to be replaced, its removal will be notified as well.
42504b5274eSUday Bondhugula   // So there is nothing to do here.
42604b5274eSUday Bondhugula   void notifyRootReplaced(Operation *op) override {}
42704b5274eSUday Bondhugula 
42804b5274eSUday Bondhugula private:
4293e98fbf4SRiver Riddle   /// The low-level pattern applicator.
4303e98fbf4SRiver Riddle   PatternApplicator matcher;
43104b5274eSUday Bondhugula 
43204b5274eSUday Bondhugula   /// Non-pattern based folder for operations.
43304b5274eSUday Bondhugula   OperationFolder folder;
43404b5274eSUday Bondhugula 
43504b5274eSUday Bondhugula   /// Set to true if the operation has been erased via pattern rewrites.
43604b5274eSUday Bondhugula   bool opErasedViaPatternRewrites = false;
43704b5274eSUday Bondhugula };
43804b5274eSUday Bondhugula 
43904b5274eSUday Bondhugula } // anonymous namespace
44004b5274eSUday Bondhugula 
4417932d21fSUday Bondhugula /// Performs the rewrites and folding only on `op`. The simplification
4427932d21fSUday Bondhugula /// converges if the op is erased as a result of being folded, replaced, or
4437932d21fSUday Bondhugula /// becoming dead, or no more changes happen in an iteration. Returns success if
4447932d21fSUday Bondhugula /// the rewrite converges in `maxIterations`. `erased` is set to true if `op`
4457932d21fSUday Bondhugula /// gets erased.
4463e98fbf4SRiver Riddle LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
4473e98fbf4SRiver Riddle                                                       int maxIterations,
44804b5274eSUday Bondhugula                                                       bool &erased) {
44904b5274eSUday Bondhugula   bool changed = false;
45004b5274eSUday Bondhugula   erased = false;
45104b5274eSUday Bondhugula   opErasedViaPatternRewrites = false;
4527932d21fSUday Bondhugula   int iterations = 0;
45304b5274eSUday Bondhugula   // Iterate until convergence or until maxIterations. Deletion of the op as
45404b5274eSUday Bondhugula   // a result of being dead or folded is convergence.
45504b5274eSUday Bondhugula   do {
456ff87c4d3SChristian Sigg     changed = false;
457ff87c4d3SChristian Sigg 
45804b5274eSUday Bondhugula     // If the operation is trivially dead - remove it.
45904b5274eSUday Bondhugula     if (isOpTriviallyDead(op)) {
46004b5274eSUday Bondhugula       op->erase();
46104b5274eSUday Bondhugula       erased = true;
4623e98fbf4SRiver Riddle       return success();
46304b5274eSUday Bondhugula     }
46404b5274eSUday Bondhugula 
46504b5274eSUday Bondhugula     // Try to fold this op.
46604b5274eSUday Bondhugula     bool inPlaceUpdate;
46704b5274eSUday Bondhugula     if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr,
46804b5274eSUday Bondhugula                                    /*preReplaceAction=*/nullptr,
46904b5274eSUday Bondhugula                                    &inPlaceUpdate))) {
47004b5274eSUday Bondhugula       changed = true;
47104b5274eSUday Bondhugula       if (!inPlaceUpdate) {
47204b5274eSUday Bondhugula         erased = true;
4733e98fbf4SRiver Riddle         return success();
47404b5274eSUday Bondhugula       }
47504b5274eSUday Bondhugula     }
47604b5274eSUday Bondhugula 
47704b5274eSUday Bondhugula     // Try to match one of the patterns. The rewriter is automatically
47804b5274eSUday Bondhugula     // notified of any necessary changes, so there is nothing else to do here.
4793e98fbf4SRiver Riddle     changed |= succeeded(matcher.matchAndRewrite(op, *this));
48004b5274eSUday Bondhugula     if ((erased = opErasedViaPatternRewrites))
4813e98fbf4SRiver Riddle       return success();
482519663beSFrederik Gossen   } while (changed &&
483519663beSFrederik Gossen            (++iterations < maxIterations ||
484519663beSFrederik Gossen             maxIterations == GreedyRewriteConfig::kNoIterationLimit));
48504b5274eSUday Bondhugula 
48604b5274eSUday Bondhugula   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
4873e98fbf4SRiver Riddle   return failure(changed);
48804b5274eSUday Bondhugula }
48904b5274eSUday Bondhugula 
4907932d21fSUday Bondhugula //===----------------------------------------------------------------------===//
4917932d21fSUday Bondhugula // MultiOpPatternRewriteDriver
4927932d21fSUday Bondhugula //===----------------------------------------------------------------------===//
4937932d21fSUday Bondhugula 
4947932d21fSUday Bondhugula namespace {
4957932d21fSUday Bondhugula 
4967932d21fSUday Bondhugula /// This is a specialized GreedyPatternRewriteDriver to apply patterns and
4977932d21fSUday Bondhugula /// perform folding for a supplied set of ops. It repeatedly simplifies while
4987932d21fSUday Bondhugula /// restricting the rewrites to only the provided set of ops or optionally
4997932d21fSUday Bondhugula /// to those directly affected by it (result users or operand providers).
5007932d21fSUday Bondhugula class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
5017932d21fSUday Bondhugula public:
5027932d21fSUday Bondhugula   explicit MultiOpPatternRewriteDriver(MLIRContext *ctx,
5037932d21fSUday Bondhugula                                        const FrozenRewritePatternSet &patterns,
5047932d21fSUday Bondhugula                                        bool strict)
5057932d21fSUday Bondhugula       : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()),
5067932d21fSUday Bondhugula         strictMode(strict) {}
5077932d21fSUday Bondhugula 
5087932d21fSUday Bondhugula   bool simplifyLocally(ArrayRef<Operation *> op);
5097932d21fSUday Bondhugula 
5107932d21fSUday Bondhugula private:
5117932d21fSUday Bondhugula   // Look over the provided operands for any defining operations that should
5127932d21fSUday Bondhugula   // be re-added to the worklist. This function should be called when an
5137932d21fSUday Bondhugula   // operation is modified or removed, as it may trigger further
5147932d21fSUday Bondhugula   // simplifications. If `strict` is set to true, only ops in
5157932d21fSUday Bondhugula   // `strictModeFilteredOps` are considered.
5167932d21fSUday Bondhugula   template <typename Operands>
5177932d21fSUday Bondhugula   void addOperandsToWorklist(Operands &&operands) {
5187932d21fSUday Bondhugula     for (Value operand : operands) {
5197932d21fSUday Bondhugula       if (auto *defOp = operand.getDefiningOp()) {
5207932d21fSUday Bondhugula         if (!strictMode || strictModeFilteredOps.contains(defOp))
5217932d21fSUday Bondhugula           addToWorklist(defOp);
5227932d21fSUday Bondhugula       }
5237932d21fSUday Bondhugula     }
5247932d21fSUday Bondhugula   }
5257932d21fSUday Bondhugula 
5267932d21fSUday Bondhugula   void notifyOperationRemoved(Operation *op) override {
5277932d21fSUday Bondhugula     GreedyPatternRewriteDriver::notifyOperationRemoved(op);
5287932d21fSUday Bondhugula     if (strictMode)
5297932d21fSUday Bondhugula       strictModeFilteredOps.erase(op);
5307932d21fSUday Bondhugula   }
5317932d21fSUday Bondhugula 
5327932d21fSUday Bondhugula   /// If `strictMode` is true, any pre-existing ops outside of
5337932d21fSUday Bondhugula   /// `strictModeFilteredOps` remain completely untouched by the rewrite driver.
5347932d21fSUday Bondhugula   /// If `strictMode` is false, operations that use results of (or supply
5357932d21fSUday Bondhugula   /// operands to) any rewritten ops stemming from the simplification of the
5367932d21fSUday Bondhugula   /// provided ops are in turn simplified; any other ops still remain untouched
5377932d21fSUday Bondhugula   /// (i.e., regardless of `strictMode`).
5387932d21fSUday Bondhugula   bool strictMode = false;
5397932d21fSUday Bondhugula 
5407932d21fSUday Bondhugula   /// The list of ops we are restricting our rewrites to if `strictMode` is on.
5417932d21fSUday Bondhugula   /// These include the supplied set of ops as well as new ops created while
5427932d21fSUday Bondhugula   /// rewriting those ops. This set is not maintained when strictMode is off.
5437932d21fSUday Bondhugula   llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
5447932d21fSUday Bondhugula };
5457932d21fSUday Bondhugula 
5467932d21fSUday Bondhugula } // end anonymous namespace
5477932d21fSUday Bondhugula 
5487932d21fSUday Bondhugula /// Performs the specified rewrites on `ops` while also trying to fold these ops
5497932d21fSUday Bondhugula /// as well as any other ops that were in turn created due to these rewrite
5507932d21fSUday Bondhugula /// patterns. Any pre-existing ops outside of `ops` remain completely
5517932d21fSUday Bondhugula /// unmodified if `strictMode` is true. If `strictMode` is false, other
5527932d21fSUday Bondhugula /// operations that use results of rewritten ops or supply operands to such ops
5537932d21fSUday Bondhugula /// are in turn simplified; any other ops still remain unmodified (i.e.,
5547932d21fSUday Bondhugula /// regardless of `strictMode`). Note that ops in `ops` could be erased as a
5557932d21fSUday Bondhugula /// result of folding, becoming dead, or via pattern rewrites. Returns true if
5567932d21fSUday Bondhugula /// at all any changes happened.
5577932d21fSUday Bondhugula // Unlike `OpPatternRewriteDriver::simplifyLocally` which works on a single op
5587932d21fSUday Bondhugula // or GreedyPatternRewriteDriver::simplify, this method just iterates until
5597932d21fSUday Bondhugula // the worklist is empty. As our objective is to keep simplification "local",
5607932d21fSUday Bondhugula // there is no strong rationale to re-add all operations into the worklist and
5617932d21fSUday Bondhugula // rerun until an iteration changes nothing. If more widereaching simplification
5627932d21fSUday Bondhugula // is desired, GreedyPatternRewriteDriver should be used.
5637932d21fSUday Bondhugula bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
5647932d21fSUday Bondhugula   if (strictMode) {
5657932d21fSUday Bondhugula     strictModeFilteredOps.clear();
5667932d21fSUday Bondhugula     strictModeFilteredOps.insert(ops.begin(), ops.end());
5677932d21fSUday Bondhugula   }
5687932d21fSUday Bondhugula 
5697932d21fSUday Bondhugula   bool changed = false;
5707932d21fSUday Bondhugula   worklist.clear();
5717932d21fSUday Bondhugula   worklistMap.clear();
5727932d21fSUday Bondhugula   for (Operation *op : ops)
5737932d21fSUday Bondhugula     addToWorklist(op);
5747932d21fSUday Bondhugula 
5757932d21fSUday Bondhugula   // These are scratch vectors used in the folding loop below.
5767932d21fSUday Bondhugula   SmallVector<Value, 8> originalOperands, resultValues;
5777932d21fSUday Bondhugula   while (!worklist.empty()) {
5787932d21fSUday Bondhugula     Operation *op = popFromWorklist();
5797932d21fSUday Bondhugula 
5807932d21fSUday Bondhugula     // Nulls get added to the worklist when operations are removed, ignore
5817932d21fSUday Bondhugula     // them.
5827932d21fSUday Bondhugula     if (op == nullptr)
5837932d21fSUday Bondhugula       continue;
5847932d21fSUday Bondhugula 
5857932d21fSUday Bondhugula     // If the operation is trivially dead - remove it.
5867932d21fSUday Bondhugula     if (isOpTriviallyDead(op)) {
5877932d21fSUday Bondhugula       notifyOperationRemoved(op);
5887932d21fSUday Bondhugula       op->erase();
5897932d21fSUday Bondhugula       changed = true;
5907932d21fSUday Bondhugula       continue;
5917932d21fSUday Bondhugula     }
5927932d21fSUday Bondhugula 
5937932d21fSUday Bondhugula     // Collects all the operands and result uses of the given `op` into work
5947932d21fSUday Bondhugula     // list. Also remove `op` and nested ops from worklist.
5957932d21fSUday Bondhugula     originalOperands.assign(op->operand_begin(), op->operand_end());
5967932d21fSUday Bondhugula     auto preReplaceAction = [&](Operation *op) {
5977932d21fSUday Bondhugula       // Add the operands to the worklist for visitation.
5987932d21fSUday Bondhugula       addOperandsToWorklist(originalOperands);
5997932d21fSUday Bondhugula 
6007932d21fSUday Bondhugula       // Add all the users of the result to the worklist so we make sure
6017932d21fSUday Bondhugula       // to revisit them.
6027932d21fSUday Bondhugula       for (Value result : op->getResults())
6037932d21fSUday Bondhugula         for (Operation *userOp : result.getUsers()) {
6047932d21fSUday Bondhugula           if (!strictMode || strictModeFilteredOps.contains(userOp))
6057932d21fSUday Bondhugula             addToWorklist(userOp);
6067932d21fSUday Bondhugula         }
6077932d21fSUday Bondhugula       notifyOperationRemoved(op);
6087932d21fSUday Bondhugula     };
6097932d21fSUday Bondhugula 
6107932d21fSUday Bondhugula     // Add the given operation generated by the folder to the worklist.
6117932d21fSUday Bondhugula     auto processGeneratedConstants = [this](Operation *op) {
6127932d21fSUday Bondhugula       // Newly created ops are also simplified -- these are also "local".
6137932d21fSUday Bondhugula       addToWorklist(op);
6147932d21fSUday Bondhugula       // When strict mode is off, we don't need to maintain
6157932d21fSUday Bondhugula       // strictModeFilteredOps.
6167932d21fSUday Bondhugula       if (strictMode)
6177932d21fSUday Bondhugula         strictModeFilteredOps.insert(op);
6187932d21fSUday Bondhugula     };
6197932d21fSUday Bondhugula 
6207932d21fSUday Bondhugula     // Try to fold this op.
6217932d21fSUday Bondhugula     bool inPlaceUpdate;
6227932d21fSUday Bondhugula     if (succeeded(folder.tryToFold(op, processGeneratedConstants,
6237932d21fSUday Bondhugula                                    preReplaceAction, &inPlaceUpdate))) {
6247932d21fSUday Bondhugula       changed = true;
6257932d21fSUday Bondhugula       if (!inPlaceUpdate) {
6267932d21fSUday Bondhugula         // Op has been erased.
6277932d21fSUday Bondhugula         continue;
6287932d21fSUday Bondhugula       }
6297932d21fSUday Bondhugula     }
6307932d21fSUday Bondhugula 
6317932d21fSUday Bondhugula     // Try to match one of the patterns. The rewriter is automatically
6327932d21fSUday Bondhugula     // notified of any necessary changes, so there is nothing else to do
6337932d21fSUday Bondhugula     // here.
6347932d21fSUday Bondhugula     changed |= succeeded(matcher.matchAndRewrite(op, *this));
6357932d21fSUday Bondhugula   }
6367932d21fSUday Bondhugula 
6377932d21fSUday Bondhugula   return changed;
6387932d21fSUday Bondhugula }
6397932d21fSUday Bondhugula 
64004b5274eSUday Bondhugula /// Rewrites only `op` using the supplied canonicalization patterns and
64104b5274eSUday Bondhugula /// folding. `erased` is set to true if the op is erased as a result of being
64204b5274eSUday Bondhugula /// folded, replaced, or dead.
6433e98fbf4SRiver Riddle LogicalResult mlir::applyOpPatternsAndFold(
64479d7f618SChris Lattner     Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) {
64504b5274eSUday Bondhugula   // Start the pattern driver.
64664716b2cSChris Lattner   GreedyRewriteConfig config;
64704b5274eSUday Bondhugula   OpPatternRewriteDriver driver(op->getContext(), patterns);
64804b5274eSUday Bondhugula   bool opErased;
6493e98fbf4SRiver Riddle   LogicalResult converged =
65064716b2cSChris Lattner       driver.simplifyLocally(op, config.maxIterations, opErased);
65104b5274eSUday Bondhugula   if (erased)
65204b5274eSUday Bondhugula     *erased = opErased;
6533e98fbf4SRiver Riddle   LLVM_DEBUG(if (failed(converged)) {
65404b5274eSUday Bondhugula     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
65564716b2cSChris Lattner                  << config.maxIterations << " times";
65604b5274eSUday Bondhugula   });
65704b5274eSUday Bondhugula   return converged;
65804b5274eSUday Bondhugula }
6597932d21fSUday Bondhugula 
6607932d21fSUday Bondhugula bool mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
6617932d21fSUday Bondhugula                                   const FrozenRewritePatternSet &patterns,
6627932d21fSUday Bondhugula                                   bool strict) {
6637932d21fSUday Bondhugula   if (ops.empty())
6647932d21fSUday Bondhugula     return false;
6657932d21fSUday Bondhugula 
6667932d21fSUday Bondhugula   // Start the pattern driver.
6677932d21fSUday Bondhugula   MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
6687932d21fSUday Bondhugula                                      strict);
6697932d21fSUday Bondhugula   return driver.simplifyLocally(ops);
6707932d21fSUday Bondhugula }
671