164d52014SChris Lattner //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
264d52014SChris Lattner //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
664d52014SChris Lattner //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
864d52014SChris Lattner //
9a60fdd2bSLorenzo Chelini // This file implements mlir::applyPatternsAndFoldGreedily.
1064d52014SChris Lattner //
1164d52014SChris Lattner //===----------------------------------------------------------------------===//
1264d52014SChris Lattner 
13b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14af371f9fSRiver Riddle #include "mlir/IR/Matchers.h"
15eb623ae8SStephen Neuendorffer #include "mlir/Interfaces/SideEffectInterfaces.h"
16b6eb26fdSRiver Riddle #include "mlir/Rewrite/PatternApplicator.h"
171982afb1SRiver Riddle #include "mlir/Transforms/FoldUtils.h"
18fafb708bSRiver Riddle #include "mlir/Transforms/RegionUtils.h"
1964d52014SChris Lattner #include "llvm/ADT/DenseMap.h"
205c757087SFeng Liu #include "llvm/Support/CommandLine.h"
215c757087SFeng Liu #include "llvm/Support/Debug.h"
225652ecc3SRiver Riddle #include "llvm/Support/ScopedPrinter.h"
235c757087SFeng Liu #include "llvm/Support/raw_ostream.h"
244e40c832SLei Zhang 
2564d52014SChris Lattner using namespace mlir;
2664d52014SChris Lattner 
275652ecc3SRiver Riddle #define DEBUG_TYPE "greedy-rewriter"
285c757087SFeng Liu 
2904b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
3004b5274eSUday Bondhugula // GreedyPatternRewriteDriver
3104b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
3204b5274eSUday Bondhugula 
3364d52014SChris Lattner namespace {
3464d52014SChris Lattner /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
3564d52014SChris Lattner /// applies the locally optimal patterns in a roughly "bottom up" way.
364bd9f936SChris Lattner class GreedyPatternRewriteDriver : public PatternRewriter {
3764d52014SChris Lattner public:
382566a72aSRiver Riddle   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
39648f34a2SChris Lattner                                       const FrozenRewritePatternSet &patterns,
40b7144ab7SRiver Riddle                                       const GreedyRewriteConfig &config);
413e98fbf4SRiver Riddle 
42b7144ab7SRiver Riddle   /// Simplify the operations within the given regions.
4364716b2cSChris Lattner   bool simplify(MutableArrayRef<Region> regions);
4464d52014SChris Lattner 
45b7144ab7SRiver Riddle   /// Add the given operation to the worklist.
46*ba3a9f51SChia-hung Duan   virtual void addToWorklist(Operation *op);
475c4f1fddSRiver Riddle 
48b7144ab7SRiver Riddle   /// Pop the next operation from the worklist.
49b7144ab7SRiver Riddle   Operation *popFromWorklist();
5064d52014SChris Lattner 
51b7144ab7SRiver Riddle   /// If the specified operation is in the worklist, remove it.
52b7144ab7SRiver Riddle   void removeFromWorklist(Operation *op);
5364d52014SChris Lattner 
544bd9f936SChris Lattner protected:
55851a8516SRiver Riddle   // Implement the hook for inserting operations, and make sure that newly
56851a8516SRiver Riddle   // inserted ops are added to the worklist for processing.
57b7144ab7SRiver Riddle   void notifyOperationInserted(Operation *op) override;
5864d52014SChris Lattner 
597932d21fSUday Bondhugula   // Look over the provided operands for any defining operations that should
607932d21fSUday Bondhugula   // be re-added to the worklist. This function should be called when an
617932d21fSUday Bondhugula   // operation is modified or removed, as it may trigger further
627932d21fSUday Bondhugula   // simplifications.
63*ba3a9f51SChia-hung Duan   void addOperandsToWorklist(ValueRange operands);
647932d21fSUday Bondhugula 
6564d52014SChris Lattner   // If an operation is about to be removed, make sure it is not in our
6664d52014SChris Lattner   // worklist anymore because we'd get dangling references to it.
67b7144ab7SRiver Riddle   void notifyOperationRemoved(Operation *op) override;
6864d52014SChris Lattner 
69085b687fSChris Lattner   // When the root of a pattern is about to be replaced, it can trigger
70085b687fSChris Lattner   // simplifications to its users - make sure to add them to the worklist
71085b687fSChris Lattner   // before the root is changed.
72b7144ab7SRiver Riddle   void notifyRootReplaced(Operation *op) override;
73085b687fSChris Lattner 
745652ecc3SRiver Riddle   /// PatternRewriter hook for erasing a dead operation.
755652ecc3SRiver Riddle   void eraseOp(Operation *op) override;
765652ecc3SRiver Riddle 
775652ecc3SRiver Riddle   /// PatternRewriter hook for notifying match failure reasons.
785652ecc3SRiver Riddle   LogicalResult
79ea64828aSRiver Riddle   notifyMatchFailure(Location loc,
805652ecc3SRiver Riddle                      function_ref<void(Diagnostic &)> reasonCallback) override;
815652ecc3SRiver Riddle 
823e98fbf4SRiver Riddle   /// The low-level pattern applicator.
833e98fbf4SRiver Riddle   PatternApplicator matcher;
844bd9f936SChris Lattner 
854bd9f936SChris Lattner   /// The worklist for this transformation keeps track of the operations that
864bd9f936SChris Lattner   /// need to be revisited, plus their index in the worklist.  This allows us to
87e7a2ef21SRiver Riddle   /// efficiently remove operations from the worklist when they are erased, even
88e7a2ef21SRiver Riddle   /// if they aren't the root of a pattern.
8999b87c97SRiver Riddle   std::vector<Operation *> worklist;
9099b87c97SRiver Riddle   DenseMap<Operation *, unsigned> worklistMap;
9160a29837SRiver Riddle 
9260a29837SRiver Riddle   /// Non-pattern based folder for operations.
9360a29837SRiver Riddle   OperationFolder folder;
94648f34a2SChris Lattner 
957932d21fSUday Bondhugula private:
9664716b2cSChris Lattner   /// Configuration information for how to simplify.
9764716b2cSChris Lattner   GreedyRewriteConfig config;
985652ecc3SRiver Riddle 
995652ecc3SRiver Riddle #ifndef NDEBUG
1005652ecc3SRiver Riddle   /// A logger used to emit information during the application process.
1015652ecc3SRiver Riddle   llvm::ScopedPrinter logger{llvm::dbgs()};
1025652ecc3SRiver Riddle #endif
10364d52014SChris Lattner };
104be0a7e9fSMehdi Amini } // namespace
10564d52014SChris Lattner 
GreedyPatternRewriteDriver(MLIRContext * ctx,const FrozenRewritePatternSet & patterns,const GreedyRewriteConfig & config)106b7144ab7SRiver Riddle GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
107b7144ab7SRiver Riddle     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
108b7144ab7SRiver Riddle     const GreedyRewriteConfig &config)
109b7144ab7SRiver Riddle     : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
110b7144ab7SRiver Riddle   worklist.reserve(64);
111b7144ab7SRiver Riddle 
112b7144ab7SRiver Riddle   // Apply a simple cost model based solely on pattern benefit.
113b7144ab7SRiver Riddle   matcher.applyDefaultCostModel();
114b7144ab7SRiver Riddle }
115b7144ab7SRiver Riddle 
simplify(MutableArrayRef<Region> regions)11664716b2cSChris Lattner bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
1175652ecc3SRiver Riddle #ifndef NDEBUG
1185652ecc3SRiver Riddle   const char *logLineComment =
1195652ecc3SRiver Riddle       "//===-------------------------------------------===//\n";
1205652ecc3SRiver Riddle 
1215652ecc3SRiver Riddle   /// A utility function to log a process result for the given reason.
1225652ecc3SRiver Riddle   auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
1235652ecc3SRiver Riddle     logger.unindent();
1245652ecc3SRiver Riddle     logger.startLine() << "} -> " << result;
1255652ecc3SRiver Riddle     if (!msg.isTriviallyEmpty())
1265652ecc3SRiver Riddle       logger.getOStream() << " : " << msg;
1275652ecc3SRiver Riddle     logger.getOStream() << "\n";
1285652ecc3SRiver Riddle   };
1295652ecc3SRiver Riddle   auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
1305652ecc3SRiver Riddle     logResult(result, msg);
1315652ecc3SRiver Riddle     logger.startLine() << logLineComment;
1325652ecc3SRiver Riddle   };
1335652ecc3SRiver Riddle #endif
1345652ecc3SRiver Riddle 
1357814b559Srkayaith   auto insertKnownConstant = [&](Operation *op) {
1367814b559Srkayaith     // Check for existing constants when populating the worklist. This avoids
1377814b559Srkayaith     // accidentally reversing the constant order during processing.
1387814b559Srkayaith     Attribute constValue;
1397814b559Srkayaith     if (matchPattern(op, m_Constant(&constValue)))
1407814b559Srkayaith       if (!folder.insertKnownConstant(op, constValue))
1417814b559Srkayaith         return true;
1427814b559Srkayaith     return false;
1437814b559Srkayaith   };
1447814b559Srkayaith 
1455c757087SFeng Liu   bool changed = false;
14664716b2cSChris Lattner   unsigned iteration = 0;
1475c757087SFeng Liu   do {
148648f34a2SChris Lattner     worklist.clear();
149648f34a2SChris Lattner     worklistMap.clear();
150648f34a2SChris Lattner 
15164716b2cSChris Lattner     if (!config.useTopDownTraversal) {
15264716b2cSChris Lattner       // Add operations to the worklist in postorder.
153af371f9fSRiver Riddle       for (auto &region : regions) {
1547814b559Srkayaith         region.walk([&](Operation *op) {
1557814b559Srkayaith           if (!insertKnownConstant(op))
156af371f9fSRiver Riddle             addToWorklist(op);
157af371f9fSRiver Riddle         });
158af371f9fSRiver Riddle       }
15964716b2cSChris Lattner     } else {
160648f34a2SChris Lattner       // Add all nested operations to the worklist in preorder.
161aa568e08SRiver Riddle       for (auto &region : regions) {
1627814b559Srkayaith         region.walk<WalkOrder::PreOrder>([&](Operation *op) {
163aa568e08SRiver Riddle           if (!insertKnownConstant(op)) {
1647814b559Srkayaith             worklist.push_back(op);
165aa568e08SRiver Riddle             return WalkResult::advance();
166aa568e08SRiver Riddle           }
167aa568e08SRiver Riddle           return WalkResult::skip();
1687814b559Srkayaith         });
169aa568e08SRiver Riddle       }
170648f34a2SChris Lattner 
171648f34a2SChris Lattner       // Reverse the list so our pop-back loop processes them in-order.
172648f34a2SChris Lattner       std::reverse(worklist.begin(), worklist.end());
173648f34a2SChris Lattner       // Remember the reverse index.
174648f34a2SChris Lattner       for (size_t i = 0, e = worklist.size(); i != e; ++i)
175648f34a2SChris Lattner         worklistMap[worklist[i]] = i;
176648f34a2SChris Lattner     }
1774e40c832SLei Zhang 
1784e40c832SLei Zhang     // These are scratch vectors used in the folding loop below.
179e62a6956SRiver Riddle     SmallVector<Value, 8> originalOperands, resultValues;
18064d52014SChris Lattner 
1815c757087SFeng Liu     changed = false;
18264d52014SChris Lattner     while (!worklist.empty()) {
18364d52014SChris Lattner       auto *op = popFromWorklist();
18464d52014SChris Lattner 
1855c757087SFeng Liu       // Nulls get added to the worklist when operations are removed, ignore
1865c757087SFeng Liu       // them.
18764d52014SChris Lattner       if (op == nullptr)
18864d52014SChris Lattner         continue;
18964d52014SChris Lattner 
1905652ecc3SRiver Riddle       LLVM_DEBUG({
1915652ecc3SRiver Riddle         logger.getOStream() << "\n";
1925652ecc3SRiver Riddle         logger.startLine() << logLineComment;
1935652ecc3SRiver Riddle         logger.startLine() << "Processing operation : '" << op->getName()
1945652ecc3SRiver Riddle                            << "'(" << op << ") {\n";
1955652ecc3SRiver Riddle         logger.indent();
1965652ecc3SRiver Riddle 
1975652ecc3SRiver Riddle         // If the operation has no regions, just print it here.
1985652ecc3SRiver Riddle         if (op->getNumRegions() == 0) {
1995652ecc3SRiver Riddle           op->print(
2005652ecc3SRiver Riddle               logger.startLine(),
2015652ecc3SRiver Riddle               OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
2025652ecc3SRiver Riddle           logger.getOStream() << "\n\n";
2035652ecc3SRiver Riddle         }
2045652ecc3SRiver Riddle       });
2055652ecc3SRiver Riddle 
2060ddba0bdSRiver Riddle       // If the operation is trivially dead - remove it.
2070ddba0bdSRiver Riddle       if (isOpTriviallyDead(op)) {
2086a501e3dSAndy Ly         notifyOperationRemoved(op);
20964d52014SChris Lattner         op->erase();
210f875e55bSUday Bondhugula         changed = true;
2115652ecc3SRiver Riddle 
2125652ecc3SRiver Riddle         LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
21364d52014SChris Lattner         continue;
21464d52014SChris Lattner       }
21564d52014SChris Lattner 
2164e40c832SLei Zhang       // Collects all the operands and result uses of the given `op` into work
2176a501e3dSAndy Ly       // list. Also remove `op` and nested ops from worklist.
2181982afb1SRiver Riddle       originalOperands.assign(op->operand_begin(), op->operand_end());
2196a501e3dSAndy Ly       auto preReplaceAction = [&](Operation *op) {
220a8866258SRiver Riddle         // Add the operands to the worklist for visitation.
221*ba3a9f51SChia-hung Duan         addOperandsToWorklist(originalOperands);
2221982afb1SRiver Riddle 
2234e40c832SLei Zhang         // Add all the users of the result to the worklist so we make sure
2244e40c832SLei Zhang         // to revisit them.
22535807bc4SRiver Riddle         for (auto result : op->getResults())
226cc673894SUday Bondhugula           for (auto *userOp : result.getUsers())
227cc673894SUday Bondhugula             addToWorklist(userOp);
2286a501e3dSAndy Ly 
2296a501e3dSAndy Ly         notifyOperationRemoved(op);
2304e40c832SLei Zhang       };
23164d52014SChris Lattner 
232648f34a2SChris Lattner       // Add the given operation to the worklist.
233648f34a2SChris Lattner       auto collectOps = [this](Operation *op) { addToWorklist(op); };
234648f34a2SChris Lattner 
2351982afb1SRiver Riddle       // Try to fold this op.
236cbcb12fdSUday Bondhugula       bool inPlaceUpdate;
237cbcb12fdSUday Bondhugula       if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
238cbcb12fdSUday Bondhugula                                       &inPlaceUpdate)))) {
2395652ecc3SRiver Riddle         LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
2405652ecc3SRiver Riddle 
241f875e55bSUday Bondhugula         changed = true;
242cbcb12fdSUday Bondhugula         if (!inPlaceUpdate)
243934b6d12SChris Lattner           continue;
24464d52014SChris Lattner       }
24564d52014SChris Lattner 
24632052c84SRiver Riddle       // Try to match one of the patterns. The rewriter is automatically
247648f34a2SChris Lattner       // notified of any necessary changes, so there is nothing else to do
248648f34a2SChris Lattner       // here.
2495652ecc3SRiver Riddle #ifndef NDEBUG
2505652ecc3SRiver Riddle       auto canApply = [&](const Pattern &pattern) {
2515652ecc3SRiver Riddle         LLVM_DEBUG({
2525652ecc3SRiver Riddle           logger.getOStream() << "\n";
2535652ecc3SRiver Riddle           logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
2545652ecc3SRiver Riddle                              << op->getName() << " -> (";
2555652ecc3SRiver Riddle           llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
2565652ecc3SRiver Riddle           logger.getOStream() << ")' {\n";
2575652ecc3SRiver Riddle           logger.indent();
2585652ecc3SRiver Riddle         });
2595652ecc3SRiver Riddle         return true;
2605652ecc3SRiver Riddle       };
2615652ecc3SRiver Riddle       auto onFailure = [&](const Pattern &pattern) {
2625652ecc3SRiver Riddle         LLVM_DEBUG(logResult("failure", "pattern failed to match"));
2635652ecc3SRiver Riddle       };
2645652ecc3SRiver Riddle       auto onSuccess = [&](const Pattern &pattern) {
2655652ecc3SRiver Riddle         LLVM_DEBUG(logResult("success", "pattern applied successfully"));
2665652ecc3SRiver Riddle         return success();
2675652ecc3SRiver Riddle       };
2685652ecc3SRiver Riddle 
2695652ecc3SRiver Riddle       LogicalResult matchResult =
2705652ecc3SRiver Riddle           matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
2715652ecc3SRiver Riddle       if (succeeded(matchResult))
2725652ecc3SRiver Riddle         LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
2735652ecc3SRiver Riddle       else
2745652ecc3SRiver Riddle         LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
2755652ecc3SRiver Riddle #else
2765652ecc3SRiver Riddle       LogicalResult matchResult = matcher.matchAndRewrite(op, *this);
2775652ecc3SRiver Riddle #endif
2785652ecc3SRiver Riddle       changed |= succeeded(matchResult);
27964d52014SChris Lattner     }
280a32f0dcbSRiver Riddle 
281648f34a2SChris Lattner     // After applying patterns, make sure that the CFG of each of the regions
282648f34a2SChris Lattner     // is kept up to date.
28364716b2cSChris Lattner     if (config.enableRegionSimplification)
284d75a611aSRiver Riddle       changed |= succeeded(simplifyRegions(*this, regions));
285519663beSFrederik Gossen   } while (changed &&
286673e9828SFrederik Gossen            (iteration++ < config.maxIterations ||
287519663beSFrederik Gossen             config.maxIterations == GreedyRewriteConfig::kNoIterationLimit));
288648f34a2SChris Lattner 
2895c757087SFeng Liu   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
2905c757087SFeng Liu   return !changed;
29164d52014SChris Lattner }
29264d52014SChris Lattner 
addToWorklist(Operation * op)293b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
294b7144ab7SRiver Riddle   // Check to see if the worklist already contains this op.
295b7144ab7SRiver Riddle   if (worklistMap.count(op))
296b7144ab7SRiver Riddle     return;
297b7144ab7SRiver Riddle 
298b7144ab7SRiver Riddle   worklistMap[op] = worklist.size();
299b7144ab7SRiver Riddle   worklist.push_back(op);
300b7144ab7SRiver Riddle }
301b7144ab7SRiver Riddle 
popFromWorklist()302b7144ab7SRiver Riddle Operation *GreedyPatternRewriteDriver::popFromWorklist() {
303b7144ab7SRiver Riddle   auto *op = worklist.back();
304b7144ab7SRiver Riddle   worklist.pop_back();
305b7144ab7SRiver Riddle 
306b7144ab7SRiver Riddle   // This operation is no longer in the worklist, keep worklistMap up to date.
307b7144ab7SRiver Riddle   if (op)
308b7144ab7SRiver Riddle     worklistMap.erase(op);
309b7144ab7SRiver Riddle   return op;
310b7144ab7SRiver Riddle }
311b7144ab7SRiver Riddle 
removeFromWorklist(Operation * op)312b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::removeFromWorklist(Operation *op) {
313b7144ab7SRiver Riddle   auto it = worklistMap.find(op);
314b7144ab7SRiver Riddle   if (it != worklistMap.end()) {
315b7144ab7SRiver Riddle     assert(worklist[it->second] == op && "malformed worklist data structure");
316b7144ab7SRiver Riddle     worklist[it->second] = nullptr;
317b7144ab7SRiver Riddle     worklistMap.erase(it);
318b7144ab7SRiver Riddle   }
319b7144ab7SRiver Riddle }
320b7144ab7SRiver Riddle 
notifyOperationInserted(Operation * op)321b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
3225652ecc3SRiver Riddle   LLVM_DEBUG({
3235652ecc3SRiver Riddle     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
3245652ecc3SRiver Riddle                        << ")\n";
3255652ecc3SRiver Riddle   });
326b7144ab7SRiver Riddle   addToWorklist(op);
327b7144ab7SRiver Riddle }
328b7144ab7SRiver Riddle 
addOperandsToWorklist(ValueRange operands)329*ba3a9f51SChia-hung Duan void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) {
330b7144ab7SRiver Riddle   for (Value operand : operands) {
331b7144ab7SRiver Riddle     // If the use count of this operand is now < 2, we re-add the defining
332b7144ab7SRiver Riddle     // operation to the worklist.
333b7144ab7SRiver Riddle     // TODO: This is based on the fact that zero use operations
334b7144ab7SRiver Riddle     // may be deleted, and that single use values often have more
335b7144ab7SRiver Riddle     // canonicalization opportunities.
336b7144ab7SRiver Riddle     if (!operand || (!operand.use_empty() && !operand.hasOneUse()))
337b7144ab7SRiver Riddle       continue;
338b7144ab7SRiver Riddle     if (auto *defOp = operand.getDefiningOp())
339b7144ab7SRiver Riddle       addToWorklist(defOp);
340b7144ab7SRiver Riddle   }
341b7144ab7SRiver Riddle }
342b7144ab7SRiver Riddle 
notifyOperationRemoved(Operation * op)343b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
344*ba3a9f51SChia-hung Duan   addOperandsToWorklist(op->getOperands());
345b7144ab7SRiver Riddle   op->walk([this](Operation *operation) {
346b7144ab7SRiver Riddle     removeFromWorklist(operation);
347b7144ab7SRiver Riddle     folder.notifyRemoval(operation);
348b7144ab7SRiver Riddle   });
349b7144ab7SRiver Riddle }
350b7144ab7SRiver Riddle 
notifyRootReplaced(Operation * op)351b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op) {
3525652ecc3SRiver Riddle   LLVM_DEBUG({
3535652ecc3SRiver Riddle     logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
3545652ecc3SRiver Riddle                        << ")\n";
3555652ecc3SRiver Riddle   });
356b7144ab7SRiver Riddle   for (auto result : op->getResults())
357b7144ab7SRiver Riddle     for (auto *user : result.getUsers())
358b7144ab7SRiver Riddle       addToWorklist(user);
359b7144ab7SRiver Riddle }
360b7144ab7SRiver Riddle 
eraseOp(Operation * op)3615652ecc3SRiver Riddle void GreedyPatternRewriteDriver::eraseOp(Operation *op) {
3625652ecc3SRiver Riddle   LLVM_DEBUG({
3635652ecc3SRiver Riddle     logger.startLine() << "** Erase   : '" << op->getName() << "'(" << op
3645652ecc3SRiver Riddle                        << ")\n";
3655652ecc3SRiver Riddle   });
3665652ecc3SRiver Riddle   PatternRewriter::eraseOp(op);
3675652ecc3SRiver Riddle }
3685652ecc3SRiver Riddle 
notifyMatchFailure(Location loc,function_ref<void (Diagnostic &)> reasonCallback)3695652ecc3SRiver Riddle LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
370ea64828aSRiver Riddle     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
3715652ecc3SRiver Riddle   LLVM_DEBUG({
372ea64828aSRiver Riddle     Diagnostic diag(loc, DiagnosticSeverity::Remark);
3735652ecc3SRiver Riddle     reasonCallback(diag);
3745652ecc3SRiver Riddle     logger.startLine() << "** Failure : " << diag.str() << "\n";
3755652ecc3SRiver Riddle   });
3765652ecc3SRiver Riddle   return failure();
3775652ecc3SRiver Riddle }
3785652ecc3SRiver Riddle 
379e7a2ef21SRiver Riddle /// Rewrite the regions of the specified operation, which must be isolated from
380e7a2ef21SRiver Riddle /// above, by repeatedly applying the highest benefit patterns in a greedy
3813e98fbf4SRiver Riddle /// work-list driven manner. Return success if no more patterns can be matched
3823e98fbf4SRiver Riddle /// in the result operation regions. Note: This does not apply patterns to the
3833e98fbf4SRiver Riddle /// top-level operation itself.
38464d52014SChris Lattner ///
3853e98fbf4SRiver Riddle LogicalResult
applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,const FrozenRewritePatternSet & patterns,GreedyRewriteConfig config)3863e98fbf4SRiver Riddle mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
3870b20413eSUday Bondhugula                                    const FrozenRewritePatternSet &patterns,
38864716b2cSChris Lattner                                    GreedyRewriteConfig config) {
3896b1cc3c6SRiver Riddle   if (regions.empty())
3903e98fbf4SRiver Riddle     return success();
3916b1cc3c6SRiver Riddle 
392e7a2ef21SRiver Riddle   // The top-level operation must be known to be isolated from above to
393e7a2ef21SRiver Riddle   // prevent performing canonicalizations on operations defined at or above
394e7a2ef21SRiver Riddle   // the region containing 'op'.
3956b1cc3c6SRiver Riddle   auto regionIsIsolated = [](Region &region) {
396fe7c0d90SRiver Riddle     return region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>();
3976b1cc3c6SRiver Riddle   };
3986b1cc3c6SRiver Riddle   (void)regionIsIsolated;
3996b1cc3c6SRiver Riddle   assert(llvm::all_of(regions, regionIsIsolated) &&
4006b1cc3c6SRiver Riddle          "patterns can only be applied to operations IsolatedFromAbove");
401e7a2ef21SRiver Riddle 
4026b1cc3c6SRiver Riddle   // Start the pattern driver.
40364716b2cSChris Lattner   GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config);
40464716b2cSChris Lattner   bool converged = driver.simplify(regions);
4055c757087SFeng Liu   LLVM_DEBUG(if (!converged) {
406e7a2ef21SRiver Riddle     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
40764716b2cSChris Lattner                  << config.maxIterations << " times\n";
4085c757087SFeng Liu   });
4093e98fbf4SRiver Riddle   return success(converged);
41064d52014SChris Lattner }
41104b5274eSUday Bondhugula 
41204b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
41304b5274eSUday Bondhugula // OpPatternRewriteDriver
41404b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
41504b5274eSUday Bondhugula 
41604b5274eSUday Bondhugula namespace {
41704b5274eSUday Bondhugula /// This is a simple driver for the PatternMatcher to apply patterns and perform
41804b5274eSUday Bondhugula /// folding on a single op. It repeatedly applies locally optimal patterns.
41904b5274eSUday Bondhugula class OpPatternRewriteDriver : public PatternRewriter {
42004b5274eSUday Bondhugula public:
OpPatternRewriteDriver(MLIRContext * ctx,const FrozenRewritePatternSet & patterns)42104b5274eSUday Bondhugula   explicit OpPatternRewriteDriver(MLIRContext *ctx,
42279d7f618SChris Lattner                                   const FrozenRewritePatternSet &patterns)
4233e98fbf4SRiver Riddle       : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
4243e98fbf4SRiver Riddle     // Apply a simple cost model based solely on pattern benefit.
4253e98fbf4SRiver Riddle     matcher.applyDefaultCostModel();
4263e98fbf4SRiver Riddle   }
42704b5274eSUday Bondhugula 
4283e98fbf4SRiver Riddle   LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased);
42904b5274eSUday Bondhugula 
43004b5274eSUday Bondhugula   // These are hooks implemented for PatternRewriter.
43104b5274eSUday Bondhugula protected:
43204b5274eSUday Bondhugula   /// If an operation is about to be removed, mark it so that we can let clients
43304b5274eSUday Bondhugula   /// know.
notifyOperationRemoved(Operation * op)43404b5274eSUday Bondhugula   void notifyOperationRemoved(Operation *op) override {
43504b5274eSUday Bondhugula     opErasedViaPatternRewrites = true;
43604b5274eSUday Bondhugula   }
43704b5274eSUday Bondhugula 
43804b5274eSUday Bondhugula   // When a root is going to be replaced, its removal will be notified as well.
43904b5274eSUday Bondhugula   // So there is nothing to do here.
notifyRootReplaced(Operation * op)44004b5274eSUday Bondhugula   void notifyRootReplaced(Operation *op) override {}
44104b5274eSUday Bondhugula 
44204b5274eSUday Bondhugula private:
4433e98fbf4SRiver Riddle   /// The low-level pattern applicator.
4443e98fbf4SRiver Riddle   PatternApplicator matcher;
44504b5274eSUday Bondhugula 
44604b5274eSUday Bondhugula   /// Non-pattern based folder for operations.
44704b5274eSUday Bondhugula   OperationFolder folder;
44804b5274eSUday Bondhugula 
44904b5274eSUday Bondhugula   /// Set to true if the operation has been erased via pattern rewrites.
45004b5274eSUday Bondhugula   bool opErasedViaPatternRewrites = false;
45104b5274eSUday Bondhugula };
45204b5274eSUday Bondhugula 
453be0a7e9fSMehdi Amini } // namespace
45404b5274eSUday Bondhugula 
4557932d21fSUday Bondhugula /// Performs the rewrites and folding only on `op`. The simplification
4567932d21fSUday Bondhugula /// converges if the op is erased as a result of being folded, replaced, or
4577932d21fSUday Bondhugula /// becoming dead, or no more changes happen in an iteration. Returns success if
4587932d21fSUday Bondhugula /// the rewrite converges in `maxIterations`. `erased` is set to true if `op`
4597932d21fSUday Bondhugula /// gets erased.
simplifyLocally(Operation * op,int maxIterations,bool & erased)4603e98fbf4SRiver Riddle LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
4613e98fbf4SRiver Riddle                                                       int maxIterations,
46204b5274eSUday Bondhugula                                                       bool &erased) {
46304b5274eSUday Bondhugula   bool changed = false;
46404b5274eSUday Bondhugula   erased = false;
46504b5274eSUday Bondhugula   opErasedViaPatternRewrites = false;
4667932d21fSUday Bondhugula   int iterations = 0;
46704b5274eSUday Bondhugula   // Iterate until convergence or until maxIterations. Deletion of the op as
46804b5274eSUday Bondhugula   // a result of being dead or folded is convergence.
46904b5274eSUday Bondhugula   do {
470ff87c4d3SChristian Sigg     changed = false;
471ff87c4d3SChristian Sigg 
47204b5274eSUday Bondhugula     // If the operation is trivially dead - remove it.
47304b5274eSUday Bondhugula     if (isOpTriviallyDead(op)) {
47404b5274eSUday Bondhugula       op->erase();
47504b5274eSUday Bondhugula       erased = true;
4763e98fbf4SRiver Riddle       return success();
47704b5274eSUday Bondhugula     }
47804b5274eSUday Bondhugula 
47904b5274eSUday Bondhugula     // Try to fold this op.
48004b5274eSUday Bondhugula     bool inPlaceUpdate;
48104b5274eSUday Bondhugula     if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr,
48204b5274eSUday Bondhugula                                    /*preReplaceAction=*/nullptr,
48304b5274eSUday Bondhugula                                    &inPlaceUpdate))) {
48404b5274eSUday Bondhugula       changed = true;
48504b5274eSUday Bondhugula       if (!inPlaceUpdate) {
48604b5274eSUday Bondhugula         erased = true;
4873e98fbf4SRiver Riddle         return success();
48804b5274eSUday Bondhugula       }
48904b5274eSUday Bondhugula     }
49004b5274eSUday Bondhugula 
49104b5274eSUday Bondhugula     // Try to match one of the patterns. The rewriter is automatically
49204b5274eSUday Bondhugula     // notified of any necessary changes, so there is nothing else to do here.
4933e98fbf4SRiver Riddle     changed |= succeeded(matcher.matchAndRewrite(op, *this));
49404b5274eSUday Bondhugula     if ((erased = opErasedViaPatternRewrites))
4953e98fbf4SRiver Riddle       return success();
496519663beSFrederik Gossen   } while (changed &&
497519663beSFrederik Gossen            (++iterations < maxIterations ||
498519663beSFrederik Gossen             maxIterations == GreedyRewriteConfig::kNoIterationLimit));
49904b5274eSUday Bondhugula 
50004b5274eSUday Bondhugula   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
5013e98fbf4SRiver Riddle   return failure(changed);
50204b5274eSUday Bondhugula }
50304b5274eSUday Bondhugula 
5047932d21fSUday Bondhugula //===----------------------------------------------------------------------===//
5057932d21fSUday Bondhugula // MultiOpPatternRewriteDriver
5067932d21fSUday Bondhugula //===----------------------------------------------------------------------===//
5077932d21fSUday Bondhugula 
5087932d21fSUday Bondhugula namespace {
5097932d21fSUday Bondhugula 
5107932d21fSUday Bondhugula /// This is a specialized GreedyPatternRewriteDriver to apply patterns and
5117932d21fSUday Bondhugula /// perform folding for a supplied set of ops. It repeatedly simplifies while
5127932d21fSUday Bondhugula /// restricting the rewrites to only the provided set of ops or optionally
5137932d21fSUday Bondhugula /// to those directly affected by it (result users or operand providers).
5147932d21fSUday Bondhugula class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
5157932d21fSUday Bondhugula public:
MultiOpPatternRewriteDriver(MLIRContext * ctx,const FrozenRewritePatternSet & patterns,bool strict)5167932d21fSUday Bondhugula   explicit MultiOpPatternRewriteDriver(MLIRContext *ctx,
5177932d21fSUday Bondhugula                                        const FrozenRewritePatternSet &patterns,
5187932d21fSUday Bondhugula                                        bool strict)
5197932d21fSUday Bondhugula       : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()),
5207932d21fSUday Bondhugula         strictMode(strict) {}
5217932d21fSUday Bondhugula 
5227932d21fSUday Bondhugula   bool simplifyLocally(ArrayRef<Operation *> op);
5237932d21fSUday Bondhugula 
addToWorklist(Operation * op)524*ba3a9f51SChia-hung Duan   void addToWorklist(Operation *op) override {
525*ba3a9f51SChia-hung Duan     if (!strictMode || strictModeFilteredOps.contains(op))
526*ba3a9f51SChia-hung Duan       GreedyPatternRewriteDriver::addToWorklist(op);
5277932d21fSUday Bondhugula   }
5287932d21fSUday Bondhugula 
529*ba3a9f51SChia-hung Duan private:
notifyOperationInserted(Operation * op)5302aeffc6dSChia-hung Duan   void notifyOperationInserted(Operation *op) override {
5312aeffc6dSChia-hung Duan     GreedyPatternRewriteDriver::notifyOperationInserted(op);
5322aeffc6dSChia-hung Duan     if (strictMode)
5332aeffc6dSChia-hung Duan       strictModeFilteredOps.insert(op);
5342aeffc6dSChia-hung Duan   }
5352aeffc6dSChia-hung Duan 
notifyOperationRemoved(Operation * op)5367932d21fSUday Bondhugula   void notifyOperationRemoved(Operation *op) override {
5377932d21fSUday Bondhugula     GreedyPatternRewriteDriver::notifyOperationRemoved(op);
5387932d21fSUday Bondhugula     if (strictMode)
5397932d21fSUday Bondhugula       strictModeFilteredOps.erase(op);
5407932d21fSUday Bondhugula   }
5417932d21fSUday Bondhugula 
5427932d21fSUday Bondhugula   /// If `strictMode` is true, any pre-existing ops outside of
5437932d21fSUday Bondhugula   /// `strictModeFilteredOps` remain completely untouched by the rewrite driver.
5447932d21fSUday Bondhugula   /// If `strictMode` is false, operations that use results of (or supply
5457932d21fSUday Bondhugula   /// operands to) any rewritten ops stemming from the simplification of the
5467932d21fSUday Bondhugula   /// provided ops are in turn simplified; any other ops still remain untouched
5477932d21fSUday Bondhugula   /// (i.e., regardless of `strictMode`).
5487932d21fSUday Bondhugula   bool strictMode = false;
5497932d21fSUday Bondhugula 
5507932d21fSUday Bondhugula   /// The list of ops we are restricting our rewrites to if `strictMode` is on.
5517932d21fSUday Bondhugula   /// These include the supplied set of ops as well as new ops created while
5527932d21fSUday Bondhugula   /// rewriting those ops. This set is not maintained when strictMode is off.
5537932d21fSUday Bondhugula   llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
5547932d21fSUday Bondhugula };
5557932d21fSUday Bondhugula 
556be0a7e9fSMehdi Amini } // namespace
5577932d21fSUday Bondhugula 
5587932d21fSUday Bondhugula /// Performs the specified rewrites on `ops` while also trying to fold these ops
5597932d21fSUday Bondhugula /// as well as any other ops that were in turn created due to these rewrite
5607932d21fSUday Bondhugula /// patterns. Any pre-existing ops outside of `ops` remain completely
5617932d21fSUday Bondhugula /// unmodified if `strictMode` is true. If `strictMode` is false, other
5627932d21fSUday Bondhugula /// operations that use results of rewritten ops or supply operands to such ops
5637932d21fSUday Bondhugula /// are in turn simplified; any other ops still remain unmodified (i.e.,
5647932d21fSUday Bondhugula /// regardless of `strictMode`). Note that ops in `ops` could be erased as a
5657932d21fSUday Bondhugula /// result of folding, becoming dead, or via pattern rewrites. Returns true if
5667932d21fSUday Bondhugula /// at all any changes happened.
5677932d21fSUday Bondhugula // Unlike `OpPatternRewriteDriver::simplifyLocally` which works on a single op
5687932d21fSUday Bondhugula // or GreedyPatternRewriteDriver::simplify, this method just iterates until
5697932d21fSUday Bondhugula // the worklist is empty. As our objective is to keep simplification "local",
5707932d21fSUday Bondhugula // there is no strong rationale to re-add all operations into the worklist and
5717932d21fSUday Bondhugula // rerun until an iteration changes nothing. If more widereaching simplification
5727932d21fSUday Bondhugula // is desired, GreedyPatternRewriteDriver should be used.
simplifyLocally(ArrayRef<Operation * > ops)5737932d21fSUday Bondhugula bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
5747932d21fSUday Bondhugula   if (strictMode) {
5757932d21fSUday Bondhugula     strictModeFilteredOps.clear();
5767932d21fSUday Bondhugula     strictModeFilteredOps.insert(ops.begin(), ops.end());
5777932d21fSUday Bondhugula   }
5787932d21fSUday Bondhugula 
5797932d21fSUday Bondhugula   bool changed = false;
5807932d21fSUday Bondhugula   worklist.clear();
5817932d21fSUday Bondhugula   worklistMap.clear();
5827932d21fSUday Bondhugula   for (Operation *op : ops)
5837932d21fSUday Bondhugula     addToWorklist(op);
5847932d21fSUday Bondhugula 
5857932d21fSUday Bondhugula   // These are scratch vectors used in the folding loop below.
5867932d21fSUday Bondhugula   SmallVector<Value, 8> originalOperands, resultValues;
5877932d21fSUday Bondhugula   while (!worklist.empty()) {
5887932d21fSUday Bondhugula     Operation *op = popFromWorklist();
5897932d21fSUday Bondhugula 
5907932d21fSUday Bondhugula     // Nulls get added to the worklist when operations are removed, ignore
5917932d21fSUday Bondhugula     // them.
5927932d21fSUday Bondhugula     if (op == nullptr)
5937932d21fSUday Bondhugula       continue;
5947932d21fSUday Bondhugula 
595633ad1d8SChia-hung Duan     assert((!strictMode || strictModeFilteredOps.contains(op)) &&
596633ad1d8SChia-hung Duan            "unexpected op was inserted under strict mode");
597633ad1d8SChia-hung Duan 
5987932d21fSUday Bondhugula     // If the operation is trivially dead - remove it.
5997932d21fSUday Bondhugula     if (isOpTriviallyDead(op)) {
6007932d21fSUday Bondhugula       notifyOperationRemoved(op);
6017932d21fSUday Bondhugula       op->erase();
6027932d21fSUday Bondhugula       changed = true;
6037932d21fSUday Bondhugula       continue;
6047932d21fSUday Bondhugula     }
6057932d21fSUday Bondhugula 
6067932d21fSUday Bondhugula     // Collects all the operands and result uses of the given `op` into work
6077932d21fSUday Bondhugula     // list. Also remove `op` and nested ops from worklist.
6087932d21fSUday Bondhugula     originalOperands.assign(op->operand_begin(), op->operand_end());
6097932d21fSUday Bondhugula     auto preReplaceAction = [&](Operation *op) {
6107932d21fSUday Bondhugula       // Add the operands to the worklist for visitation.
6117932d21fSUday Bondhugula       addOperandsToWorklist(originalOperands);
6127932d21fSUday Bondhugula 
6137932d21fSUday Bondhugula       // Add all the users of the result to the worklist so we make sure
6147932d21fSUday Bondhugula       // to revisit them.
615*ba3a9f51SChia-hung Duan       for (Value result : op->getResults()) {
616*ba3a9f51SChia-hung Duan         for (Operation *userOp : result.getUsers())
6177932d21fSUday Bondhugula           addToWorklist(userOp);
6187932d21fSUday Bondhugula       }
619*ba3a9f51SChia-hung Duan 
6207932d21fSUday Bondhugula       notifyOperationRemoved(op);
6217932d21fSUday Bondhugula     };
6227932d21fSUday Bondhugula 
6237932d21fSUday Bondhugula     // Add the given operation generated by the folder to the worklist.
6247932d21fSUday Bondhugula     auto processGeneratedConstants = [this](Operation *op) {
625*ba3a9f51SChia-hung Duan       notifyOperationInserted(op);
6267932d21fSUday Bondhugula     };
6277932d21fSUday Bondhugula 
6287932d21fSUday Bondhugula     // Try to fold this op.
6297932d21fSUday Bondhugula     bool inPlaceUpdate;
6307932d21fSUday Bondhugula     if (succeeded(folder.tryToFold(op, processGeneratedConstants,
6317932d21fSUday Bondhugula                                    preReplaceAction, &inPlaceUpdate))) {
6327932d21fSUday Bondhugula       changed = true;
6337932d21fSUday Bondhugula       if (!inPlaceUpdate) {
6347932d21fSUday Bondhugula         // Op has been erased.
6357932d21fSUday Bondhugula         continue;
6367932d21fSUday Bondhugula       }
6377932d21fSUday Bondhugula     }
6387932d21fSUday Bondhugula 
6397932d21fSUday Bondhugula     // Try to match one of the patterns. The rewriter is automatically
6407932d21fSUday Bondhugula     // notified of any necessary changes, so there is nothing else to do
6417932d21fSUday Bondhugula     // here.
6427932d21fSUday Bondhugula     changed |= succeeded(matcher.matchAndRewrite(op, *this));
6437932d21fSUday Bondhugula   }
6447932d21fSUday Bondhugula 
6457932d21fSUday Bondhugula   return changed;
6467932d21fSUday Bondhugula }
6477932d21fSUday Bondhugula 
64804b5274eSUday Bondhugula /// Rewrites only `op` using the supplied canonicalization patterns and
64904b5274eSUday Bondhugula /// folding. `erased` is set to true if the op is erased as a result of being
65004b5274eSUday Bondhugula /// folded, replaced, or dead.
applyOpPatternsAndFold(Operation * op,const FrozenRewritePatternSet & patterns,bool * erased)6513e98fbf4SRiver Riddle LogicalResult mlir::applyOpPatternsAndFold(
65279d7f618SChris Lattner     Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) {
65304b5274eSUday Bondhugula   // Start the pattern driver.
65464716b2cSChris Lattner   GreedyRewriteConfig config;
65504b5274eSUday Bondhugula   OpPatternRewriteDriver driver(op->getContext(), patterns);
65604b5274eSUday Bondhugula   bool opErased;
6573e98fbf4SRiver Riddle   LogicalResult converged =
65864716b2cSChris Lattner       driver.simplifyLocally(op, config.maxIterations, opErased);
65904b5274eSUday Bondhugula   if (erased)
66004b5274eSUday Bondhugula     *erased = opErased;
6613e98fbf4SRiver Riddle   LLVM_DEBUG(if (failed(converged)) {
66204b5274eSUday Bondhugula     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
66364716b2cSChris Lattner                  << config.maxIterations << " times";
66404b5274eSUday Bondhugula   });
66504b5274eSUday Bondhugula   return converged;
66604b5274eSUday Bondhugula }
6677932d21fSUday Bondhugula 
applyOpPatternsAndFold(ArrayRef<Operation * > ops,const FrozenRewritePatternSet & patterns,bool strict)6687932d21fSUday Bondhugula bool mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
6697932d21fSUday Bondhugula                                   const FrozenRewritePatternSet &patterns,
6707932d21fSUday Bondhugula                                   bool strict) {
6717932d21fSUday Bondhugula   if (ops.empty())
6727932d21fSUday Bondhugula     return false;
6737932d21fSUday Bondhugula 
6747932d21fSUday Bondhugula   // Start the pattern driver.
6757932d21fSUday Bondhugula   MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
6767932d21fSUday Bondhugula                                      strict);
6777932d21fSUday Bondhugula   return driver.simplifyLocally(ops);
6787932d21fSUday Bondhugula }
679