164d52014SChris Lattner //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===// 264d52014SChris Lattner // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 664d52014SChris Lattner // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 864d52014SChris Lattner // 9a60fdd2bSLorenzo Chelini // This file implements mlir::applyPatternsAndFoldGreedily. 1064d52014SChris Lattner // 1164d52014SChris Lattner //===----------------------------------------------------------------------===// 1264d52014SChris Lattner 13b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 14af371f9fSRiver Riddle #include "mlir/IR/Matchers.h" 15eb623ae8SStephen Neuendorffer #include "mlir/Interfaces/SideEffectInterfaces.h" 16b6eb26fdSRiver Riddle #include "mlir/Rewrite/PatternApplicator.h" 171982afb1SRiver Riddle #include "mlir/Transforms/FoldUtils.h" 18fafb708bSRiver Riddle #include "mlir/Transforms/RegionUtils.h" 1964d52014SChris Lattner #include "llvm/ADT/DenseMap.h" 205c757087SFeng Liu #include "llvm/Support/CommandLine.h" 215c757087SFeng Liu #include "llvm/Support/Debug.h" 225652ecc3SRiver Riddle #include "llvm/Support/ScopedPrinter.h" 235c757087SFeng Liu #include "llvm/Support/raw_ostream.h" 244e40c832SLei Zhang 2564d52014SChris Lattner using namespace mlir; 2664d52014SChris Lattner 275652ecc3SRiver Riddle #define DEBUG_TYPE "greedy-rewriter" 285c757087SFeng Liu 2904b5274eSUday Bondhugula //===----------------------------------------------------------------------===// 3004b5274eSUday Bondhugula // GreedyPatternRewriteDriver 3104b5274eSUday Bondhugula //===----------------------------------------------------------------------===// 3204b5274eSUday Bondhugula 3364d52014SChris Lattner namespace { 3464d52014SChris Lattner /// This is a worklist-driven driver for the PatternMatcher, which repeatedly 3564d52014SChris Lattner /// applies the locally optimal patterns in a roughly "bottom up" way. 364bd9f936SChris Lattner class GreedyPatternRewriteDriver : public PatternRewriter { 3764d52014SChris Lattner public: 382566a72aSRiver Riddle explicit GreedyPatternRewriteDriver(MLIRContext *ctx, 39648f34a2SChris Lattner const FrozenRewritePatternSet &patterns, 40b7144ab7SRiver Riddle const GreedyRewriteConfig &config); 413e98fbf4SRiver Riddle 42b7144ab7SRiver Riddle /// Simplify the operations within the given regions. 4364716b2cSChris Lattner bool simplify(MutableArrayRef<Region> regions); 4464d52014SChris Lattner 45b7144ab7SRiver Riddle /// Add the given operation to the worklist. 46b7144ab7SRiver Riddle void addToWorklist(Operation *op); 475c4f1fddSRiver Riddle 48b7144ab7SRiver Riddle /// Pop the next operation from the worklist. 49b7144ab7SRiver Riddle Operation *popFromWorklist(); 5064d52014SChris Lattner 51b7144ab7SRiver Riddle /// If the specified operation is in the worklist, remove it. 52b7144ab7SRiver Riddle void removeFromWorklist(Operation *op); 5364d52014SChris Lattner 544bd9f936SChris Lattner protected: 55851a8516SRiver Riddle // Implement the hook for inserting operations, and make sure that newly 56851a8516SRiver Riddle // inserted ops are added to the worklist for processing. 57b7144ab7SRiver Riddle void notifyOperationInserted(Operation *op) override; 5864d52014SChris Lattner 597932d21fSUday Bondhugula // Look over the provided operands for any defining operations that should 607932d21fSUday Bondhugula // be re-added to the worklist. This function should be called when an 617932d21fSUday Bondhugula // operation is modified or removed, as it may trigger further 627932d21fSUday Bondhugula // simplifications. 637932d21fSUday Bondhugula template <typename Operands> 64b7144ab7SRiver Riddle void addToWorklist(Operands &&operands); 657932d21fSUday Bondhugula 6664d52014SChris Lattner // If an operation is about to be removed, make sure it is not in our 6764d52014SChris Lattner // worklist anymore because we'd get dangling references to it. 68b7144ab7SRiver Riddle void notifyOperationRemoved(Operation *op) override; 6964d52014SChris Lattner 70085b687fSChris Lattner // When the root of a pattern is about to be replaced, it can trigger 71085b687fSChris Lattner // simplifications to its users - make sure to add them to the worklist 72085b687fSChris Lattner // before the root is changed. 73b7144ab7SRiver Riddle void notifyRootReplaced(Operation *op) override; 74085b687fSChris Lattner 755652ecc3SRiver Riddle /// PatternRewriter hook for erasing a dead operation. 765652ecc3SRiver Riddle void eraseOp(Operation *op) override; 775652ecc3SRiver Riddle 785652ecc3SRiver Riddle /// PatternRewriter hook for notifying match failure reasons. 795652ecc3SRiver Riddle LogicalResult 80ea64828aSRiver Riddle notifyMatchFailure(Location loc, 815652ecc3SRiver Riddle function_ref<void(Diagnostic &)> reasonCallback) override; 825652ecc3SRiver Riddle 833e98fbf4SRiver Riddle /// The low-level pattern applicator. 843e98fbf4SRiver Riddle PatternApplicator matcher; 854bd9f936SChris Lattner 864bd9f936SChris Lattner /// The worklist for this transformation keeps track of the operations that 874bd9f936SChris Lattner /// need to be revisited, plus their index in the worklist. This allows us to 88e7a2ef21SRiver Riddle /// efficiently remove operations from the worklist when they are erased, even 89e7a2ef21SRiver Riddle /// if they aren't the root of a pattern. 9099b87c97SRiver Riddle std::vector<Operation *> worklist; 9199b87c97SRiver Riddle DenseMap<Operation *, unsigned> worklistMap; 9260a29837SRiver Riddle 9360a29837SRiver Riddle /// Non-pattern based folder for operations. 9460a29837SRiver Riddle OperationFolder folder; 95648f34a2SChris Lattner 967932d21fSUday Bondhugula private: 9764716b2cSChris Lattner /// Configuration information for how to simplify. 9864716b2cSChris Lattner GreedyRewriteConfig config; 995652ecc3SRiver Riddle 1005652ecc3SRiver Riddle #ifndef NDEBUG 1015652ecc3SRiver Riddle /// A logger used to emit information during the application process. 1025652ecc3SRiver Riddle llvm::ScopedPrinter logger{llvm::dbgs()}; 1035652ecc3SRiver Riddle #endif 10464d52014SChris Lattner }; 105be0a7e9fSMehdi Amini } // namespace 10664d52014SChris Lattner 107b7144ab7SRiver Riddle GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( 108b7144ab7SRiver Riddle MLIRContext *ctx, const FrozenRewritePatternSet &patterns, 109b7144ab7SRiver Riddle const GreedyRewriteConfig &config) 110b7144ab7SRiver Riddle : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) { 111b7144ab7SRiver Riddle worklist.reserve(64); 112b7144ab7SRiver Riddle 113b7144ab7SRiver Riddle // Apply a simple cost model based solely on pattern benefit. 114b7144ab7SRiver Riddle matcher.applyDefaultCostModel(); 115b7144ab7SRiver Riddle } 116b7144ab7SRiver Riddle 11764716b2cSChris Lattner bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) { 1185652ecc3SRiver Riddle #ifndef NDEBUG 1195652ecc3SRiver Riddle const char *logLineComment = 1205652ecc3SRiver Riddle "//===-------------------------------------------===//\n"; 1215652ecc3SRiver Riddle 1225652ecc3SRiver Riddle /// A utility function to log a process result for the given reason. 1235652ecc3SRiver Riddle auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) { 1245652ecc3SRiver Riddle logger.unindent(); 1255652ecc3SRiver Riddle logger.startLine() << "} -> " << result; 1265652ecc3SRiver Riddle if (!msg.isTriviallyEmpty()) 1275652ecc3SRiver Riddle logger.getOStream() << " : " << msg; 1285652ecc3SRiver Riddle logger.getOStream() << "\n"; 1295652ecc3SRiver Riddle }; 1305652ecc3SRiver Riddle auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) { 1315652ecc3SRiver Riddle logResult(result, msg); 1325652ecc3SRiver Riddle logger.startLine() << logLineComment; 1335652ecc3SRiver Riddle }; 1345652ecc3SRiver Riddle #endif 1355652ecc3SRiver Riddle 1367814b559Srkayaith auto insertKnownConstant = [&](Operation *op) { 1377814b559Srkayaith // Check for existing constants when populating the worklist. This avoids 1387814b559Srkayaith // accidentally reversing the constant order during processing. 1397814b559Srkayaith Attribute constValue; 1407814b559Srkayaith if (matchPattern(op, m_Constant(&constValue))) 1417814b559Srkayaith if (!folder.insertKnownConstant(op, constValue)) 1427814b559Srkayaith return true; 1437814b559Srkayaith return false; 1447814b559Srkayaith }; 1457814b559Srkayaith 1465c757087SFeng Liu bool changed = false; 14764716b2cSChris Lattner unsigned iteration = 0; 1485c757087SFeng Liu do { 149648f34a2SChris Lattner worklist.clear(); 150648f34a2SChris Lattner worklistMap.clear(); 151648f34a2SChris Lattner 15264716b2cSChris Lattner if (!config.useTopDownTraversal) { 15364716b2cSChris Lattner // Add operations to the worklist in postorder. 154af371f9fSRiver Riddle for (auto ®ion : regions) { 1557814b559Srkayaith region.walk([&](Operation *op) { 1567814b559Srkayaith if (!insertKnownConstant(op)) 157af371f9fSRiver Riddle addToWorklist(op); 158af371f9fSRiver Riddle }); 159af371f9fSRiver Riddle } 16064716b2cSChris Lattner } else { 161648f34a2SChris Lattner // Add all nested operations to the worklist in preorder. 162aa568e08SRiver Riddle for (auto ®ion : regions) { 1637814b559Srkayaith region.walk<WalkOrder::PreOrder>([&](Operation *op) { 164aa568e08SRiver Riddle if (!insertKnownConstant(op)) { 1657814b559Srkayaith worklist.push_back(op); 166aa568e08SRiver Riddle return WalkResult::advance(); 167aa568e08SRiver Riddle } 168aa568e08SRiver Riddle return WalkResult::skip(); 1697814b559Srkayaith }); 170aa568e08SRiver Riddle } 171648f34a2SChris Lattner 172648f34a2SChris Lattner // Reverse the list so our pop-back loop processes them in-order. 173648f34a2SChris Lattner std::reverse(worklist.begin(), worklist.end()); 174648f34a2SChris Lattner // Remember the reverse index. 175648f34a2SChris Lattner for (size_t i = 0, e = worklist.size(); i != e; ++i) 176648f34a2SChris Lattner worklistMap[worklist[i]] = i; 177648f34a2SChris Lattner } 1784e40c832SLei Zhang 1794e40c832SLei Zhang // These are scratch vectors used in the folding loop below. 180e62a6956SRiver Riddle SmallVector<Value, 8> originalOperands, resultValues; 18164d52014SChris Lattner 1825c757087SFeng Liu changed = false; 18364d52014SChris Lattner while (!worklist.empty()) { 18464d52014SChris Lattner auto *op = popFromWorklist(); 18564d52014SChris Lattner 1865c757087SFeng Liu // Nulls get added to the worklist when operations are removed, ignore 1875c757087SFeng Liu // them. 18864d52014SChris Lattner if (op == nullptr) 18964d52014SChris Lattner continue; 19064d52014SChris Lattner 1915652ecc3SRiver Riddle LLVM_DEBUG({ 1925652ecc3SRiver Riddle logger.getOStream() << "\n"; 1935652ecc3SRiver Riddle logger.startLine() << logLineComment; 1945652ecc3SRiver Riddle logger.startLine() << "Processing operation : '" << op->getName() 1955652ecc3SRiver Riddle << "'(" << op << ") {\n"; 1965652ecc3SRiver Riddle logger.indent(); 1975652ecc3SRiver Riddle 1985652ecc3SRiver Riddle // If the operation has no regions, just print it here. 1995652ecc3SRiver Riddle if (op->getNumRegions() == 0) { 2005652ecc3SRiver Riddle op->print( 2015652ecc3SRiver Riddle logger.startLine(), 2025652ecc3SRiver Riddle OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs()); 2035652ecc3SRiver Riddle logger.getOStream() << "\n\n"; 2045652ecc3SRiver Riddle } 2055652ecc3SRiver Riddle }); 2065652ecc3SRiver Riddle 2070ddba0bdSRiver Riddle // If the operation is trivially dead - remove it. 2080ddba0bdSRiver Riddle if (isOpTriviallyDead(op)) { 2096a501e3dSAndy Ly notifyOperationRemoved(op); 21064d52014SChris Lattner op->erase(); 211f875e55bSUday Bondhugula changed = true; 2125652ecc3SRiver Riddle 2135652ecc3SRiver Riddle LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead")); 21464d52014SChris Lattner continue; 21564d52014SChris Lattner } 21664d52014SChris Lattner 2174e40c832SLei Zhang // Collects all the operands and result uses of the given `op` into work 2186a501e3dSAndy Ly // list. Also remove `op` and nested ops from worklist. 2191982afb1SRiver Riddle originalOperands.assign(op->operand_begin(), op->operand_end()); 2206a501e3dSAndy Ly auto preReplaceAction = [&](Operation *op) { 221a8866258SRiver Riddle // Add the operands to the worklist for visitation. 2221982afb1SRiver Riddle addToWorklist(originalOperands); 2231982afb1SRiver Riddle 2244e40c832SLei Zhang // Add all the users of the result to the worklist so we make sure 2254e40c832SLei Zhang // to revisit them. 22635807bc4SRiver Riddle for (auto result : op->getResults()) 227cc673894SUday Bondhugula for (auto *userOp : result.getUsers()) 228cc673894SUday Bondhugula addToWorklist(userOp); 2296a501e3dSAndy Ly 2306a501e3dSAndy Ly notifyOperationRemoved(op); 2314e40c832SLei Zhang }; 23264d52014SChris Lattner 233648f34a2SChris Lattner // Add the given operation to the worklist. 234648f34a2SChris Lattner auto collectOps = [this](Operation *op) { addToWorklist(op); }; 235648f34a2SChris Lattner 2361982afb1SRiver Riddle // Try to fold this op. 237cbcb12fdSUday Bondhugula bool inPlaceUpdate; 238cbcb12fdSUday Bondhugula if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction, 239cbcb12fdSUday Bondhugula &inPlaceUpdate)))) { 2405652ecc3SRiver Riddle LLVM_DEBUG(logResultWithLine("success", "operation was folded")); 2415652ecc3SRiver Riddle 242f875e55bSUday Bondhugula changed = true; 243cbcb12fdSUday Bondhugula if (!inPlaceUpdate) 244934b6d12SChris Lattner continue; 24564d52014SChris Lattner } 24664d52014SChris Lattner 24732052c84SRiver Riddle // Try to match one of the patterns. The rewriter is automatically 248648f34a2SChris Lattner // notified of any necessary changes, so there is nothing else to do 249648f34a2SChris Lattner // here. 2505652ecc3SRiver Riddle #ifndef NDEBUG 2515652ecc3SRiver Riddle auto canApply = [&](const Pattern &pattern) { 2525652ecc3SRiver Riddle LLVM_DEBUG({ 2535652ecc3SRiver Riddle logger.getOStream() << "\n"; 2545652ecc3SRiver Riddle logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '" 2555652ecc3SRiver Riddle << op->getName() << " -> ("; 2565652ecc3SRiver Riddle llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream()); 2575652ecc3SRiver Riddle logger.getOStream() << ")' {\n"; 2585652ecc3SRiver Riddle logger.indent(); 2595652ecc3SRiver Riddle }); 2605652ecc3SRiver Riddle return true; 2615652ecc3SRiver Riddle }; 2625652ecc3SRiver Riddle auto onFailure = [&](const Pattern &pattern) { 2635652ecc3SRiver Riddle LLVM_DEBUG(logResult("failure", "pattern failed to match")); 2645652ecc3SRiver Riddle }; 2655652ecc3SRiver Riddle auto onSuccess = [&](const Pattern &pattern) { 2665652ecc3SRiver Riddle LLVM_DEBUG(logResult("success", "pattern applied successfully")); 2675652ecc3SRiver Riddle return success(); 2685652ecc3SRiver Riddle }; 2695652ecc3SRiver Riddle 2705652ecc3SRiver Riddle LogicalResult matchResult = 2715652ecc3SRiver Riddle matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess); 2725652ecc3SRiver Riddle if (succeeded(matchResult)) 2735652ecc3SRiver Riddle LLVM_DEBUG(logResultWithLine("success", "pattern matched")); 2745652ecc3SRiver Riddle else 2755652ecc3SRiver Riddle LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match")); 2765652ecc3SRiver Riddle #else 2775652ecc3SRiver Riddle LogicalResult matchResult = matcher.matchAndRewrite(op, *this); 2785652ecc3SRiver Riddle #endif 2795652ecc3SRiver Riddle changed |= succeeded(matchResult); 28064d52014SChris Lattner } 281a32f0dcbSRiver Riddle 282648f34a2SChris Lattner // After applying patterns, make sure that the CFG of each of the regions 283648f34a2SChris Lattner // is kept up to date. 28464716b2cSChris Lattner if (config.enableRegionSimplification) 285d75a611aSRiver Riddle changed |= succeeded(simplifyRegions(*this, regions)); 286519663beSFrederik Gossen } while (changed && 287673e9828SFrederik Gossen (iteration++ < config.maxIterations || 288519663beSFrederik Gossen config.maxIterations == GreedyRewriteConfig::kNoIterationLimit)); 289648f34a2SChris Lattner 2905c757087SFeng Liu // Whether the rewrite converges, i.e. wasn't changed in the last iteration. 2915c757087SFeng Liu return !changed; 29264d52014SChris Lattner } 29364d52014SChris Lattner 294b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { 295b7144ab7SRiver Riddle // Check to see if the worklist already contains this op. 296b7144ab7SRiver Riddle if (worklistMap.count(op)) 297b7144ab7SRiver Riddle return; 298b7144ab7SRiver Riddle 299b7144ab7SRiver Riddle worklistMap[op] = worklist.size(); 300b7144ab7SRiver Riddle worklist.push_back(op); 301b7144ab7SRiver Riddle } 302b7144ab7SRiver Riddle 303b7144ab7SRiver Riddle Operation *GreedyPatternRewriteDriver::popFromWorklist() { 304b7144ab7SRiver Riddle auto *op = worklist.back(); 305b7144ab7SRiver Riddle worklist.pop_back(); 306b7144ab7SRiver Riddle 307b7144ab7SRiver Riddle // This operation is no longer in the worklist, keep worklistMap up to date. 308b7144ab7SRiver Riddle if (op) 309b7144ab7SRiver Riddle worklistMap.erase(op); 310b7144ab7SRiver Riddle return op; 311b7144ab7SRiver Riddle } 312b7144ab7SRiver Riddle 313b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::removeFromWorklist(Operation *op) { 314b7144ab7SRiver Riddle auto it = worklistMap.find(op); 315b7144ab7SRiver Riddle if (it != worklistMap.end()) { 316b7144ab7SRiver Riddle assert(worklist[it->second] == op && "malformed worklist data structure"); 317b7144ab7SRiver Riddle worklist[it->second] = nullptr; 318b7144ab7SRiver Riddle worklistMap.erase(it); 319b7144ab7SRiver Riddle } 320b7144ab7SRiver Riddle } 321b7144ab7SRiver Riddle 322b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) { 3235652ecc3SRiver Riddle LLVM_DEBUG({ 3245652ecc3SRiver Riddle logger.startLine() << "** Insert : '" << op->getName() << "'(" << op 3255652ecc3SRiver Riddle << ")\n"; 3265652ecc3SRiver Riddle }); 327b7144ab7SRiver Riddle addToWorklist(op); 328b7144ab7SRiver Riddle } 329b7144ab7SRiver Riddle 330b7144ab7SRiver Riddle template <typename Operands> 331b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::addToWorklist(Operands &&operands) { 332b7144ab7SRiver Riddle for (Value operand : operands) { 333b7144ab7SRiver Riddle // If the use count of this operand is now < 2, we re-add the defining 334b7144ab7SRiver Riddle // operation to the worklist. 335b7144ab7SRiver Riddle // TODO: This is based on the fact that zero use operations 336b7144ab7SRiver Riddle // may be deleted, and that single use values often have more 337b7144ab7SRiver Riddle // canonicalization opportunities. 338b7144ab7SRiver Riddle if (!operand || (!operand.use_empty() && !operand.hasOneUse())) 339b7144ab7SRiver Riddle continue; 340b7144ab7SRiver Riddle if (auto *defOp = operand.getDefiningOp()) 341b7144ab7SRiver Riddle addToWorklist(defOp); 342b7144ab7SRiver Riddle } 343b7144ab7SRiver Riddle } 344b7144ab7SRiver Riddle 345b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) { 346b7144ab7SRiver Riddle addToWorklist(op->getOperands()); 347b7144ab7SRiver Riddle op->walk([this](Operation *operation) { 348b7144ab7SRiver Riddle removeFromWorklist(operation); 349b7144ab7SRiver Riddle folder.notifyRemoval(operation); 350b7144ab7SRiver Riddle }); 351b7144ab7SRiver Riddle } 352b7144ab7SRiver Riddle 353b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op) { 3545652ecc3SRiver Riddle LLVM_DEBUG({ 3555652ecc3SRiver Riddle logger.startLine() << "** Replace : '" << op->getName() << "'(" << op 3565652ecc3SRiver Riddle << ")\n"; 3575652ecc3SRiver Riddle }); 358b7144ab7SRiver Riddle for (auto result : op->getResults()) 359b7144ab7SRiver Riddle for (auto *user : result.getUsers()) 360b7144ab7SRiver Riddle addToWorklist(user); 361b7144ab7SRiver Riddle } 362b7144ab7SRiver Riddle 3635652ecc3SRiver Riddle void GreedyPatternRewriteDriver::eraseOp(Operation *op) { 3645652ecc3SRiver Riddle LLVM_DEBUG({ 3655652ecc3SRiver Riddle logger.startLine() << "** Erase : '" << op->getName() << "'(" << op 3665652ecc3SRiver Riddle << ")\n"; 3675652ecc3SRiver Riddle }); 3685652ecc3SRiver Riddle PatternRewriter::eraseOp(op); 3695652ecc3SRiver Riddle } 3705652ecc3SRiver Riddle 3715652ecc3SRiver Riddle LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure( 372ea64828aSRiver Riddle Location loc, function_ref<void(Diagnostic &)> reasonCallback) { 3735652ecc3SRiver Riddle LLVM_DEBUG({ 374ea64828aSRiver Riddle Diagnostic diag(loc, DiagnosticSeverity::Remark); 3755652ecc3SRiver Riddle reasonCallback(diag); 3765652ecc3SRiver Riddle logger.startLine() << "** Failure : " << diag.str() << "\n"; 3775652ecc3SRiver Riddle }); 3785652ecc3SRiver Riddle return failure(); 3795652ecc3SRiver Riddle } 3805652ecc3SRiver Riddle 381e7a2ef21SRiver Riddle /// Rewrite the regions of the specified operation, which must be isolated from 382e7a2ef21SRiver Riddle /// above, by repeatedly applying the highest benefit patterns in a greedy 3833e98fbf4SRiver Riddle /// work-list driven manner. Return success if no more patterns can be matched 3843e98fbf4SRiver Riddle /// in the result operation regions. Note: This does not apply patterns to the 3853e98fbf4SRiver Riddle /// top-level operation itself. 38664d52014SChris Lattner /// 3873e98fbf4SRiver Riddle LogicalResult 3883e98fbf4SRiver Riddle mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions, 3890b20413eSUday Bondhugula const FrozenRewritePatternSet &patterns, 39064716b2cSChris Lattner GreedyRewriteConfig config) { 3916b1cc3c6SRiver Riddle if (regions.empty()) 3923e98fbf4SRiver Riddle return success(); 3936b1cc3c6SRiver Riddle 394e7a2ef21SRiver Riddle // The top-level operation must be known to be isolated from above to 395e7a2ef21SRiver Riddle // prevent performing canonicalizations on operations defined at or above 396e7a2ef21SRiver Riddle // the region containing 'op'. 3976b1cc3c6SRiver Riddle auto regionIsIsolated = [](Region ®ion) { 398fe7c0d90SRiver Riddle return region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>(); 3996b1cc3c6SRiver Riddle }; 4006b1cc3c6SRiver Riddle (void)regionIsIsolated; 4016b1cc3c6SRiver Riddle assert(llvm::all_of(regions, regionIsIsolated) && 4026b1cc3c6SRiver Riddle "patterns can only be applied to operations IsolatedFromAbove"); 403e7a2ef21SRiver Riddle 4046b1cc3c6SRiver Riddle // Start the pattern driver. 40564716b2cSChris Lattner GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config); 40664716b2cSChris Lattner bool converged = driver.simplify(regions); 4075c757087SFeng Liu LLVM_DEBUG(if (!converged) { 408e7a2ef21SRiver Riddle llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " 40964716b2cSChris Lattner << config.maxIterations << " times\n"; 4105c757087SFeng Liu }); 4113e98fbf4SRiver Riddle return success(converged); 41264d52014SChris Lattner } 41304b5274eSUday Bondhugula 41404b5274eSUday Bondhugula //===----------------------------------------------------------------------===// 41504b5274eSUday Bondhugula // OpPatternRewriteDriver 41604b5274eSUday Bondhugula //===----------------------------------------------------------------------===// 41704b5274eSUday Bondhugula 41804b5274eSUday Bondhugula namespace { 41904b5274eSUday Bondhugula /// This is a simple driver for the PatternMatcher to apply patterns and perform 42004b5274eSUday Bondhugula /// folding on a single op. It repeatedly applies locally optimal patterns. 42104b5274eSUday Bondhugula class OpPatternRewriteDriver : public PatternRewriter { 42204b5274eSUday Bondhugula public: 42304b5274eSUday Bondhugula explicit OpPatternRewriteDriver(MLIRContext *ctx, 42479d7f618SChris Lattner const FrozenRewritePatternSet &patterns) 4253e98fbf4SRiver Riddle : PatternRewriter(ctx), matcher(patterns), folder(ctx) { 4263e98fbf4SRiver Riddle // Apply a simple cost model based solely on pattern benefit. 4273e98fbf4SRiver Riddle matcher.applyDefaultCostModel(); 4283e98fbf4SRiver Riddle } 42904b5274eSUday Bondhugula 4303e98fbf4SRiver Riddle LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased); 43104b5274eSUday Bondhugula 43204b5274eSUday Bondhugula // These are hooks implemented for PatternRewriter. 43304b5274eSUday Bondhugula protected: 43404b5274eSUday Bondhugula /// If an operation is about to be removed, mark it so that we can let clients 43504b5274eSUday Bondhugula /// know. 43604b5274eSUday Bondhugula void notifyOperationRemoved(Operation *op) override { 43704b5274eSUday Bondhugula opErasedViaPatternRewrites = true; 43804b5274eSUday Bondhugula } 43904b5274eSUday Bondhugula 44004b5274eSUday Bondhugula // When a root is going to be replaced, its removal will be notified as well. 44104b5274eSUday Bondhugula // So there is nothing to do here. 44204b5274eSUday Bondhugula void notifyRootReplaced(Operation *op) override {} 44304b5274eSUday Bondhugula 44404b5274eSUday Bondhugula private: 4453e98fbf4SRiver Riddle /// The low-level pattern applicator. 4463e98fbf4SRiver Riddle PatternApplicator matcher; 44704b5274eSUday Bondhugula 44804b5274eSUday Bondhugula /// Non-pattern based folder for operations. 44904b5274eSUday Bondhugula OperationFolder folder; 45004b5274eSUday Bondhugula 45104b5274eSUday Bondhugula /// Set to true if the operation has been erased via pattern rewrites. 45204b5274eSUday Bondhugula bool opErasedViaPatternRewrites = false; 45304b5274eSUday Bondhugula }; 45404b5274eSUday Bondhugula 455be0a7e9fSMehdi Amini } // namespace 45604b5274eSUday Bondhugula 4577932d21fSUday Bondhugula /// Performs the rewrites and folding only on `op`. The simplification 4587932d21fSUday Bondhugula /// converges if the op is erased as a result of being folded, replaced, or 4597932d21fSUday Bondhugula /// becoming dead, or no more changes happen in an iteration. Returns success if 4607932d21fSUday Bondhugula /// the rewrite converges in `maxIterations`. `erased` is set to true if `op` 4617932d21fSUday Bondhugula /// gets erased. 4623e98fbf4SRiver Riddle LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op, 4633e98fbf4SRiver Riddle int maxIterations, 46404b5274eSUday Bondhugula bool &erased) { 46504b5274eSUday Bondhugula bool changed = false; 46604b5274eSUday Bondhugula erased = false; 46704b5274eSUday Bondhugula opErasedViaPatternRewrites = false; 4687932d21fSUday Bondhugula int iterations = 0; 46904b5274eSUday Bondhugula // Iterate until convergence or until maxIterations. Deletion of the op as 47004b5274eSUday Bondhugula // a result of being dead or folded is convergence. 47104b5274eSUday Bondhugula do { 472ff87c4d3SChristian Sigg changed = false; 473ff87c4d3SChristian Sigg 47404b5274eSUday Bondhugula // If the operation is trivially dead - remove it. 47504b5274eSUday Bondhugula if (isOpTriviallyDead(op)) { 47604b5274eSUday Bondhugula op->erase(); 47704b5274eSUday Bondhugula erased = true; 4783e98fbf4SRiver Riddle return success(); 47904b5274eSUday Bondhugula } 48004b5274eSUday Bondhugula 48104b5274eSUday Bondhugula // Try to fold this op. 48204b5274eSUday Bondhugula bool inPlaceUpdate; 48304b5274eSUday Bondhugula if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr, 48404b5274eSUday Bondhugula /*preReplaceAction=*/nullptr, 48504b5274eSUday Bondhugula &inPlaceUpdate))) { 48604b5274eSUday Bondhugula changed = true; 48704b5274eSUday Bondhugula if (!inPlaceUpdate) { 48804b5274eSUday Bondhugula erased = true; 4893e98fbf4SRiver Riddle return success(); 49004b5274eSUday Bondhugula } 49104b5274eSUday Bondhugula } 49204b5274eSUday Bondhugula 49304b5274eSUday Bondhugula // Try to match one of the patterns. The rewriter is automatically 49404b5274eSUday Bondhugula // notified of any necessary changes, so there is nothing else to do here. 4953e98fbf4SRiver Riddle changed |= succeeded(matcher.matchAndRewrite(op, *this)); 49604b5274eSUday Bondhugula if ((erased = opErasedViaPatternRewrites)) 4973e98fbf4SRiver Riddle return success(); 498519663beSFrederik Gossen } while (changed && 499519663beSFrederik Gossen (++iterations < maxIterations || 500519663beSFrederik Gossen maxIterations == GreedyRewriteConfig::kNoIterationLimit)); 50104b5274eSUday Bondhugula 50204b5274eSUday Bondhugula // Whether the rewrite converges, i.e. wasn't changed in the last iteration. 5033e98fbf4SRiver Riddle return failure(changed); 50404b5274eSUday Bondhugula } 50504b5274eSUday Bondhugula 5067932d21fSUday Bondhugula //===----------------------------------------------------------------------===// 5077932d21fSUday Bondhugula // MultiOpPatternRewriteDriver 5087932d21fSUday Bondhugula //===----------------------------------------------------------------------===// 5097932d21fSUday Bondhugula 5107932d21fSUday Bondhugula namespace { 5117932d21fSUday Bondhugula 5127932d21fSUday Bondhugula /// This is a specialized GreedyPatternRewriteDriver to apply patterns and 5137932d21fSUday Bondhugula /// perform folding for a supplied set of ops. It repeatedly simplifies while 5147932d21fSUday Bondhugula /// restricting the rewrites to only the provided set of ops or optionally 5157932d21fSUday Bondhugula /// to those directly affected by it (result users or operand providers). 5167932d21fSUday Bondhugula class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { 5177932d21fSUday Bondhugula public: 5187932d21fSUday Bondhugula explicit MultiOpPatternRewriteDriver(MLIRContext *ctx, 5197932d21fSUday Bondhugula const FrozenRewritePatternSet &patterns, 5207932d21fSUday Bondhugula bool strict) 5217932d21fSUday Bondhugula : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()), 5227932d21fSUday Bondhugula strictMode(strict) {} 5237932d21fSUday Bondhugula 5247932d21fSUday Bondhugula bool simplifyLocally(ArrayRef<Operation *> op); 5257932d21fSUday Bondhugula 5267932d21fSUday Bondhugula private: 5277932d21fSUday Bondhugula // Look over the provided operands for any defining operations that should 5287932d21fSUday Bondhugula // be re-added to the worklist. This function should be called when an 5297932d21fSUday Bondhugula // operation is modified or removed, as it may trigger further 5307932d21fSUday Bondhugula // simplifications. If `strict` is set to true, only ops in 5317932d21fSUday Bondhugula // `strictModeFilteredOps` are considered. 5327932d21fSUday Bondhugula template <typename Operands> 5337932d21fSUday Bondhugula void addOperandsToWorklist(Operands &&operands) { 5347932d21fSUday Bondhugula for (Value operand : operands) { 5357932d21fSUday Bondhugula if (auto *defOp = operand.getDefiningOp()) { 5367932d21fSUday Bondhugula if (!strictMode || strictModeFilteredOps.contains(defOp)) 5377932d21fSUday Bondhugula addToWorklist(defOp); 5387932d21fSUday Bondhugula } 5397932d21fSUday Bondhugula } 5407932d21fSUday Bondhugula } 5417932d21fSUday Bondhugula 542*2aeffc6dSChia-hung Duan void notifyOperationInserted(Operation *op) override { 543*2aeffc6dSChia-hung Duan GreedyPatternRewriteDriver::notifyOperationInserted(op); 544*2aeffc6dSChia-hung Duan if (strictMode) 545*2aeffc6dSChia-hung Duan strictModeFilteredOps.insert(op); 546*2aeffc6dSChia-hung Duan } 547*2aeffc6dSChia-hung Duan 5487932d21fSUday Bondhugula void notifyOperationRemoved(Operation *op) override { 5497932d21fSUday Bondhugula GreedyPatternRewriteDriver::notifyOperationRemoved(op); 5507932d21fSUday Bondhugula if (strictMode) 5517932d21fSUday Bondhugula strictModeFilteredOps.erase(op); 5527932d21fSUday Bondhugula } 5537932d21fSUday Bondhugula 554*2aeffc6dSChia-hung Duan void notifyRootReplaced(Operation *op) override { 555*2aeffc6dSChia-hung Duan for (auto result : op->getResults()) { 556*2aeffc6dSChia-hung Duan for (auto *user : result.getUsers()) { 557*2aeffc6dSChia-hung Duan if (!strictMode || strictModeFilteredOps.contains(user)) 558*2aeffc6dSChia-hung Duan addToWorklist(user); 559*2aeffc6dSChia-hung Duan } 560*2aeffc6dSChia-hung Duan } 561*2aeffc6dSChia-hung Duan } 562*2aeffc6dSChia-hung Duan 5637932d21fSUday Bondhugula /// If `strictMode` is true, any pre-existing ops outside of 5647932d21fSUday Bondhugula /// `strictModeFilteredOps` remain completely untouched by the rewrite driver. 5657932d21fSUday Bondhugula /// If `strictMode` is false, operations that use results of (or supply 5667932d21fSUday Bondhugula /// operands to) any rewritten ops stemming from the simplification of the 5677932d21fSUday Bondhugula /// provided ops are in turn simplified; any other ops still remain untouched 5687932d21fSUday Bondhugula /// (i.e., regardless of `strictMode`). 5697932d21fSUday Bondhugula bool strictMode = false; 5707932d21fSUday Bondhugula 5717932d21fSUday Bondhugula /// The list of ops we are restricting our rewrites to if `strictMode` is on. 5727932d21fSUday Bondhugula /// These include the supplied set of ops as well as new ops created while 5737932d21fSUday Bondhugula /// rewriting those ops. This set is not maintained when strictMode is off. 5747932d21fSUday Bondhugula llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps; 5757932d21fSUday Bondhugula }; 5767932d21fSUday Bondhugula 577be0a7e9fSMehdi Amini } // namespace 5787932d21fSUday Bondhugula 5797932d21fSUday Bondhugula /// Performs the specified rewrites on `ops` while also trying to fold these ops 5807932d21fSUday Bondhugula /// as well as any other ops that were in turn created due to these rewrite 5817932d21fSUday Bondhugula /// patterns. Any pre-existing ops outside of `ops` remain completely 5827932d21fSUday Bondhugula /// unmodified if `strictMode` is true. If `strictMode` is false, other 5837932d21fSUday Bondhugula /// operations that use results of rewritten ops or supply operands to such ops 5847932d21fSUday Bondhugula /// are in turn simplified; any other ops still remain unmodified (i.e., 5857932d21fSUday Bondhugula /// regardless of `strictMode`). Note that ops in `ops` could be erased as a 5867932d21fSUday Bondhugula /// result of folding, becoming dead, or via pattern rewrites. Returns true if 5877932d21fSUday Bondhugula /// at all any changes happened. 5887932d21fSUday Bondhugula // Unlike `OpPatternRewriteDriver::simplifyLocally` which works on a single op 5897932d21fSUday Bondhugula // or GreedyPatternRewriteDriver::simplify, this method just iterates until 5907932d21fSUday Bondhugula // the worklist is empty. As our objective is to keep simplification "local", 5917932d21fSUday Bondhugula // there is no strong rationale to re-add all operations into the worklist and 5927932d21fSUday Bondhugula // rerun until an iteration changes nothing. If more widereaching simplification 5937932d21fSUday Bondhugula // is desired, GreedyPatternRewriteDriver should be used. 5947932d21fSUday Bondhugula bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) { 5957932d21fSUday Bondhugula if (strictMode) { 5967932d21fSUday Bondhugula strictModeFilteredOps.clear(); 5977932d21fSUday Bondhugula strictModeFilteredOps.insert(ops.begin(), ops.end()); 5987932d21fSUday Bondhugula } 5997932d21fSUday Bondhugula 6007932d21fSUday Bondhugula bool changed = false; 6017932d21fSUday Bondhugula worklist.clear(); 6027932d21fSUday Bondhugula worklistMap.clear(); 6037932d21fSUday Bondhugula for (Operation *op : ops) 6047932d21fSUday Bondhugula addToWorklist(op); 6057932d21fSUday Bondhugula 6067932d21fSUday Bondhugula // These are scratch vectors used in the folding loop below. 6077932d21fSUday Bondhugula SmallVector<Value, 8> originalOperands, resultValues; 6087932d21fSUday Bondhugula while (!worklist.empty()) { 6097932d21fSUday Bondhugula Operation *op = popFromWorklist(); 610*2aeffc6dSChia-hung Duan assert((!strictMode || strictModeFilteredOps.contains(op)) && 611*2aeffc6dSChia-hung Duan "unexpected op was inserted under strict mode"); 6127932d21fSUday Bondhugula 6137932d21fSUday Bondhugula // Nulls get added to the worklist when operations are removed, ignore 6147932d21fSUday Bondhugula // them. 6157932d21fSUday Bondhugula if (op == nullptr) 6167932d21fSUday Bondhugula continue; 6177932d21fSUday Bondhugula 6187932d21fSUday Bondhugula // If the operation is trivially dead - remove it. 6197932d21fSUday Bondhugula if (isOpTriviallyDead(op)) { 6207932d21fSUday Bondhugula notifyOperationRemoved(op); 6217932d21fSUday Bondhugula op->erase(); 6227932d21fSUday Bondhugula changed = true; 6237932d21fSUday Bondhugula continue; 6247932d21fSUday Bondhugula } 6257932d21fSUday Bondhugula 6267932d21fSUday Bondhugula // Collects all the operands and result uses of the given `op` into work 6277932d21fSUday Bondhugula // list. Also remove `op` and nested ops from worklist. 6287932d21fSUday Bondhugula originalOperands.assign(op->operand_begin(), op->operand_end()); 6297932d21fSUday Bondhugula auto preReplaceAction = [&](Operation *op) { 6307932d21fSUday Bondhugula // Add the operands to the worklist for visitation. 6317932d21fSUday Bondhugula addOperandsToWorklist(originalOperands); 6327932d21fSUday Bondhugula 6337932d21fSUday Bondhugula // Add all the users of the result to the worklist so we make sure 6347932d21fSUday Bondhugula // to revisit them. 6357932d21fSUday Bondhugula for (Value result : op->getResults()) 6367932d21fSUday Bondhugula for (Operation *userOp : result.getUsers()) { 6377932d21fSUday Bondhugula if (!strictMode || strictModeFilteredOps.contains(userOp)) 6387932d21fSUday Bondhugula addToWorklist(userOp); 6397932d21fSUday Bondhugula } 6407932d21fSUday Bondhugula notifyOperationRemoved(op); 6417932d21fSUday Bondhugula }; 6427932d21fSUday Bondhugula 6437932d21fSUday Bondhugula // Add the given operation generated by the folder to the worklist. 6447932d21fSUday Bondhugula auto processGeneratedConstants = [this](Operation *op) { 6457932d21fSUday Bondhugula // Newly created ops are also simplified -- these are also "local". 6467932d21fSUday Bondhugula addToWorklist(op); 6477932d21fSUday Bondhugula // When strict mode is off, we don't need to maintain 6487932d21fSUday Bondhugula // strictModeFilteredOps. 6497932d21fSUday Bondhugula if (strictMode) 6507932d21fSUday Bondhugula strictModeFilteredOps.insert(op); 6517932d21fSUday Bondhugula }; 6527932d21fSUday Bondhugula 6537932d21fSUday Bondhugula // Try to fold this op. 6547932d21fSUday Bondhugula bool inPlaceUpdate; 6557932d21fSUday Bondhugula if (succeeded(folder.tryToFold(op, processGeneratedConstants, 6567932d21fSUday Bondhugula preReplaceAction, &inPlaceUpdate))) { 6577932d21fSUday Bondhugula changed = true; 6587932d21fSUday Bondhugula if (!inPlaceUpdate) { 6597932d21fSUday Bondhugula // Op has been erased. 6607932d21fSUday Bondhugula continue; 6617932d21fSUday Bondhugula } 6627932d21fSUday Bondhugula } 6637932d21fSUday Bondhugula 6647932d21fSUday Bondhugula // Try to match one of the patterns. The rewriter is automatically 6657932d21fSUday Bondhugula // notified of any necessary changes, so there is nothing else to do 6667932d21fSUday Bondhugula // here. 6677932d21fSUday Bondhugula changed |= succeeded(matcher.matchAndRewrite(op, *this)); 6687932d21fSUday Bondhugula } 6697932d21fSUday Bondhugula 6707932d21fSUday Bondhugula return changed; 6717932d21fSUday Bondhugula } 6727932d21fSUday Bondhugula 67304b5274eSUday Bondhugula /// Rewrites only `op` using the supplied canonicalization patterns and 67404b5274eSUday Bondhugula /// folding. `erased` is set to true if the op is erased as a result of being 67504b5274eSUday Bondhugula /// folded, replaced, or dead. 6763e98fbf4SRiver Riddle LogicalResult mlir::applyOpPatternsAndFold( 67779d7f618SChris Lattner Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) { 67804b5274eSUday Bondhugula // Start the pattern driver. 67964716b2cSChris Lattner GreedyRewriteConfig config; 68004b5274eSUday Bondhugula OpPatternRewriteDriver driver(op->getContext(), patterns); 68104b5274eSUday Bondhugula bool opErased; 6823e98fbf4SRiver Riddle LogicalResult converged = 68364716b2cSChris Lattner driver.simplifyLocally(op, config.maxIterations, opErased); 68404b5274eSUday Bondhugula if (erased) 68504b5274eSUday Bondhugula *erased = opErased; 6863e98fbf4SRiver Riddle LLVM_DEBUG(if (failed(converged)) { 68704b5274eSUday Bondhugula llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " 68864716b2cSChris Lattner << config.maxIterations << " times"; 68904b5274eSUday Bondhugula }); 69004b5274eSUday Bondhugula return converged; 69104b5274eSUday Bondhugula } 6927932d21fSUday Bondhugula 6937932d21fSUday Bondhugula bool mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops, 6947932d21fSUday Bondhugula const FrozenRewritePatternSet &patterns, 6957932d21fSUday Bondhugula bool strict) { 6967932d21fSUday Bondhugula if (ops.empty()) 6977932d21fSUday Bondhugula return false; 6987932d21fSUday Bondhugula 6997932d21fSUday Bondhugula // Start the pattern driver. 7007932d21fSUday Bondhugula MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, 7017932d21fSUday Bondhugula strict); 7027932d21fSUday Bondhugula return driver.simplifyLocally(ops); 7037932d21fSUday Bondhugula } 704