164d52014SChris Lattner //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
264d52014SChris Lattner //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
664d52014SChris Lattner //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
864d52014SChris Lattner //
9a60fdd2bSLorenzo Chelini // This file implements mlir::applyPatternsAndFoldGreedily.
1064d52014SChris Lattner //
1164d52014SChris Lattner //===----------------------------------------------------------------------===//
1264d52014SChris Lattner
13b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14af371f9fSRiver Riddle #include "mlir/IR/Matchers.h"
15eb623ae8SStephen Neuendorffer #include "mlir/Interfaces/SideEffectInterfaces.h"
16b6eb26fdSRiver Riddle #include "mlir/Rewrite/PatternApplicator.h"
171982afb1SRiver Riddle #include "mlir/Transforms/FoldUtils.h"
18fafb708bSRiver Riddle #include "mlir/Transforms/RegionUtils.h"
1964d52014SChris Lattner #include "llvm/ADT/DenseMap.h"
205c757087SFeng Liu #include "llvm/Support/CommandLine.h"
215c757087SFeng Liu #include "llvm/Support/Debug.h"
225652ecc3SRiver Riddle #include "llvm/Support/ScopedPrinter.h"
235c757087SFeng Liu #include "llvm/Support/raw_ostream.h"
244e40c832SLei Zhang
2564d52014SChris Lattner using namespace mlir;
2664d52014SChris Lattner
275652ecc3SRiver Riddle #define DEBUG_TYPE "greedy-rewriter"
285c757087SFeng Liu
2904b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
3004b5274eSUday Bondhugula // GreedyPatternRewriteDriver
3104b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
3204b5274eSUday Bondhugula
3364d52014SChris Lattner namespace {
3464d52014SChris Lattner /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
3564d52014SChris Lattner /// applies the locally optimal patterns in a roughly "bottom up" way.
364bd9f936SChris Lattner class GreedyPatternRewriteDriver : public PatternRewriter {
3764d52014SChris Lattner public:
382566a72aSRiver Riddle explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
39648f34a2SChris Lattner const FrozenRewritePatternSet &patterns,
40b7144ab7SRiver Riddle const GreedyRewriteConfig &config);
413e98fbf4SRiver Riddle
42b7144ab7SRiver Riddle /// Simplify the operations within the given regions.
4364716b2cSChris Lattner bool simplify(MutableArrayRef<Region> regions);
4464d52014SChris Lattner
45b7144ab7SRiver Riddle /// Add the given operation to the worklist.
46*ba3a9f51SChia-hung Duan virtual void addToWorklist(Operation *op);
475c4f1fddSRiver Riddle
48b7144ab7SRiver Riddle /// Pop the next operation from the worklist.
49b7144ab7SRiver Riddle Operation *popFromWorklist();
5064d52014SChris Lattner
51b7144ab7SRiver Riddle /// If the specified operation is in the worklist, remove it.
52b7144ab7SRiver Riddle void removeFromWorklist(Operation *op);
5364d52014SChris Lattner
544bd9f936SChris Lattner protected:
55851a8516SRiver Riddle // Implement the hook for inserting operations, and make sure that newly
56851a8516SRiver Riddle // inserted ops are added to the worklist for processing.
57b7144ab7SRiver Riddle void notifyOperationInserted(Operation *op) override;
5864d52014SChris Lattner
597932d21fSUday Bondhugula // Look over the provided operands for any defining operations that should
607932d21fSUday Bondhugula // be re-added to the worklist. This function should be called when an
617932d21fSUday Bondhugula // operation is modified or removed, as it may trigger further
627932d21fSUday Bondhugula // simplifications.
63*ba3a9f51SChia-hung Duan void addOperandsToWorklist(ValueRange operands);
647932d21fSUday Bondhugula
6564d52014SChris Lattner // If an operation is about to be removed, make sure it is not in our
6664d52014SChris Lattner // worklist anymore because we'd get dangling references to it.
67b7144ab7SRiver Riddle void notifyOperationRemoved(Operation *op) override;
6864d52014SChris Lattner
69085b687fSChris Lattner // When the root of a pattern is about to be replaced, it can trigger
70085b687fSChris Lattner // simplifications to its users - make sure to add them to the worklist
71085b687fSChris Lattner // before the root is changed.
72b7144ab7SRiver Riddle void notifyRootReplaced(Operation *op) override;
73085b687fSChris Lattner
745652ecc3SRiver Riddle /// PatternRewriter hook for erasing a dead operation.
755652ecc3SRiver Riddle void eraseOp(Operation *op) override;
765652ecc3SRiver Riddle
775652ecc3SRiver Riddle /// PatternRewriter hook for notifying match failure reasons.
785652ecc3SRiver Riddle LogicalResult
79ea64828aSRiver Riddle notifyMatchFailure(Location loc,
805652ecc3SRiver Riddle function_ref<void(Diagnostic &)> reasonCallback) override;
815652ecc3SRiver Riddle
823e98fbf4SRiver Riddle /// The low-level pattern applicator.
833e98fbf4SRiver Riddle PatternApplicator matcher;
844bd9f936SChris Lattner
854bd9f936SChris Lattner /// The worklist for this transformation keeps track of the operations that
864bd9f936SChris Lattner /// need to be revisited, plus their index in the worklist. This allows us to
87e7a2ef21SRiver Riddle /// efficiently remove operations from the worklist when they are erased, even
88e7a2ef21SRiver Riddle /// if they aren't the root of a pattern.
8999b87c97SRiver Riddle std::vector<Operation *> worklist;
9099b87c97SRiver Riddle DenseMap<Operation *, unsigned> worklistMap;
9160a29837SRiver Riddle
9260a29837SRiver Riddle /// Non-pattern based folder for operations.
9360a29837SRiver Riddle OperationFolder folder;
94648f34a2SChris Lattner
957932d21fSUday Bondhugula private:
9664716b2cSChris Lattner /// Configuration information for how to simplify.
9764716b2cSChris Lattner GreedyRewriteConfig config;
985652ecc3SRiver Riddle
995652ecc3SRiver Riddle #ifndef NDEBUG
1005652ecc3SRiver Riddle /// A logger used to emit information during the application process.
1015652ecc3SRiver Riddle llvm::ScopedPrinter logger{llvm::dbgs()};
1025652ecc3SRiver Riddle #endif
10364d52014SChris Lattner };
104be0a7e9fSMehdi Amini } // namespace
10564d52014SChris Lattner
GreedyPatternRewriteDriver(MLIRContext * ctx,const FrozenRewritePatternSet & patterns,const GreedyRewriteConfig & config)106b7144ab7SRiver Riddle GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
107b7144ab7SRiver Riddle MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
108b7144ab7SRiver Riddle const GreedyRewriteConfig &config)
109b7144ab7SRiver Riddle : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
110b7144ab7SRiver Riddle worklist.reserve(64);
111b7144ab7SRiver Riddle
112b7144ab7SRiver Riddle // Apply a simple cost model based solely on pattern benefit.
113b7144ab7SRiver Riddle matcher.applyDefaultCostModel();
114b7144ab7SRiver Riddle }
115b7144ab7SRiver Riddle
simplify(MutableArrayRef<Region> regions)11664716b2cSChris Lattner bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
1175652ecc3SRiver Riddle #ifndef NDEBUG
1185652ecc3SRiver Riddle const char *logLineComment =
1195652ecc3SRiver Riddle "//===-------------------------------------------===//\n";
1205652ecc3SRiver Riddle
1215652ecc3SRiver Riddle /// A utility function to log a process result for the given reason.
1225652ecc3SRiver Riddle auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
1235652ecc3SRiver Riddle logger.unindent();
1245652ecc3SRiver Riddle logger.startLine() << "} -> " << result;
1255652ecc3SRiver Riddle if (!msg.isTriviallyEmpty())
1265652ecc3SRiver Riddle logger.getOStream() << " : " << msg;
1275652ecc3SRiver Riddle logger.getOStream() << "\n";
1285652ecc3SRiver Riddle };
1295652ecc3SRiver Riddle auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
1305652ecc3SRiver Riddle logResult(result, msg);
1315652ecc3SRiver Riddle logger.startLine() << logLineComment;
1325652ecc3SRiver Riddle };
1335652ecc3SRiver Riddle #endif
1345652ecc3SRiver Riddle
1357814b559Srkayaith auto insertKnownConstant = [&](Operation *op) {
1367814b559Srkayaith // Check for existing constants when populating the worklist. This avoids
1377814b559Srkayaith // accidentally reversing the constant order during processing.
1387814b559Srkayaith Attribute constValue;
1397814b559Srkayaith if (matchPattern(op, m_Constant(&constValue)))
1407814b559Srkayaith if (!folder.insertKnownConstant(op, constValue))
1417814b559Srkayaith return true;
1427814b559Srkayaith return false;
1437814b559Srkayaith };
1447814b559Srkayaith
1455c757087SFeng Liu bool changed = false;
14664716b2cSChris Lattner unsigned iteration = 0;
1475c757087SFeng Liu do {
148648f34a2SChris Lattner worklist.clear();
149648f34a2SChris Lattner worklistMap.clear();
150648f34a2SChris Lattner
15164716b2cSChris Lattner if (!config.useTopDownTraversal) {
15264716b2cSChris Lattner // Add operations to the worklist in postorder.
153af371f9fSRiver Riddle for (auto ®ion : regions) {
1547814b559Srkayaith region.walk([&](Operation *op) {
1557814b559Srkayaith if (!insertKnownConstant(op))
156af371f9fSRiver Riddle addToWorklist(op);
157af371f9fSRiver Riddle });
158af371f9fSRiver Riddle }
15964716b2cSChris Lattner } else {
160648f34a2SChris Lattner // Add all nested operations to the worklist in preorder.
161aa568e08SRiver Riddle for (auto ®ion : regions) {
1627814b559Srkayaith region.walk<WalkOrder::PreOrder>([&](Operation *op) {
163aa568e08SRiver Riddle if (!insertKnownConstant(op)) {
1647814b559Srkayaith worklist.push_back(op);
165aa568e08SRiver Riddle return WalkResult::advance();
166aa568e08SRiver Riddle }
167aa568e08SRiver Riddle return WalkResult::skip();
1687814b559Srkayaith });
169aa568e08SRiver Riddle }
170648f34a2SChris Lattner
171648f34a2SChris Lattner // Reverse the list so our pop-back loop processes them in-order.
172648f34a2SChris Lattner std::reverse(worklist.begin(), worklist.end());
173648f34a2SChris Lattner // Remember the reverse index.
174648f34a2SChris Lattner for (size_t i = 0, e = worklist.size(); i != e; ++i)
175648f34a2SChris Lattner worklistMap[worklist[i]] = i;
176648f34a2SChris Lattner }
1774e40c832SLei Zhang
1784e40c832SLei Zhang // These are scratch vectors used in the folding loop below.
179e62a6956SRiver Riddle SmallVector<Value, 8> originalOperands, resultValues;
18064d52014SChris Lattner
1815c757087SFeng Liu changed = false;
18264d52014SChris Lattner while (!worklist.empty()) {
18364d52014SChris Lattner auto *op = popFromWorklist();
18464d52014SChris Lattner
1855c757087SFeng Liu // Nulls get added to the worklist when operations are removed, ignore
1865c757087SFeng Liu // them.
18764d52014SChris Lattner if (op == nullptr)
18864d52014SChris Lattner continue;
18964d52014SChris Lattner
1905652ecc3SRiver Riddle LLVM_DEBUG({
1915652ecc3SRiver Riddle logger.getOStream() << "\n";
1925652ecc3SRiver Riddle logger.startLine() << logLineComment;
1935652ecc3SRiver Riddle logger.startLine() << "Processing operation : '" << op->getName()
1945652ecc3SRiver Riddle << "'(" << op << ") {\n";
1955652ecc3SRiver Riddle logger.indent();
1965652ecc3SRiver Riddle
1975652ecc3SRiver Riddle // If the operation has no regions, just print it here.
1985652ecc3SRiver Riddle if (op->getNumRegions() == 0) {
1995652ecc3SRiver Riddle op->print(
2005652ecc3SRiver Riddle logger.startLine(),
2015652ecc3SRiver Riddle OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
2025652ecc3SRiver Riddle logger.getOStream() << "\n\n";
2035652ecc3SRiver Riddle }
2045652ecc3SRiver Riddle });
2055652ecc3SRiver Riddle
2060ddba0bdSRiver Riddle // If the operation is trivially dead - remove it.
2070ddba0bdSRiver Riddle if (isOpTriviallyDead(op)) {
2086a501e3dSAndy Ly notifyOperationRemoved(op);
20964d52014SChris Lattner op->erase();
210f875e55bSUday Bondhugula changed = true;
2115652ecc3SRiver Riddle
2125652ecc3SRiver Riddle LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
21364d52014SChris Lattner continue;
21464d52014SChris Lattner }
21564d52014SChris Lattner
2164e40c832SLei Zhang // Collects all the operands and result uses of the given `op` into work
2176a501e3dSAndy Ly // list. Also remove `op` and nested ops from worklist.
2181982afb1SRiver Riddle originalOperands.assign(op->operand_begin(), op->operand_end());
2196a501e3dSAndy Ly auto preReplaceAction = [&](Operation *op) {
220a8866258SRiver Riddle // Add the operands to the worklist for visitation.
221*ba3a9f51SChia-hung Duan addOperandsToWorklist(originalOperands);
2221982afb1SRiver Riddle
2234e40c832SLei Zhang // Add all the users of the result to the worklist so we make sure
2244e40c832SLei Zhang // to revisit them.
22535807bc4SRiver Riddle for (auto result : op->getResults())
226cc673894SUday Bondhugula for (auto *userOp : result.getUsers())
227cc673894SUday Bondhugula addToWorklist(userOp);
2286a501e3dSAndy Ly
2296a501e3dSAndy Ly notifyOperationRemoved(op);
2304e40c832SLei Zhang };
23164d52014SChris Lattner
232648f34a2SChris Lattner // Add the given operation to the worklist.
233648f34a2SChris Lattner auto collectOps = [this](Operation *op) { addToWorklist(op); };
234648f34a2SChris Lattner
2351982afb1SRiver Riddle // Try to fold this op.
236cbcb12fdSUday Bondhugula bool inPlaceUpdate;
237cbcb12fdSUday Bondhugula if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
238cbcb12fdSUday Bondhugula &inPlaceUpdate)))) {
2395652ecc3SRiver Riddle LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
2405652ecc3SRiver Riddle
241f875e55bSUday Bondhugula changed = true;
242cbcb12fdSUday Bondhugula if (!inPlaceUpdate)
243934b6d12SChris Lattner continue;
24464d52014SChris Lattner }
24564d52014SChris Lattner
24632052c84SRiver Riddle // Try to match one of the patterns. The rewriter is automatically
247648f34a2SChris Lattner // notified of any necessary changes, so there is nothing else to do
248648f34a2SChris Lattner // here.
2495652ecc3SRiver Riddle #ifndef NDEBUG
2505652ecc3SRiver Riddle auto canApply = [&](const Pattern &pattern) {
2515652ecc3SRiver Riddle LLVM_DEBUG({
2525652ecc3SRiver Riddle logger.getOStream() << "\n";
2535652ecc3SRiver Riddle logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
2545652ecc3SRiver Riddle << op->getName() << " -> (";
2555652ecc3SRiver Riddle llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
2565652ecc3SRiver Riddle logger.getOStream() << ")' {\n";
2575652ecc3SRiver Riddle logger.indent();
2585652ecc3SRiver Riddle });
2595652ecc3SRiver Riddle return true;
2605652ecc3SRiver Riddle };
2615652ecc3SRiver Riddle auto onFailure = [&](const Pattern &pattern) {
2625652ecc3SRiver Riddle LLVM_DEBUG(logResult("failure", "pattern failed to match"));
2635652ecc3SRiver Riddle };
2645652ecc3SRiver Riddle auto onSuccess = [&](const Pattern &pattern) {
2655652ecc3SRiver Riddle LLVM_DEBUG(logResult("success", "pattern applied successfully"));
2665652ecc3SRiver Riddle return success();
2675652ecc3SRiver Riddle };
2685652ecc3SRiver Riddle
2695652ecc3SRiver Riddle LogicalResult matchResult =
2705652ecc3SRiver Riddle matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
2715652ecc3SRiver Riddle if (succeeded(matchResult))
2725652ecc3SRiver Riddle LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
2735652ecc3SRiver Riddle else
2745652ecc3SRiver Riddle LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
2755652ecc3SRiver Riddle #else
2765652ecc3SRiver Riddle LogicalResult matchResult = matcher.matchAndRewrite(op, *this);
2775652ecc3SRiver Riddle #endif
2785652ecc3SRiver Riddle changed |= succeeded(matchResult);
27964d52014SChris Lattner }
280a32f0dcbSRiver Riddle
281648f34a2SChris Lattner // After applying patterns, make sure that the CFG of each of the regions
282648f34a2SChris Lattner // is kept up to date.
28364716b2cSChris Lattner if (config.enableRegionSimplification)
284d75a611aSRiver Riddle changed |= succeeded(simplifyRegions(*this, regions));
285519663beSFrederik Gossen } while (changed &&
286673e9828SFrederik Gossen (iteration++ < config.maxIterations ||
287519663beSFrederik Gossen config.maxIterations == GreedyRewriteConfig::kNoIterationLimit));
288648f34a2SChris Lattner
2895c757087SFeng Liu // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
2905c757087SFeng Liu return !changed;
29164d52014SChris Lattner }
29264d52014SChris Lattner
addToWorklist(Operation * op)293b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
294b7144ab7SRiver Riddle // Check to see if the worklist already contains this op.
295b7144ab7SRiver Riddle if (worklistMap.count(op))
296b7144ab7SRiver Riddle return;
297b7144ab7SRiver Riddle
298b7144ab7SRiver Riddle worklistMap[op] = worklist.size();
299b7144ab7SRiver Riddle worklist.push_back(op);
300b7144ab7SRiver Riddle }
301b7144ab7SRiver Riddle
popFromWorklist()302b7144ab7SRiver Riddle Operation *GreedyPatternRewriteDriver::popFromWorklist() {
303b7144ab7SRiver Riddle auto *op = worklist.back();
304b7144ab7SRiver Riddle worklist.pop_back();
305b7144ab7SRiver Riddle
306b7144ab7SRiver Riddle // This operation is no longer in the worklist, keep worklistMap up to date.
307b7144ab7SRiver Riddle if (op)
308b7144ab7SRiver Riddle worklistMap.erase(op);
309b7144ab7SRiver Riddle return op;
310b7144ab7SRiver Riddle }
311b7144ab7SRiver Riddle
removeFromWorklist(Operation * op)312b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::removeFromWorklist(Operation *op) {
313b7144ab7SRiver Riddle auto it = worklistMap.find(op);
314b7144ab7SRiver Riddle if (it != worklistMap.end()) {
315b7144ab7SRiver Riddle assert(worklist[it->second] == op && "malformed worklist data structure");
316b7144ab7SRiver Riddle worklist[it->second] = nullptr;
317b7144ab7SRiver Riddle worklistMap.erase(it);
318b7144ab7SRiver Riddle }
319b7144ab7SRiver Riddle }
320b7144ab7SRiver Riddle
notifyOperationInserted(Operation * op)321b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
3225652ecc3SRiver Riddle LLVM_DEBUG({
3235652ecc3SRiver Riddle logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
3245652ecc3SRiver Riddle << ")\n";
3255652ecc3SRiver Riddle });
326b7144ab7SRiver Riddle addToWorklist(op);
327b7144ab7SRiver Riddle }
328b7144ab7SRiver Riddle
addOperandsToWorklist(ValueRange operands)329*ba3a9f51SChia-hung Duan void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) {
330b7144ab7SRiver Riddle for (Value operand : operands) {
331b7144ab7SRiver Riddle // If the use count of this operand is now < 2, we re-add the defining
332b7144ab7SRiver Riddle // operation to the worklist.
333b7144ab7SRiver Riddle // TODO: This is based on the fact that zero use operations
334b7144ab7SRiver Riddle // may be deleted, and that single use values often have more
335b7144ab7SRiver Riddle // canonicalization opportunities.
336b7144ab7SRiver Riddle if (!operand || (!operand.use_empty() && !operand.hasOneUse()))
337b7144ab7SRiver Riddle continue;
338b7144ab7SRiver Riddle if (auto *defOp = operand.getDefiningOp())
339b7144ab7SRiver Riddle addToWorklist(defOp);
340b7144ab7SRiver Riddle }
341b7144ab7SRiver Riddle }
342b7144ab7SRiver Riddle
notifyOperationRemoved(Operation * op)343b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
344*ba3a9f51SChia-hung Duan addOperandsToWorklist(op->getOperands());
345b7144ab7SRiver Riddle op->walk([this](Operation *operation) {
346b7144ab7SRiver Riddle removeFromWorklist(operation);
347b7144ab7SRiver Riddle folder.notifyRemoval(operation);
348b7144ab7SRiver Riddle });
349b7144ab7SRiver Riddle }
350b7144ab7SRiver Riddle
notifyRootReplaced(Operation * op)351b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op) {
3525652ecc3SRiver Riddle LLVM_DEBUG({
3535652ecc3SRiver Riddle logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
3545652ecc3SRiver Riddle << ")\n";
3555652ecc3SRiver Riddle });
356b7144ab7SRiver Riddle for (auto result : op->getResults())
357b7144ab7SRiver Riddle for (auto *user : result.getUsers())
358b7144ab7SRiver Riddle addToWorklist(user);
359b7144ab7SRiver Riddle }
360b7144ab7SRiver Riddle
eraseOp(Operation * op)3615652ecc3SRiver Riddle void GreedyPatternRewriteDriver::eraseOp(Operation *op) {
3625652ecc3SRiver Riddle LLVM_DEBUG({
3635652ecc3SRiver Riddle logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
3645652ecc3SRiver Riddle << ")\n";
3655652ecc3SRiver Riddle });
3665652ecc3SRiver Riddle PatternRewriter::eraseOp(op);
3675652ecc3SRiver Riddle }
3685652ecc3SRiver Riddle
notifyMatchFailure(Location loc,function_ref<void (Diagnostic &)> reasonCallback)3695652ecc3SRiver Riddle LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
370ea64828aSRiver Riddle Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
3715652ecc3SRiver Riddle LLVM_DEBUG({
372ea64828aSRiver Riddle Diagnostic diag(loc, DiagnosticSeverity::Remark);
3735652ecc3SRiver Riddle reasonCallback(diag);
3745652ecc3SRiver Riddle logger.startLine() << "** Failure : " << diag.str() << "\n";
3755652ecc3SRiver Riddle });
3765652ecc3SRiver Riddle return failure();
3775652ecc3SRiver Riddle }
3785652ecc3SRiver Riddle
379e7a2ef21SRiver Riddle /// Rewrite the regions of the specified operation, which must be isolated from
380e7a2ef21SRiver Riddle /// above, by repeatedly applying the highest benefit patterns in a greedy
3813e98fbf4SRiver Riddle /// work-list driven manner. Return success if no more patterns can be matched
3823e98fbf4SRiver Riddle /// in the result operation regions. Note: This does not apply patterns to the
3833e98fbf4SRiver Riddle /// top-level operation itself.
38464d52014SChris Lattner ///
3853e98fbf4SRiver Riddle LogicalResult
applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,const FrozenRewritePatternSet & patterns,GreedyRewriteConfig config)3863e98fbf4SRiver Riddle mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
3870b20413eSUday Bondhugula const FrozenRewritePatternSet &patterns,
38864716b2cSChris Lattner GreedyRewriteConfig config) {
3896b1cc3c6SRiver Riddle if (regions.empty())
3903e98fbf4SRiver Riddle return success();
3916b1cc3c6SRiver Riddle
392e7a2ef21SRiver Riddle // The top-level operation must be known to be isolated from above to
393e7a2ef21SRiver Riddle // prevent performing canonicalizations on operations defined at or above
394e7a2ef21SRiver Riddle // the region containing 'op'.
3956b1cc3c6SRiver Riddle auto regionIsIsolated = [](Region ®ion) {
396fe7c0d90SRiver Riddle return region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>();
3976b1cc3c6SRiver Riddle };
3986b1cc3c6SRiver Riddle (void)regionIsIsolated;
3996b1cc3c6SRiver Riddle assert(llvm::all_of(regions, regionIsIsolated) &&
4006b1cc3c6SRiver Riddle "patterns can only be applied to operations IsolatedFromAbove");
401e7a2ef21SRiver Riddle
4026b1cc3c6SRiver Riddle // Start the pattern driver.
40364716b2cSChris Lattner GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config);
40464716b2cSChris Lattner bool converged = driver.simplify(regions);
4055c757087SFeng Liu LLVM_DEBUG(if (!converged) {
406e7a2ef21SRiver Riddle llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
40764716b2cSChris Lattner << config.maxIterations << " times\n";
4085c757087SFeng Liu });
4093e98fbf4SRiver Riddle return success(converged);
41064d52014SChris Lattner }
41104b5274eSUday Bondhugula
41204b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
41304b5274eSUday Bondhugula // OpPatternRewriteDriver
41404b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
41504b5274eSUday Bondhugula
41604b5274eSUday Bondhugula namespace {
41704b5274eSUday Bondhugula /// This is a simple driver for the PatternMatcher to apply patterns and perform
41804b5274eSUday Bondhugula /// folding on a single op. It repeatedly applies locally optimal patterns.
41904b5274eSUday Bondhugula class OpPatternRewriteDriver : public PatternRewriter {
42004b5274eSUday Bondhugula public:
OpPatternRewriteDriver(MLIRContext * ctx,const FrozenRewritePatternSet & patterns)42104b5274eSUday Bondhugula explicit OpPatternRewriteDriver(MLIRContext *ctx,
42279d7f618SChris Lattner const FrozenRewritePatternSet &patterns)
4233e98fbf4SRiver Riddle : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
4243e98fbf4SRiver Riddle // Apply a simple cost model based solely on pattern benefit.
4253e98fbf4SRiver Riddle matcher.applyDefaultCostModel();
4263e98fbf4SRiver Riddle }
42704b5274eSUday Bondhugula
4283e98fbf4SRiver Riddle LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased);
42904b5274eSUday Bondhugula
43004b5274eSUday Bondhugula // These are hooks implemented for PatternRewriter.
43104b5274eSUday Bondhugula protected:
43204b5274eSUday Bondhugula /// If an operation is about to be removed, mark it so that we can let clients
43304b5274eSUday Bondhugula /// know.
notifyOperationRemoved(Operation * op)43404b5274eSUday Bondhugula void notifyOperationRemoved(Operation *op) override {
43504b5274eSUday Bondhugula opErasedViaPatternRewrites = true;
43604b5274eSUday Bondhugula }
43704b5274eSUday Bondhugula
43804b5274eSUday Bondhugula // When a root is going to be replaced, its removal will be notified as well.
43904b5274eSUday Bondhugula // So there is nothing to do here.
notifyRootReplaced(Operation * op)44004b5274eSUday Bondhugula void notifyRootReplaced(Operation *op) override {}
44104b5274eSUday Bondhugula
44204b5274eSUday Bondhugula private:
4433e98fbf4SRiver Riddle /// The low-level pattern applicator.
4443e98fbf4SRiver Riddle PatternApplicator matcher;
44504b5274eSUday Bondhugula
44604b5274eSUday Bondhugula /// Non-pattern based folder for operations.
44704b5274eSUday Bondhugula OperationFolder folder;
44804b5274eSUday Bondhugula
44904b5274eSUday Bondhugula /// Set to true if the operation has been erased via pattern rewrites.
45004b5274eSUday Bondhugula bool opErasedViaPatternRewrites = false;
45104b5274eSUday Bondhugula };
45204b5274eSUday Bondhugula
453be0a7e9fSMehdi Amini } // namespace
45404b5274eSUday Bondhugula
4557932d21fSUday Bondhugula /// Performs the rewrites and folding only on `op`. The simplification
4567932d21fSUday Bondhugula /// converges if the op is erased as a result of being folded, replaced, or
4577932d21fSUday Bondhugula /// becoming dead, or no more changes happen in an iteration. Returns success if
4587932d21fSUday Bondhugula /// the rewrite converges in `maxIterations`. `erased` is set to true if `op`
4597932d21fSUday Bondhugula /// gets erased.
simplifyLocally(Operation * op,int maxIterations,bool & erased)4603e98fbf4SRiver Riddle LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
4613e98fbf4SRiver Riddle int maxIterations,
46204b5274eSUday Bondhugula bool &erased) {
46304b5274eSUday Bondhugula bool changed = false;
46404b5274eSUday Bondhugula erased = false;
46504b5274eSUday Bondhugula opErasedViaPatternRewrites = false;
4667932d21fSUday Bondhugula int iterations = 0;
46704b5274eSUday Bondhugula // Iterate until convergence or until maxIterations. Deletion of the op as
46804b5274eSUday Bondhugula // a result of being dead or folded is convergence.
46904b5274eSUday Bondhugula do {
470ff87c4d3SChristian Sigg changed = false;
471ff87c4d3SChristian Sigg
47204b5274eSUday Bondhugula // If the operation is trivially dead - remove it.
47304b5274eSUday Bondhugula if (isOpTriviallyDead(op)) {
47404b5274eSUday Bondhugula op->erase();
47504b5274eSUday Bondhugula erased = true;
4763e98fbf4SRiver Riddle return success();
47704b5274eSUday Bondhugula }
47804b5274eSUday Bondhugula
47904b5274eSUday Bondhugula // Try to fold this op.
48004b5274eSUday Bondhugula bool inPlaceUpdate;
48104b5274eSUday Bondhugula if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr,
48204b5274eSUday Bondhugula /*preReplaceAction=*/nullptr,
48304b5274eSUday Bondhugula &inPlaceUpdate))) {
48404b5274eSUday Bondhugula changed = true;
48504b5274eSUday Bondhugula if (!inPlaceUpdate) {
48604b5274eSUday Bondhugula erased = true;
4873e98fbf4SRiver Riddle return success();
48804b5274eSUday Bondhugula }
48904b5274eSUday Bondhugula }
49004b5274eSUday Bondhugula
49104b5274eSUday Bondhugula // Try to match one of the patterns. The rewriter is automatically
49204b5274eSUday Bondhugula // notified of any necessary changes, so there is nothing else to do here.
4933e98fbf4SRiver Riddle changed |= succeeded(matcher.matchAndRewrite(op, *this));
49404b5274eSUday Bondhugula if ((erased = opErasedViaPatternRewrites))
4953e98fbf4SRiver Riddle return success();
496519663beSFrederik Gossen } while (changed &&
497519663beSFrederik Gossen (++iterations < maxIterations ||
498519663beSFrederik Gossen maxIterations == GreedyRewriteConfig::kNoIterationLimit));
49904b5274eSUday Bondhugula
50004b5274eSUday Bondhugula // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
5013e98fbf4SRiver Riddle return failure(changed);
50204b5274eSUday Bondhugula }
50304b5274eSUday Bondhugula
5047932d21fSUday Bondhugula //===----------------------------------------------------------------------===//
5057932d21fSUday Bondhugula // MultiOpPatternRewriteDriver
5067932d21fSUday Bondhugula //===----------------------------------------------------------------------===//
5077932d21fSUday Bondhugula
5087932d21fSUday Bondhugula namespace {
5097932d21fSUday Bondhugula
5107932d21fSUday Bondhugula /// This is a specialized GreedyPatternRewriteDriver to apply patterns and
5117932d21fSUday Bondhugula /// perform folding for a supplied set of ops. It repeatedly simplifies while
5127932d21fSUday Bondhugula /// restricting the rewrites to only the provided set of ops or optionally
5137932d21fSUday Bondhugula /// to those directly affected by it (result users or operand providers).
5147932d21fSUday Bondhugula class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
5157932d21fSUday Bondhugula public:
MultiOpPatternRewriteDriver(MLIRContext * ctx,const FrozenRewritePatternSet & patterns,bool strict)5167932d21fSUday Bondhugula explicit MultiOpPatternRewriteDriver(MLIRContext *ctx,
5177932d21fSUday Bondhugula const FrozenRewritePatternSet &patterns,
5187932d21fSUday Bondhugula bool strict)
5197932d21fSUday Bondhugula : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()),
5207932d21fSUday Bondhugula strictMode(strict) {}
5217932d21fSUday Bondhugula
5227932d21fSUday Bondhugula bool simplifyLocally(ArrayRef<Operation *> op);
5237932d21fSUday Bondhugula
addToWorklist(Operation * op)524*ba3a9f51SChia-hung Duan void addToWorklist(Operation *op) override {
525*ba3a9f51SChia-hung Duan if (!strictMode || strictModeFilteredOps.contains(op))
526*ba3a9f51SChia-hung Duan GreedyPatternRewriteDriver::addToWorklist(op);
5277932d21fSUday Bondhugula }
5287932d21fSUday Bondhugula
529*ba3a9f51SChia-hung Duan private:
notifyOperationInserted(Operation * op)5302aeffc6dSChia-hung Duan void notifyOperationInserted(Operation *op) override {
5312aeffc6dSChia-hung Duan GreedyPatternRewriteDriver::notifyOperationInserted(op);
5322aeffc6dSChia-hung Duan if (strictMode)
5332aeffc6dSChia-hung Duan strictModeFilteredOps.insert(op);
5342aeffc6dSChia-hung Duan }
5352aeffc6dSChia-hung Duan
notifyOperationRemoved(Operation * op)5367932d21fSUday Bondhugula void notifyOperationRemoved(Operation *op) override {
5377932d21fSUday Bondhugula GreedyPatternRewriteDriver::notifyOperationRemoved(op);
5387932d21fSUday Bondhugula if (strictMode)
5397932d21fSUday Bondhugula strictModeFilteredOps.erase(op);
5407932d21fSUday Bondhugula }
5417932d21fSUday Bondhugula
5427932d21fSUday Bondhugula /// If `strictMode` is true, any pre-existing ops outside of
5437932d21fSUday Bondhugula /// `strictModeFilteredOps` remain completely untouched by the rewrite driver.
5447932d21fSUday Bondhugula /// If `strictMode` is false, operations that use results of (or supply
5457932d21fSUday Bondhugula /// operands to) any rewritten ops stemming from the simplification of the
5467932d21fSUday Bondhugula /// provided ops are in turn simplified; any other ops still remain untouched
5477932d21fSUday Bondhugula /// (i.e., regardless of `strictMode`).
5487932d21fSUday Bondhugula bool strictMode = false;
5497932d21fSUday Bondhugula
5507932d21fSUday Bondhugula /// The list of ops we are restricting our rewrites to if `strictMode` is on.
5517932d21fSUday Bondhugula /// These include the supplied set of ops as well as new ops created while
5527932d21fSUday Bondhugula /// rewriting those ops. This set is not maintained when strictMode is off.
5537932d21fSUday Bondhugula llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
5547932d21fSUday Bondhugula };
5557932d21fSUday Bondhugula
556be0a7e9fSMehdi Amini } // namespace
5577932d21fSUday Bondhugula
5587932d21fSUday Bondhugula /// Performs the specified rewrites on `ops` while also trying to fold these ops
5597932d21fSUday Bondhugula /// as well as any other ops that were in turn created due to these rewrite
5607932d21fSUday Bondhugula /// patterns. Any pre-existing ops outside of `ops` remain completely
5617932d21fSUday Bondhugula /// unmodified if `strictMode` is true. If `strictMode` is false, other
5627932d21fSUday Bondhugula /// operations that use results of rewritten ops or supply operands to such ops
5637932d21fSUday Bondhugula /// are in turn simplified; any other ops still remain unmodified (i.e.,
5647932d21fSUday Bondhugula /// regardless of `strictMode`). Note that ops in `ops` could be erased as a
5657932d21fSUday Bondhugula /// result of folding, becoming dead, or via pattern rewrites. Returns true if
5667932d21fSUday Bondhugula /// at all any changes happened.
5677932d21fSUday Bondhugula // Unlike `OpPatternRewriteDriver::simplifyLocally` which works on a single op
5687932d21fSUday Bondhugula // or GreedyPatternRewriteDriver::simplify, this method just iterates until
5697932d21fSUday Bondhugula // the worklist is empty. As our objective is to keep simplification "local",
5707932d21fSUday Bondhugula // there is no strong rationale to re-add all operations into the worklist and
5717932d21fSUday Bondhugula // rerun until an iteration changes nothing. If more widereaching simplification
5727932d21fSUday Bondhugula // is desired, GreedyPatternRewriteDriver should be used.
simplifyLocally(ArrayRef<Operation * > ops)5737932d21fSUday Bondhugula bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
5747932d21fSUday Bondhugula if (strictMode) {
5757932d21fSUday Bondhugula strictModeFilteredOps.clear();
5767932d21fSUday Bondhugula strictModeFilteredOps.insert(ops.begin(), ops.end());
5777932d21fSUday Bondhugula }
5787932d21fSUday Bondhugula
5797932d21fSUday Bondhugula bool changed = false;
5807932d21fSUday Bondhugula worklist.clear();
5817932d21fSUday Bondhugula worklistMap.clear();
5827932d21fSUday Bondhugula for (Operation *op : ops)
5837932d21fSUday Bondhugula addToWorklist(op);
5847932d21fSUday Bondhugula
5857932d21fSUday Bondhugula // These are scratch vectors used in the folding loop below.
5867932d21fSUday Bondhugula SmallVector<Value, 8> originalOperands, resultValues;
5877932d21fSUday Bondhugula while (!worklist.empty()) {
5887932d21fSUday Bondhugula Operation *op = popFromWorklist();
5897932d21fSUday Bondhugula
5907932d21fSUday Bondhugula // Nulls get added to the worklist when operations are removed, ignore
5917932d21fSUday Bondhugula // them.
5927932d21fSUday Bondhugula if (op == nullptr)
5937932d21fSUday Bondhugula continue;
5947932d21fSUday Bondhugula
595633ad1d8SChia-hung Duan assert((!strictMode || strictModeFilteredOps.contains(op)) &&
596633ad1d8SChia-hung Duan "unexpected op was inserted under strict mode");
597633ad1d8SChia-hung Duan
5987932d21fSUday Bondhugula // If the operation is trivially dead - remove it.
5997932d21fSUday Bondhugula if (isOpTriviallyDead(op)) {
6007932d21fSUday Bondhugula notifyOperationRemoved(op);
6017932d21fSUday Bondhugula op->erase();
6027932d21fSUday Bondhugula changed = true;
6037932d21fSUday Bondhugula continue;
6047932d21fSUday Bondhugula }
6057932d21fSUday Bondhugula
6067932d21fSUday Bondhugula // Collects all the operands and result uses of the given `op` into work
6077932d21fSUday Bondhugula // list. Also remove `op` and nested ops from worklist.
6087932d21fSUday Bondhugula originalOperands.assign(op->operand_begin(), op->operand_end());
6097932d21fSUday Bondhugula auto preReplaceAction = [&](Operation *op) {
6107932d21fSUday Bondhugula // Add the operands to the worklist for visitation.
6117932d21fSUday Bondhugula addOperandsToWorklist(originalOperands);
6127932d21fSUday Bondhugula
6137932d21fSUday Bondhugula // Add all the users of the result to the worklist so we make sure
6147932d21fSUday Bondhugula // to revisit them.
615*ba3a9f51SChia-hung Duan for (Value result : op->getResults()) {
616*ba3a9f51SChia-hung Duan for (Operation *userOp : result.getUsers())
6177932d21fSUday Bondhugula addToWorklist(userOp);
6187932d21fSUday Bondhugula }
619*ba3a9f51SChia-hung Duan
6207932d21fSUday Bondhugula notifyOperationRemoved(op);
6217932d21fSUday Bondhugula };
6227932d21fSUday Bondhugula
6237932d21fSUday Bondhugula // Add the given operation generated by the folder to the worklist.
6247932d21fSUday Bondhugula auto processGeneratedConstants = [this](Operation *op) {
625*ba3a9f51SChia-hung Duan notifyOperationInserted(op);
6267932d21fSUday Bondhugula };
6277932d21fSUday Bondhugula
6287932d21fSUday Bondhugula // Try to fold this op.
6297932d21fSUday Bondhugula bool inPlaceUpdate;
6307932d21fSUday Bondhugula if (succeeded(folder.tryToFold(op, processGeneratedConstants,
6317932d21fSUday Bondhugula preReplaceAction, &inPlaceUpdate))) {
6327932d21fSUday Bondhugula changed = true;
6337932d21fSUday Bondhugula if (!inPlaceUpdate) {
6347932d21fSUday Bondhugula // Op has been erased.
6357932d21fSUday Bondhugula continue;
6367932d21fSUday Bondhugula }
6377932d21fSUday Bondhugula }
6387932d21fSUday Bondhugula
6397932d21fSUday Bondhugula // Try to match one of the patterns. The rewriter is automatically
6407932d21fSUday Bondhugula // notified of any necessary changes, so there is nothing else to do
6417932d21fSUday Bondhugula // here.
6427932d21fSUday Bondhugula changed |= succeeded(matcher.matchAndRewrite(op, *this));
6437932d21fSUday Bondhugula }
6447932d21fSUday Bondhugula
6457932d21fSUday Bondhugula return changed;
6467932d21fSUday Bondhugula }
6477932d21fSUday Bondhugula
64804b5274eSUday Bondhugula /// Rewrites only `op` using the supplied canonicalization patterns and
64904b5274eSUday Bondhugula /// folding. `erased` is set to true if the op is erased as a result of being
65004b5274eSUday Bondhugula /// folded, replaced, or dead.
applyOpPatternsAndFold(Operation * op,const FrozenRewritePatternSet & patterns,bool * erased)6513e98fbf4SRiver Riddle LogicalResult mlir::applyOpPatternsAndFold(
65279d7f618SChris Lattner Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) {
65304b5274eSUday Bondhugula // Start the pattern driver.
65464716b2cSChris Lattner GreedyRewriteConfig config;
65504b5274eSUday Bondhugula OpPatternRewriteDriver driver(op->getContext(), patterns);
65604b5274eSUday Bondhugula bool opErased;
6573e98fbf4SRiver Riddle LogicalResult converged =
65864716b2cSChris Lattner driver.simplifyLocally(op, config.maxIterations, opErased);
65904b5274eSUday Bondhugula if (erased)
66004b5274eSUday Bondhugula *erased = opErased;
6613e98fbf4SRiver Riddle LLVM_DEBUG(if (failed(converged)) {
66204b5274eSUday Bondhugula llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
66364716b2cSChris Lattner << config.maxIterations << " times";
66404b5274eSUday Bondhugula });
66504b5274eSUday Bondhugula return converged;
66604b5274eSUday Bondhugula }
6677932d21fSUday Bondhugula
applyOpPatternsAndFold(ArrayRef<Operation * > ops,const FrozenRewritePatternSet & patterns,bool strict)6687932d21fSUday Bondhugula bool mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
6697932d21fSUday Bondhugula const FrozenRewritePatternSet &patterns,
6707932d21fSUday Bondhugula bool strict) {
6717932d21fSUday Bondhugula if (ops.empty())
6727932d21fSUday Bondhugula return false;
6737932d21fSUday Bondhugula
6747932d21fSUday Bondhugula // Start the pattern driver.
6757932d21fSUday Bondhugula MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
6767932d21fSUday Bondhugula strict);
6777932d21fSUday Bondhugula return driver.simplifyLocally(ops);
6787932d21fSUday Bondhugula }
679