164d52014SChris Lattner //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
264d52014SChris Lattner //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
664d52014SChris Lattner //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
864d52014SChris Lattner //
9a60fdd2bSLorenzo Chelini // This file implements mlir::applyPatternsAndFoldGreedily.
1064d52014SChris Lattner //
1164d52014SChris Lattner //===----------------------------------------------------------------------===//
1264d52014SChris Lattner 
13b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14eb623ae8SStephen Neuendorffer #include "mlir/Interfaces/SideEffectInterfaces.h"
15b6eb26fdSRiver Riddle #include "mlir/Rewrite/PatternApplicator.h"
161982afb1SRiver Riddle #include "mlir/Transforms/FoldUtils.h"
17fafb708bSRiver Riddle #include "mlir/Transforms/RegionUtils.h"
1864d52014SChris Lattner #include "llvm/ADT/DenseMap.h"
195c757087SFeng Liu #include "llvm/Support/CommandLine.h"
205c757087SFeng Liu #include "llvm/Support/Debug.h"
215652ecc3SRiver Riddle #include "llvm/Support/ScopedPrinter.h"
225c757087SFeng Liu #include "llvm/Support/raw_ostream.h"
234e40c832SLei Zhang 
2464d52014SChris Lattner using namespace mlir;
2564d52014SChris Lattner 
265652ecc3SRiver Riddle #define DEBUG_TYPE "greedy-rewriter"
275c757087SFeng Liu 
2804b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
2904b5274eSUday Bondhugula // GreedyPatternRewriteDriver
3004b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
3104b5274eSUday Bondhugula 
3264d52014SChris Lattner namespace {
3364d52014SChris Lattner /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
3464d52014SChris Lattner /// applies the locally optimal patterns in a roughly "bottom up" way.
354bd9f936SChris Lattner class GreedyPatternRewriteDriver : public PatternRewriter {
3664d52014SChris Lattner public:
372566a72aSRiver Riddle   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
38648f34a2SChris Lattner                                       const FrozenRewritePatternSet &patterns,
39b7144ab7SRiver Riddle                                       const GreedyRewriteConfig &config);
403e98fbf4SRiver Riddle 
41b7144ab7SRiver Riddle   /// Simplify the operations within the given regions.
4264716b2cSChris Lattner   bool simplify(MutableArrayRef<Region> regions);
4364d52014SChris Lattner 
44b7144ab7SRiver Riddle   /// Add the given operation to the worklist.
45b7144ab7SRiver Riddle   void addToWorklist(Operation *op);
465c4f1fddSRiver Riddle 
47b7144ab7SRiver Riddle   /// Pop the next operation from the worklist.
48b7144ab7SRiver Riddle   Operation *popFromWorklist();
4964d52014SChris Lattner 
50b7144ab7SRiver Riddle   /// If the specified operation is in the worklist, remove it.
51b7144ab7SRiver Riddle   void removeFromWorklist(Operation *op);
5264d52014SChris Lattner 
534bd9f936SChris Lattner protected:
54851a8516SRiver Riddle   // Implement the hook for inserting operations, and make sure that newly
55851a8516SRiver Riddle   // inserted ops are added to the worklist for processing.
56b7144ab7SRiver Riddle   void notifyOperationInserted(Operation *op) override;
5764d52014SChris Lattner 
587932d21fSUday Bondhugula   // Look over the provided operands for any defining operations that should
597932d21fSUday Bondhugula   // be re-added to the worklist. This function should be called when an
607932d21fSUday Bondhugula   // operation is modified or removed, as it may trigger further
617932d21fSUday Bondhugula   // simplifications.
627932d21fSUday Bondhugula   template <typename Operands>
63b7144ab7SRiver Riddle   void addToWorklist(Operands &&operands);
647932d21fSUday Bondhugula 
6564d52014SChris Lattner   // If an operation is about to be removed, make sure it is not in our
6664d52014SChris Lattner   // worklist anymore because we'd get dangling references to it.
67b7144ab7SRiver Riddle   void notifyOperationRemoved(Operation *op) override;
6864d52014SChris Lattner 
69085b687fSChris Lattner   // When the root of a pattern is about to be replaced, it can trigger
70085b687fSChris Lattner   // simplifications to its users - make sure to add them to the worklist
71085b687fSChris Lattner   // before the root is changed.
72b7144ab7SRiver Riddle   void notifyRootReplaced(Operation *op) override;
73085b687fSChris Lattner 
745652ecc3SRiver Riddle   /// PatternRewriter hook for erasing a dead operation.
755652ecc3SRiver Riddle   void eraseOp(Operation *op) override;
765652ecc3SRiver Riddle 
775652ecc3SRiver Riddle   /// PatternRewriter hook for notifying match failure reasons.
785652ecc3SRiver Riddle   LogicalResult
79*ea64828aSRiver Riddle   notifyMatchFailure(Location loc,
805652ecc3SRiver Riddle                      function_ref<void(Diagnostic &)> reasonCallback) override;
815652ecc3SRiver Riddle 
823e98fbf4SRiver Riddle   /// The low-level pattern applicator.
833e98fbf4SRiver Riddle   PatternApplicator matcher;
844bd9f936SChris Lattner 
854bd9f936SChris Lattner   /// The worklist for this transformation keeps track of the operations that
864bd9f936SChris Lattner   /// need to be revisited, plus their index in the worklist.  This allows us to
87e7a2ef21SRiver Riddle   /// efficiently remove operations from the worklist when they are erased, even
88e7a2ef21SRiver Riddle   /// if they aren't the root of a pattern.
8999b87c97SRiver Riddle   std::vector<Operation *> worklist;
9099b87c97SRiver Riddle   DenseMap<Operation *, unsigned> worklistMap;
9160a29837SRiver Riddle 
9260a29837SRiver Riddle   /// Non-pattern based folder for operations.
9360a29837SRiver Riddle   OperationFolder folder;
94648f34a2SChris Lattner 
957932d21fSUday Bondhugula private:
9664716b2cSChris Lattner   /// Configuration information for how to simplify.
9764716b2cSChris Lattner   GreedyRewriteConfig config;
985652ecc3SRiver Riddle 
995652ecc3SRiver Riddle #ifndef NDEBUG
1005652ecc3SRiver Riddle   /// A logger used to emit information during the application process.
1015652ecc3SRiver Riddle   llvm::ScopedPrinter logger{llvm::dbgs()};
1025652ecc3SRiver Riddle #endif
10364d52014SChris Lattner };
104be0a7e9fSMehdi Amini } // namespace
10564d52014SChris Lattner 
106b7144ab7SRiver Riddle GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
107b7144ab7SRiver Riddle     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
108b7144ab7SRiver Riddle     const GreedyRewriteConfig &config)
109b7144ab7SRiver Riddle     : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
110b7144ab7SRiver Riddle   worklist.reserve(64);
111b7144ab7SRiver Riddle 
112b7144ab7SRiver Riddle   // Apply a simple cost model based solely on pattern benefit.
113b7144ab7SRiver Riddle   matcher.applyDefaultCostModel();
114b7144ab7SRiver Riddle }
115b7144ab7SRiver Riddle 
11664716b2cSChris Lattner bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
1175652ecc3SRiver Riddle #ifndef NDEBUG
1185652ecc3SRiver Riddle   const char *logLineComment =
1195652ecc3SRiver Riddle       "//===-------------------------------------------===//\n";
1205652ecc3SRiver Riddle 
1215652ecc3SRiver Riddle   /// A utility function to log a process result for the given reason.
1225652ecc3SRiver Riddle   auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
1235652ecc3SRiver Riddle     logger.unindent();
1245652ecc3SRiver Riddle     logger.startLine() << "} -> " << result;
1255652ecc3SRiver Riddle     if (!msg.isTriviallyEmpty())
1265652ecc3SRiver Riddle       logger.getOStream() << " : " << msg;
1275652ecc3SRiver Riddle     logger.getOStream() << "\n";
1285652ecc3SRiver Riddle   };
1295652ecc3SRiver Riddle   auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
1305652ecc3SRiver Riddle     logResult(result, msg);
1315652ecc3SRiver Riddle     logger.startLine() << logLineComment;
1325652ecc3SRiver Riddle   };
1335652ecc3SRiver Riddle #endif
1345652ecc3SRiver Riddle 
1355c757087SFeng Liu   bool changed = false;
13664716b2cSChris Lattner   unsigned iteration = 0;
1375c757087SFeng Liu   do {
138648f34a2SChris Lattner     worklist.clear();
139648f34a2SChris Lattner     worklistMap.clear();
140648f34a2SChris Lattner 
14164716b2cSChris Lattner     if (!config.useTopDownTraversal) {
14264716b2cSChris Lattner       // Add operations to the worklist in postorder.
143ba43d6f8SMehdi Amini       for (auto &region : regions)
144ba43d6f8SMehdi Amini         region.walk([this](Operation *op) { addToWorklist(op); });
14564716b2cSChris Lattner     } else {
146648f34a2SChris Lattner       // Add all nested operations to the worklist in preorder.
1476b1cc3c6SRiver Riddle       for (auto &region : regions)
148648f34a2SChris Lattner         region.walk<WalkOrder::PreOrder>(
149648f34a2SChris Lattner             [this](Operation *op) { worklist.push_back(op); });
150648f34a2SChris Lattner 
151648f34a2SChris Lattner       // Reverse the list so our pop-back loop processes them in-order.
152648f34a2SChris Lattner       std::reverse(worklist.begin(), worklist.end());
153648f34a2SChris Lattner       // Remember the reverse index.
154648f34a2SChris Lattner       for (size_t i = 0, e = worklist.size(); i != e; ++i)
155648f34a2SChris Lattner         worklistMap[worklist[i]] = i;
156648f34a2SChris Lattner     }
1574e40c832SLei Zhang 
1584e40c832SLei Zhang     // These are scratch vectors used in the folding loop below.
159e62a6956SRiver Riddle     SmallVector<Value, 8> originalOperands, resultValues;
16064d52014SChris Lattner 
1615c757087SFeng Liu     changed = false;
16264d52014SChris Lattner     while (!worklist.empty()) {
16364d52014SChris Lattner       auto *op = popFromWorklist();
16464d52014SChris Lattner 
1655c757087SFeng Liu       // Nulls get added to the worklist when operations are removed, ignore
1665c757087SFeng Liu       // them.
16764d52014SChris Lattner       if (op == nullptr)
16864d52014SChris Lattner         continue;
16964d52014SChris Lattner 
1705652ecc3SRiver Riddle       LLVM_DEBUG({
1715652ecc3SRiver Riddle         logger.getOStream() << "\n";
1725652ecc3SRiver Riddle         logger.startLine() << logLineComment;
1735652ecc3SRiver Riddle         logger.startLine() << "Processing operation : '" << op->getName()
1745652ecc3SRiver Riddle                            << "'(" << op << ") {\n";
1755652ecc3SRiver Riddle         logger.indent();
1765652ecc3SRiver Riddle 
1775652ecc3SRiver Riddle         // If the operation has no regions, just print it here.
1785652ecc3SRiver Riddle         if (op->getNumRegions() == 0) {
1795652ecc3SRiver Riddle           op->print(
1805652ecc3SRiver Riddle               logger.startLine(),
1815652ecc3SRiver Riddle               OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
1825652ecc3SRiver Riddle           logger.getOStream() << "\n\n";
1835652ecc3SRiver Riddle         }
1845652ecc3SRiver Riddle       });
1855652ecc3SRiver Riddle 
1860ddba0bdSRiver Riddle       // If the operation is trivially dead - remove it.
1870ddba0bdSRiver Riddle       if (isOpTriviallyDead(op)) {
1886a501e3dSAndy Ly         notifyOperationRemoved(op);
18964d52014SChris Lattner         op->erase();
190f875e55bSUday Bondhugula         changed = true;
1915652ecc3SRiver Riddle 
1925652ecc3SRiver Riddle         LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
19364d52014SChris Lattner         continue;
19464d52014SChris Lattner       }
19564d52014SChris Lattner 
1964e40c832SLei Zhang       // Collects all the operands and result uses of the given `op` into work
1976a501e3dSAndy Ly       // list. Also remove `op` and nested ops from worklist.
1981982afb1SRiver Riddle       originalOperands.assign(op->operand_begin(), op->operand_end());
1996a501e3dSAndy Ly       auto preReplaceAction = [&](Operation *op) {
200a8866258SRiver Riddle         // Add the operands to the worklist for visitation.
2011982afb1SRiver Riddle         addToWorklist(originalOperands);
2021982afb1SRiver Riddle 
2034e40c832SLei Zhang         // Add all the users of the result to the worklist so we make sure
2044e40c832SLei Zhang         // to revisit them.
20535807bc4SRiver Riddle         for (auto result : op->getResults())
206cc673894SUday Bondhugula           for (auto *userOp : result.getUsers())
207cc673894SUday Bondhugula             addToWorklist(userOp);
2086a501e3dSAndy Ly 
2096a501e3dSAndy Ly         notifyOperationRemoved(op);
2104e40c832SLei Zhang       };
21164d52014SChris Lattner 
212648f34a2SChris Lattner       // Add the given operation to the worklist.
213648f34a2SChris Lattner       auto collectOps = [this](Operation *op) { addToWorklist(op); };
214648f34a2SChris Lattner 
2151982afb1SRiver Riddle       // Try to fold this op.
216cbcb12fdSUday Bondhugula       bool inPlaceUpdate;
217cbcb12fdSUday Bondhugula       if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
218cbcb12fdSUday Bondhugula                                       &inPlaceUpdate)))) {
2195652ecc3SRiver Riddle         LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
2205652ecc3SRiver Riddle 
221f875e55bSUday Bondhugula         changed = true;
222cbcb12fdSUday Bondhugula         if (!inPlaceUpdate)
223934b6d12SChris Lattner           continue;
22464d52014SChris Lattner       }
22564d52014SChris Lattner 
22632052c84SRiver Riddle       // Try to match one of the patterns. The rewriter is automatically
227648f34a2SChris Lattner       // notified of any necessary changes, so there is nothing else to do
228648f34a2SChris Lattner       // here.
2295652ecc3SRiver Riddle #ifndef NDEBUG
2305652ecc3SRiver Riddle       auto canApply = [&](const Pattern &pattern) {
2315652ecc3SRiver Riddle         LLVM_DEBUG({
2325652ecc3SRiver Riddle           logger.getOStream() << "\n";
2335652ecc3SRiver Riddle           logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
2345652ecc3SRiver Riddle                              << op->getName() << " -> (";
2355652ecc3SRiver Riddle           llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
2365652ecc3SRiver Riddle           logger.getOStream() << ")' {\n";
2375652ecc3SRiver Riddle           logger.indent();
2385652ecc3SRiver Riddle         });
2395652ecc3SRiver Riddle         return true;
2405652ecc3SRiver Riddle       };
2415652ecc3SRiver Riddle       auto onFailure = [&](const Pattern &pattern) {
2425652ecc3SRiver Riddle         LLVM_DEBUG(logResult("failure", "pattern failed to match"));
2435652ecc3SRiver Riddle       };
2445652ecc3SRiver Riddle       auto onSuccess = [&](const Pattern &pattern) {
2455652ecc3SRiver Riddle         LLVM_DEBUG(logResult("success", "pattern applied successfully"));
2465652ecc3SRiver Riddle         return success();
2475652ecc3SRiver Riddle       };
2485652ecc3SRiver Riddle 
2495652ecc3SRiver Riddle       LogicalResult matchResult =
2505652ecc3SRiver Riddle           matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
2515652ecc3SRiver Riddle       if (succeeded(matchResult))
2525652ecc3SRiver Riddle         LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
2535652ecc3SRiver Riddle       else
2545652ecc3SRiver Riddle         LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
2555652ecc3SRiver Riddle #else
2565652ecc3SRiver Riddle       LogicalResult matchResult = matcher.matchAndRewrite(op, *this);
2575652ecc3SRiver Riddle #endif
2585652ecc3SRiver Riddle       changed |= succeeded(matchResult);
25964d52014SChris Lattner     }
260a32f0dcbSRiver Riddle 
261648f34a2SChris Lattner     // After applying patterns, make sure that the CFG of each of the regions
262648f34a2SChris Lattner     // is kept up to date.
26364716b2cSChris Lattner     if (config.enableRegionSimplification)
264d75a611aSRiver Riddle       changed |= succeeded(simplifyRegions(*this, regions));
265519663beSFrederik Gossen   } while (changed &&
266519663beSFrederik Gossen            (++iteration < config.maxIterations ||
267519663beSFrederik Gossen             config.maxIterations == GreedyRewriteConfig::kNoIterationLimit));
268648f34a2SChris Lattner 
2695c757087SFeng Liu   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
2705c757087SFeng Liu   return !changed;
27164d52014SChris Lattner }
27264d52014SChris Lattner 
273b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
274b7144ab7SRiver Riddle   // Check to see if the worklist already contains this op.
275b7144ab7SRiver Riddle   if (worklistMap.count(op))
276b7144ab7SRiver Riddle     return;
277b7144ab7SRiver Riddle 
278b7144ab7SRiver Riddle   worklistMap[op] = worklist.size();
279b7144ab7SRiver Riddle   worklist.push_back(op);
280b7144ab7SRiver Riddle }
281b7144ab7SRiver Riddle 
282b7144ab7SRiver Riddle Operation *GreedyPatternRewriteDriver::popFromWorklist() {
283b7144ab7SRiver Riddle   auto *op = worklist.back();
284b7144ab7SRiver Riddle   worklist.pop_back();
285b7144ab7SRiver Riddle 
286b7144ab7SRiver Riddle   // This operation is no longer in the worklist, keep worklistMap up to date.
287b7144ab7SRiver Riddle   if (op)
288b7144ab7SRiver Riddle     worklistMap.erase(op);
289b7144ab7SRiver Riddle   return op;
290b7144ab7SRiver Riddle }
291b7144ab7SRiver Riddle 
292b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::removeFromWorklist(Operation *op) {
293b7144ab7SRiver Riddle   auto it = worklistMap.find(op);
294b7144ab7SRiver Riddle   if (it != worklistMap.end()) {
295b7144ab7SRiver Riddle     assert(worklist[it->second] == op && "malformed worklist data structure");
296b7144ab7SRiver Riddle     worklist[it->second] = nullptr;
297b7144ab7SRiver Riddle     worklistMap.erase(it);
298b7144ab7SRiver Riddle   }
299b7144ab7SRiver Riddle }
300b7144ab7SRiver Riddle 
301b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
3025652ecc3SRiver Riddle   LLVM_DEBUG({
3035652ecc3SRiver Riddle     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
3045652ecc3SRiver Riddle                        << ")\n";
3055652ecc3SRiver Riddle   });
306b7144ab7SRiver Riddle   addToWorklist(op);
307b7144ab7SRiver Riddle }
308b7144ab7SRiver Riddle 
309b7144ab7SRiver Riddle template <typename Operands>
310b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::addToWorklist(Operands &&operands) {
311b7144ab7SRiver Riddle   for (Value operand : operands) {
312b7144ab7SRiver Riddle     // If the use count of this operand is now < 2, we re-add the defining
313b7144ab7SRiver Riddle     // operation to the worklist.
314b7144ab7SRiver Riddle     // TODO: This is based on the fact that zero use operations
315b7144ab7SRiver Riddle     // may be deleted, and that single use values often have more
316b7144ab7SRiver Riddle     // canonicalization opportunities.
317b7144ab7SRiver Riddle     if (!operand || (!operand.use_empty() && !operand.hasOneUse()))
318b7144ab7SRiver Riddle       continue;
319b7144ab7SRiver Riddle     if (auto *defOp = operand.getDefiningOp())
320b7144ab7SRiver Riddle       addToWorklist(defOp);
321b7144ab7SRiver Riddle   }
322b7144ab7SRiver Riddle }
323b7144ab7SRiver Riddle 
324b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
325b7144ab7SRiver Riddle   addToWorklist(op->getOperands());
326b7144ab7SRiver Riddle   op->walk([this](Operation *operation) {
327b7144ab7SRiver Riddle     removeFromWorklist(operation);
328b7144ab7SRiver Riddle     folder.notifyRemoval(operation);
329b7144ab7SRiver Riddle   });
330b7144ab7SRiver Riddle }
331b7144ab7SRiver Riddle 
332b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op) {
3335652ecc3SRiver Riddle   LLVM_DEBUG({
3345652ecc3SRiver Riddle     logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
3355652ecc3SRiver Riddle                        << ")\n";
3365652ecc3SRiver Riddle   });
337b7144ab7SRiver Riddle   for (auto result : op->getResults())
338b7144ab7SRiver Riddle     for (auto *user : result.getUsers())
339b7144ab7SRiver Riddle       addToWorklist(user);
340b7144ab7SRiver Riddle }
341b7144ab7SRiver Riddle 
3425652ecc3SRiver Riddle void GreedyPatternRewriteDriver::eraseOp(Operation *op) {
3435652ecc3SRiver Riddle   LLVM_DEBUG({
3445652ecc3SRiver Riddle     logger.startLine() << "** Erase   : '" << op->getName() << "'(" << op
3455652ecc3SRiver Riddle                        << ")\n";
3465652ecc3SRiver Riddle   });
3475652ecc3SRiver Riddle   PatternRewriter::eraseOp(op);
3485652ecc3SRiver Riddle }
3495652ecc3SRiver Riddle 
3505652ecc3SRiver Riddle LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
351*ea64828aSRiver Riddle     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
3525652ecc3SRiver Riddle   LLVM_DEBUG({
353*ea64828aSRiver Riddle     Diagnostic diag(loc, DiagnosticSeverity::Remark);
3545652ecc3SRiver Riddle     reasonCallback(diag);
3555652ecc3SRiver Riddle     logger.startLine() << "** Failure : " << diag.str() << "\n";
3565652ecc3SRiver Riddle   });
3575652ecc3SRiver Riddle   return failure();
3585652ecc3SRiver Riddle }
3595652ecc3SRiver Riddle 
360e7a2ef21SRiver Riddle /// Rewrite the regions of the specified operation, which must be isolated from
361e7a2ef21SRiver Riddle /// above, by repeatedly applying the highest benefit patterns in a greedy
3623e98fbf4SRiver Riddle /// work-list driven manner. Return success if no more patterns can be matched
3633e98fbf4SRiver Riddle /// in the result operation regions. Note: This does not apply patterns to the
3643e98fbf4SRiver Riddle /// top-level operation itself.
36564d52014SChris Lattner ///
3663e98fbf4SRiver Riddle LogicalResult
3673e98fbf4SRiver Riddle mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
3680b20413eSUday Bondhugula                                    const FrozenRewritePatternSet &patterns,
36964716b2cSChris Lattner                                    GreedyRewriteConfig config) {
3706b1cc3c6SRiver Riddle   if (regions.empty())
3713e98fbf4SRiver Riddle     return success();
3726b1cc3c6SRiver Riddle 
373e7a2ef21SRiver Riddle   // The top-level operation must be known to be isolated from above to
374e7a2ef21SRiver Riddle   // prevent performing canonicalizations on operations defined at or above
375e7a2ef21SRiver Riddle   // the region containing 'op'.
3766b1cc3c6SRiver Riddle   auto regionIsIsolated = [](Region &region) {
377fe7c0d90SRiver Riddle     return region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>();
3786b1cc3c6SRiver Riddle   };
3796b1cc3c6SRiver Riddle   (void)regionIsIsolated;
3806b1cc3c6SRiver Riddle   assert(llvm::all_of(regions, regionIsIsolated) &&
3816b1cc3c6SRiver Riddle          "patterns can only be applied to operations IsolatedFromAbove");
382e7a2ef21SRiver Riddle 
3836b1cc3c6SRiver Riddle   // Start the pattern driver.
38464716b2cSChris Lattner   GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config);
38564716b2cSChris Lattner   bool converged = driver.simplify(regions);
3865c757087SFeng Liu   LLVM_DEBUG(if (!converged) {
387e7a2ef21SRiver Riddle     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
38864716b2cSChris Lattner                  << config.maxIterations << " times\n";
3895c757087SFeng Liu   });
3903e98fbf4SRiver Riddle   return success(converged);
39164d52014SChris Lattner }
39204b5274eSUday Bondhugula 
39304b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
39404b5274eSUday Bondhugula // OpPatternRewriteDriver
39504b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
39604b5274eSUday Bondhugula 
39704b5274eSUday Bondhugula namespace {
39804b5274eSUday Bondhugula /// This is a simple driver for the PatternMatcher to apply patterns and perform
39904b5274eSUday Bondhugula /// folding on a single op. It repeatedly applies locally optimal patterns.
40004b5274eSUday Bondhugula class OpPatternRewriteDriver : public PatternRewriter {
40104b5274eSUday Bondhugula public:
40204b5274eSUday Bondhugula   explicit OpPatternRewriteDriver(MLIRContext *ctx,
40379d7f618SChris Lattner                                   const FrozenRewritePatternSet &patterns)
4043e98fbf4SRiver Riddle       : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
4053e98fbf4SRiver Riddle     // Apply a simple cost model based solely on pattern benefit.
4063e98fbf4SRiver Riddle     matcher.applyDefaultCostModel();
4073e98fbf4SRiver Riddle   }
40804b5274eSUday Bondhugula 
4093e98fbf4SRiver Riddle   LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased);
41004b5274eSUday Bondhugula 
41104b5274eSUday Bondhugula   // These are hooks implemented for PatternRewriter.
41204b5274eSUday Bondhugula protected:
41304b5274eSUday Bondhugula   /// If an operation is about to be removed, mark it so that we can let clients
41404b5274eSUday Bondhugula   /// know.
41504b5274eSUday Bondhugula   void notifyOperationRemoved(Operation *op) override {
41604b5274eSUday Bondhugula     opErasedViaPatternRewrites = true;
41704b5274eSUday Bondhugula   }
41804b5274eSUday Bondhugula 
41904b5274eSUday Bondhugula   // When a root is going to be replaced, its removal will be notified as well.
42004b5274eSUday Bondhugula   // So there is nothing to do here.
42104b5274eSUday Bondhugula   void notifyRootReplaced(Operation *op) override {}
42204b5274eSUday Bondhugula 
42304b5274eSUday Bondhugula private:
4243e98fbf4SRiver Riddle   /// The low-level pattern applicator.
4253e98fbf4SRiver Riddle   PatternApplicator matcher;
42604b5274eSUday Bondhugula 
42704b5274eSUday Bondhugula   /// Non-pattern based folder for operations.
42804b5274eSUday Bondhugula   OperationFolder folder;
42904b5274eSUday Bondhugula 
43004b5274eSUday Bondhugula   /// Set to true if the operation has been erased via pattern rewrites.
43104b5274eSUday Bondhugula   bool opErasedViaPatternRewrites = false;
43204b5274eSUday Bondhugula };
43304b5274eSUday Bondhugula 
434be0a7e9fSMehdi Amini } // namespace
43504b5274eSUday Bondhugula 
4367932d21fSUday Bondhugula /// Performs the rewrites and folding only on `op`. The simplification
4377932d21fSUday Bondhugula /// converges if the op is erased as a result of being folded, replaced, or
4387932d21fSUday Bondhugula /// becoming dead, or no more changes happen in an iteration. Returns success if
4397932d21fSUday Bondhugula /// the rewrite converges in `maxIterations`. `erased` is set to true if `op`
4407932d21fSUday Bondhugula /// gets erased.
4413e98fbf4SRiver Riddle LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
4423e98fbf4SRiver Riddle                                                       int maxIterations,
44304b5274eSUday Bondhugula                                                       bool &erased) {
44404b5274eSUday Bondhugula   bool changed = false;
44504b5274eSUday Bondhugula   erased = false;
44604b5274eSUday Bondhugula   opErasedViaPatternRewrites = false;
4477932d21fSUday Bondhugula   int iterations = 0;
44804b5274eSUday Bondhugula   // Iterate until convergence or until maxIterations. Deletion of the op as
44904b5274eSUday Bondhugula   // a result of being dead or folded is convergence.
45004b5274eSUday Bondhugula   do {
451ff87c4d3SChristian Sigg     changed = false;
452ff87c4d3SChristian Sigg 
45304b5274eSUday Bondhugula     // If the operation is trivially dead - remove it.
45404b5274eSUday Bondhugula     if (isOpTriviallyDead(op)) {
45504b5274eSUday Bondhugula       op->erase();
45604b5274eSUday Bondhugula       erased = true;
4573e98fbf4SRiver Riddle       return success();
45804b5274eSUday Bondhugula     }
45904b5274eSUday Bondhugula 
46004b5274eSUday Bondhugula     // Try to fold this op.
46104b5274eSUday Bondhugula     bool inPlaceUpdate;
46204b5274eSUday Bondhugula     if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr,
46304b5274eSUday Bondhugula                                    /*preReplaceAction=*/nullptr,
46404b5274eSUday Bondhugula                                    &inPlaceUpdate))) {
46504b5274eSUday Bondhugula       changed = true;
46604b5274eSUday Bondhugula       if (!inPlaceUpdate) {
46704b5274eSUday Bondhugula         erased = true;
4683e98fbf4SRiver Riddle         return success();
46904b5274eSUday Bondhugula       }
47004b5274eSUday Bondhugula     }
47104b5274eSUday Bondhugula 
47204b5274eSUday Bondhugula     // Try to match one of the patterns. The rewriter is automatically
47304b5274eSUday Bondhugula     // notified of any necessary changes, so there is nothing else to do here.
4743e98fbf4SRiver Riddle     changed |= succeeded(matcher.matchAndRewrite(op, *this));
47504b5274eSUday Bondhugula     if ((erased = opErasedViaPatternRewrites))
4763e98fbf4SRiver Riddle       return success();
477519663beSFrederik Gossen   } while (changed &&
478519663beSFrederik Gossen            (++iterations < maxIterations ||
479519663beSFrederik Gossen             maxIterations == GreedyRewriteConfig::kNoIterationLimit));
48004b5274eSUday Bondhugula 
48104b5274eSUday Bondhugula   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
4823e98fbf4SRiver Riddle   return failure(changed);
48304b5274eSUday Bondhugula }
48404b5274eSUday Bondhugula 
4857932d21fSUday Bondhugula //===----------------------------------------------------------------------===//
4867932d21fSUday Bondhugula // MultiOpPatternRewriteDriver
4877932d21fSUday Bondhugula //===----------------------------------------------------------------------===//
4887932d21fSUday Bondhugula 
4897932d21fSUday Bondhugula namespace {
4907932d21fSUday Bondhugula 
4917932d21fSUday Bondhugula /// This is a specialized GreedyPatternRewriteDriver to apply patterns and
4927932d21fSUday Bondhugula /// perform folding for a supplied set of ops. It repeatedly simplifies while
4937932d21fSUday Bondhugula /// restricting the rewrites to only the provided set of ops or optionally
4947932d21fSUday Bondhugula /// to those directly affected by it (result users or operand providers).
4957932d21fSUday Bondhugula class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
4967932d21fSUday Bondhugula public:
4977932d21fSUday Bondhugula   explicit MultiOpPatternRewriteDriver(MLIRContext *ctx,
4987932d21fSUday Bondhugula                                        const FrozenRewritePatternSet &patterns,
4997932d21fSUday Bondhugula                                        bool strict)
5007932d21fSUday Bondhugula       : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()),
5017932d21fSUday Bondhugula         strictMode(strict) {}
5027932d21fSUday Bondhugula 
5037932d21fSUday Bondhugula   bool simplifyLocally(ArrayRef<Operation *> op);
5047932d21fSUday Bondhugula 
5057932d21fSUday Bondhugula private:
5067932d21fSUday Bondhugula   // Look over the provided operands for any defining operations that should
5077932d21fSUday Bondhugula   // be re-added to the worklist. This function should be called when an
5087932d21fSUday Bondhugula   // operation is modified or removed, as it may trigger further
5097932d21fSUday Bondhugula   // simplifications. If `strict` is set to true, only ops in
5107932d21fSUday Bondhugula   // `strictModeFilteredOps` are considered.
5117932d21fSUday Bondhugula   template <typename Operands>
5127932d21fSUday Bondhugula   void addOperandsToWorklist(Operands &&operands) {
5137932d21fSUday Bondhugula     for (Value operand : operands) {
5147932d21fSUday Bondhugula       if (auto *defOp = operand.getDefiningOp()) {
5157932d21fSUday Bondhugula         if (!strictMode || strictModeFilteredOps.contains(defOp))
5167932d21fSUday Bondhugula           addToWorklist(defOp);
5177932d21fSUday Bondhugula       }
5187932d21fSUday Bondhugula     }
5197932d21fSUday Bondhugula   }
5207932d21fSUday Bondhugula 
5217932d21fSUday Bondhugula   void notifyOperationRemoved(Operation *op) override {
5227932d21fSUday Bondhugula     GreedyPatternRewriteDriver::notifyOperationRemoved(op);
5237932d21fSUday Bondhugula     if (strictMode)
5247932d21fSUday Bondhugula       strictModeFilteredOps.erase(op);
5257932d21fSUday Bondhugula   }
5267932d21fSUday Bondhugula 
5277932d21fSUday Bondhugula   /// If `strictMode` is true, any pre-existing ops outside of
5287932d21fSUday Bondhugula   /// `strictModeFilteredOps` remain completely untouched by the rewrite driver.
5297932d21fSUday Bondhugula   /// If `strictMode` is false, operations that use results of (or supply
5307932d21fSUday Bondhugula   /// operands to) any rewritten ops stemming from the simplification of the
5317932d21fSUday Bondhugula   /// provided ops are in turn simplified; any other ops still remain untouched
5327932d21fSUday Bondhugula   /// (i.e., regardless of `strictMode`).
5337932d21fSUday Bondhugula   bool strictMode = false;
5347932d21fSUday Bondhugula 
5357932d21fSUday Bondhugula   /// The list of ops we are restricting our rewrites to if `strictMode` is on.
5367932d21fSUday Bondhugula   /// These include the supplied set of ops as well as new ops created while
5377932d21fSUday Bondhugula   /// rewriting those ops. This set is not maintained when strictMode is off.
5387932d21fSUday Bondhugula   llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
5397932d21fSUday Bondhugula };
5407932d21fSUday Bondhugula 
541be0a7e9fSMehdi Amini } // namespace
5427932d21fSUday Bondhugula 
5437932d21fSUday Bondhugula /// Performs the specified rewrites on `ops` while also trying to fold these ops
5447932d21fSUday Bondhugula /// as well as any other ops that were in turn created due to these rewrite
5457932d21fSUday Bondhugula /// patterns. Any pre-existing ops outside of `ops` remain completely
5467932d21fSUday Bondhugula /// unmodified if `strictMode` is true. If `strictMode` is false, other
5477932d21fSUday Bondhugula /// operations that use results of rewritten ops or supply operands to such ops
5487932d21fSUday Bondhugula /// are in turn simplified; any other ops still remain unmodified (i.e.,
5497932d21fSUday Bondhugula /// regardless of `strictMode`). Note that ops in `ops` could be erased as a
5507932d21fSUday Bondhugula /// result of folding, becoming dead, or via pattern rewrites. Returns true if
5517932d21fSUday Bondhugula /// at all any changes happened.
5527932d21fSUday Bondhugula // Unlike `OpPatternRewriteDriver::simplifyLocally` which works on a single op
5537932d21fSUday Bondhugula // or GreedyPatternRewriteDriver::simplify, this method just iterates until
5547932d21fSUday Bondhugula // the worklist is empty. As our objective is to keep simplification "local",
5557932d21fSUday Bondhugula // there is no strong rationale to re-add all operations into the worklist and
5567932d21fSUday Bondhugula // rerun until an iteration changes nothing. If more widereaching simplification
5577932d21fSUday Bondhugula // is desired, GreedyPatternRewriteDriver should be used.
5587932d21fSUday Bondhugula bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
5597932d21fSUday Bondhugula   if (strictMode) {
5607932d21fSUday Bondhugula     strictModeFilteredOps.clear();
5617932d21fSUday Bondhugula     strictModeFilteredOps.insert(ops.begin(), ops.end());
5627932d21fSUday Bondhugula   }
5637932d21fSUday Bondhugula 
5647932d21fSUday Bondhugula   bool changed = false;
5657932d21fSUday Bondhugula   worklist.clear();
5667932d21fSUday Bondhugula   worklistMap.clear();
5677932d21fSUday Bondhugula   for (Operation *op : ops)
5687932d21fSUday Bondhugula     addToWorklist(op);
5697932d21fSUday Bondhugula 
5707932d21fSUday Bondhugula   // These are scratch vectors used in the folding loop below.
5717932d21fSUday Bondhugula   SmallVector<Value, 8> originalOperands, resultValues;
5727932d21fSUday Bondhugula   while (!worklist.empty()) {
5737932d21fSUday Bondhugula     Operation *op = popFromWorklist();
5747932d21fSUday Bondhugula 
5757932d21fSUday Bondhugula     // Nulls get added to the worklist when operations are removed, ignore
5767932d21fSUday Bondhugula     // them.
5777932d21fSUday Bondhugula     if (op == nullptr)
5787932d21fSUday Bondhugula       continue;
5797932d21fSUday Bondhugula 
5807932d21fSUday Bondhugula     // If the operation is trivially dead - remove it.
5817932d21fSUday Bondhugula     if (isOpTriviallyDead(op)) {
5827932d21fSUday Bondhugula       notifyOperationRemoved(op);
5837932d21fSUday Bondhugula       op->erase();
5847932d21fSUday Bondhugula       changed = true;
5857932d21fSUday Bondhugula       continue;
5867932d21fSUday Bondhugula     }
5877932d21fSUday Bondhugula 
5887932d21fSUday Bondhugula     // Collects all the operands and result uses of the given `op` into work
5897932d21fSUday Bondhugula     // list. Also remove `op` and nested ops from worklist.
5907932d21fSUday Bondhugula     originalOperands.assign(op->operand_begin(), op->operand_end());
5917932d21fSUday Bondhugula     auto preReplaceAction = [&](Operation *op) {
5927932d21fSUday Bondhugula       // Add the operands to the worklist for visitation.
5937932d21fSUday Bondhugula       addOperandsToWorklist(originalOperands);
5947932d21fSUday Bondhugula 
5957932d21fSUday Bondhugula       // Add all the users of the result to the worklist so we make sure
5967932d21fSUday Bondhugula       // to revisit them.
5977932d21fSUday Bondhugula       for (Value result : op->getResults())
5987932d21fSUday Bondhugula         for (Operation *userOp : result.getUsers()) {
5997932d21fSUday Bondhugula           if (!strictMode || strictModeFilteredOps.contains(userOp))
6007932d21fSUday Bondhugula             addToWorklist(userOp);
6017932d21fSUday Bondhugula         }
6027932d21fSUday Bondhugula       notifyOperationRemoved(op);
6037932d21fSUday Bondhugula     };
6047932d21fSUday Bondhugula 
6057932d21fSUday Bondhugula     // Add the given operation generated by the folder to the worklist.
6067932d21fSUday Bondhugula     auto processGeneratedConstants = [this](Operation *op) {
6077932d21fSUday Bondhugula       // Newly created ops are also simplified -- these are also "local".
6087932d21fSUday Bondhugula       addToWorklist(op);
6097932d21fSUday Bondhugula       // When strict mode is off, we don't need to maintain
6107932d21fSUday Bondhugula       // strictModeFilteredOps.
6117932d21fSUday Bondhugula       if (strictMode)
6127932d21fSUday Bondhugula         strictModeFilteredOps.insert(op);
6137932d21fSUday Bondhugula     };
6147932d21fSUday Bondhugula 
6157932d21fSUday Bondhugula     // Try to fold this op.
6167932d21fSUday Bondhugula     bool inPlaceUpdate;
6177932d21fSUday Bondhugula     if (succeeded(folder.tryToFold(op, processGeneratedConstants,
6187932d21fSUday Bondhugula                                    preReplaceAction, &inPlaceUpdate))) {
6197932d21fSUday Bondhugula       changed = true;
6207932d21fSUday Bondhugula       if (!inPlaceUpdate) {
6217932d21fSUday Bondhugula         // Op has been erased.
6227932d21fSUday Bondhugula         continue;
6237932d21fSUday Bondhugula       }
6247932d21fSUday Bondhugula     }
6257932d21fSUday Bondhugula 
6267932d21fSUday Bondhugula     // Try to match one of the patterns. The rewriter is automatically
6277932d21fSUday Bondhugula     // notified of any necessary changes, so there is nothing else to do
6287932d21fSUday Bondhugula     // here.
6297932d21fSUday Bondhugula     changed |= succeeded(matcher.matchAndRewrite(op, *this));
6307932d21fSUday Bondhugula   }
6317932d21fSUday Bondhugula 
6327932d21fSUday Bondhugula   return changed;
6337932d21fSUday Bondhugula }
6347932d21fSUday Bondhugula 
63504b5274eSUday Bondhugula /// Rewrites only `op` using the supplied canonicalization patterns and
63604b5274eSUday Bondhugula /// folding. `erased` is set to true if the op is erased as a result of being
63704b5274eSUday Bondhugula /// folded, replaced, or dead.
6383e98fbf4SRiver Riddle LogicalResult mlir::applyOpPatternsAndFold(
63979d7f618SChris Lattner     Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) {
64004b5274eSUday Bondhugula   // Start the pattern driver.
64164716b2cSChris Lattner   GreedyRewriteConfig config;
64204b5274eSUday Bondhugula   OpPatternRewriteDriver driver(op->getContext(), patterns);
64304b5274eSUday Bondhugula   bool opErased;
6443e98fbf4SRiver Riddle   LogicalResult converged =
64564716b2cSChris Lattner       driver.simplifyLocally(op, config.maxIterations, opErased);
64604b5274eSUday Bondhugula   if (erased)
64704b5274eSUday Bondhugula     *erased = opErased;
6483e98fbf4SRiver Riddle   LLVM_DEBUG(if (failed(converged)) {
64904b5274eSUday Bondhugula     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
65064716b2cSChris Lattner                  << config.maxIterations << " times";
65104b5274eSUday Bondhugula   });
65204b5274eSUday Bondhugula   return converged;
65304b5274eSUday Bondhugula }
6547932d21fSUday Bondhugula 
6557932d21fSUday Bondhugula bool mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
6567932d21fSUday Bondhugula                                   const FrozenRewritePatternSet &patterns,
6577932d21fSUday Bondhugula                                   bool strict) {
6587932d21fSUday Bondhugula   if (ops.empty())
6597932d21fSUday Bondhugula     return false;
6607932d21fSUday Bondhugula 
6617932d21fSUday Bondhugula   // Start the pattern driver.
6627932d21fSUday Bondhugula   MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
6637932d21fSUday Bondhugula                                      strict);
6647932d21fSUday Bondhugula   return driver.simplifyLocally(ops);
6657932d21fSUday Bondhugula }
666