164d52014SChris Lattner //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===// 264d52014SChris Lattner // 364d52014SChris Lattner // Copyright 2019 The MLIR Authors. 464d52014SChris Lattner // 564d52014SChris Lattner // Licensed under the Apache License, Version 2.0 (the "License"); 664d52014SChris Lattner // you may not use this file except in compliance with the License. 764d52014SChris Lattner // You may obtain a copy of the License at 864d52014SChris Lattner // 964d52014SChris Lattner // http://www.apache.org/licenses/LICENSE-2.0 1064d52014SChris Lattner // 1164d52014SChris Lattner // Unless required by applicable law or agreed to in writing, software 1264d52014SChris Lattner // distributed under the License is distributed on an "AS IS" BASIS, 1364d52014SChris Lattner // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1464d52014SChris Lattner // See the License for the specific language governing permissions and 1564d52014SChris Lattner // limitations under the License. 1664d52014SChris Lattner // ============================================================================= 1764d52014SChris Lattner // 1864d52014SChris Lattner // This file implements mlir::applyPatternsGreedily. 1964d52014SChris Lattner // 2064d52014SChris Lattner //===----------------------------------------------------------------------===// 2164d52014SChris Lattner 2264d52014SChris Lattner #include "mlir/IR/Builders.h" 237de0da95SChris Lattner #include "mlir/IR/PatternMatch.h" 24f37651c7SRiver Riddle #include "mlir/StandardOps/Ops.h" 25*4e40c832SLei Zhang #include "mlir/Transforms/ConstantFoldUtils.h" 2664d52014SChris Lattner #include "llvm/ADT/DenseMap.h" 27*4e40c832SLei Zhang 2864d52014SChris Lattner using namespace mlir; 2964d52014SChris Lattner 3064d52014SChris Lattner namespace { 3164d52014SChris Lattner 3264d52014SChris Lattner /// This is a worklist-driven driver for the PatternMatcher, which repeatedly 3364d52014SChris Lattner /// applies the locally optimal patterns in a roughly "bottom up" way. 344bd9f936SChris Lattner class GreedyPatternRewriteDriver : public PatternRewriter { 3564d52014SChris Lattner public: 3646ade282SChris Lattner explicit GreedyPatternRewriteDriver(Function &fn, 374bd9f936SChris Lattner OwningRewritePatternList &&patterns) 3846ade282SChris Lattner : PatternRewriter(fn.getContext()), matcher(std::move(patterns), *this), 3946ade282SChris Lattner builder(&fn) { 4064d52014SChris Lattner worklist.reserve(64); 414bd9f936SChris Lattner 424bd9f936SChris Lattner // Add all operations to the worklist. 4399b87c97SRiver Riddle fn.walk([&](Operation *op) { addToWorklist(op); }); 4464d52014SChris Lattner } 4564d52014SChris Lattner 464bd9f936SChris Lattner /// Perform the rewrites. 474bd9f936SChris Lattner void simplifyFunction(); 4864d52014SChris Lattner 4999b87c97SRiver Riddle void addToWorklist(Operation *op) { 505c4f1fddSRiver Riddle // Check to see if the worklist already contains this op. 515c4f1fddSRiver Riddle if (worklistMap.count(op)) 525c4f1fddSRiver Riddle return; 535c4f1fddSRiver Riddle 5464d52014SChris Lattner worklistMap[op] = worklist.size(); 5564d52014SChris Lattner worklist.push_back(op); 5664d52014SChris Lattner } 5764d52014SChris Lattner 5899b87c97SRiver Riddle Operation *popFromWorklist() { 5964d52014SChris Lattner auto *op = worklist.back(); 6064d52014SChris Lattner worklist.pop_back(); 6164d52014SChris Lattner 6264d52014SChris Lattner // This operation is no longer in the worklist, keep worklistMap up to date. 6364d52014SChris Lattner if (op) 6464d52014SChris Lattner worklistMap.erase(op); 6564d52014SChris Lattner return op; 6664d52014SChris Lattner } 6764d52014SChris Lattner 6864d52014SChris Lattner /// If the specified operation is in the worklist, remove it. If not, this is 6964d52014SChris Lattner /// a no-op. 7099b87c97SRiver Riddle void removeFromWorklist(Operation *op) { 7164d52014SChris Lattner auto it = worklistMap.find(op); 7264d52014SChris Lattner if (it != worklistMap.end()) { 7364d52014SChris Lattner assert(worklist[it->second] == op && "malformed worklist data structure"); 7464d52014SChris Lattner worklist[it->second] = nullptr; 7564d52014SChris Lattner } 7664d52014SChris Lattner } 7764d52014SChris Lattner 784bd9f936SChris Lattner // These are hooks implemented for PatternRewriter. 794bd9f936SChris Lattner protected: 804bd9f936SChris Lattner // Implement the hook for creating operations, and make sure that newly 814bd9f936SChris Lattner // created ops are added to the worklist for processing. 8299b87c97SRiver Riddle Operation *createOperation(const OperationState &state) override { 834bd9f936SChris Lattner auto *result = builder.createOperation(state); 844bd9f936SChris Lattner addToWorklist(result); 854bd9f936SChris Lattner return result; 864bd9f936SChris Lattner } 8764d52014SChris Lattner 8864d52014SChris Lattner // If an operation is about to be removed, make sure it is not in our 8964d52014SChris Lattner // worklist anymore because we'd get dangling references to it. 9099b87c97SRiver Riddle void notifyOperationRemoved(Operation *op) override { 91a8866258SRiver Riddle addToWorklist(op->getOperands()); 924bd9f936SChris Lattner removeFromWorklist(op); 9364d52014SChris Lattner } 9464d52014SChris Lattner 95085b687fSChris Lattner // When the root of a pattern is about to be replaced, it can trigger 96085b687fSChris Lattner // simplifications to its users - make sure to add them to the worklist 97085b687fSChris Lattner // before the root is changed. 9899b87c97SRiver Riddle void notifyRootReplaced(Operation *op) override { 99085b687fSChris Lattner for (auto *result : op->getResults()) 100085b687fSChris Lattner // TODO: Add a result->getUsers() iterator. 101b499277fSRiver Riddle for (auto &user : result->getUses()) 102b499277fSRiver Riddle addToWorklist(user.getOwner()); 103085b687fSChris Lattner } 104085b687fSChris Lattner 1054bd9f936SChris Lattner private: 10699b87c97SRiver Riddle // Look over the provided operands for any defining operations that should 107a8866258SRiver Riddle // be re-added to the worklist. This function should be called when an 108a8866258SRiver Riddle // operation is modified or removed, as it may trigger further 109a8866258SRiver Riddle // simplifications. 110a8866258SRiver Riddle template <typename Operands> void addToWorklist(Operands &&operands) { 111a8866258SRiver Riddle for (Value *operand : operands) { 112a8866258SRiver Riddle // If the use count of this operand is now < 2, we re-add the defining 11399b87c97SRiver Riddle // operation to the worklist. 11499b87c97SRiver Riddle // TODO(riverriddle) This is based on the fact that zero use operations 115a8866258SRiver Riddle // may be deleted, and that single use values often have more 116a8866258SRiver Riddle // canonicalization opportunities. 117a8866258SRiver Riddle if (!operand->use_empty() && 118a8866258SRiver Riddle std::next(operand->use_begin()) != operand->use_end()) 119a8866258SRiver Riddle continue; 120f9d91531SRiver Riddle if (auto *defInst = operand->getDefiningOp()) 121a8866258SRiver Riddle addToWorklist(defInst); 122a8866258SRiver Riddle } 123a8866258SRiver Riddle } 124a8866258SRiver Riddle 1254bd9f936SChris Lattner /// The low-level pattern matcher. 1265de726f4SRiver Riddle RewritePatternMatcher matcher; 1274bd9f936SChris Lattner 1284bd9f936SChris Lattner /// This builder is used to create new operations. 1294bd9f936SChris Lattner FuncBuilder builder; 1304bd9f936SChris Lattner 1314bd9f936SChris Lattner /// The worklist for this transformation keeps track of the operations that 1324bd9f936SChris Lattner /// need to be revisited, plus their index in the worklist. This allows us to 1334bd9f936SChris Lattner /// efficiently remove operations from the worklist when they are erased from 1344bd9f936SChris Lattner /// the function, even if they aren't the root of a pattern. 13599b87c97SRiver Riddle std::vector<Operation *> worklist; 13699b87c97SRiver Riddle DenseMap<Operation *, unsigned> worklistMap; 13764d52014SChris Lattner }; 1384bd9f936SChris Lattner }; // end anonymous namespace 13964d52014SChris Lattner 1404bd9f936SChris Lattner /// Perform the rewrites. 1414bd9f936SChris Lattner void GreedyPatternRewriteDriver::simplifyFunction() { 142*4e40c832SLei Zhang ConstantFoldHelper helper(builder.getFunction()); 143*4e40c832SLei Zhang 144*4e40c832SLei Zhang // These are scratch vectors used in the folding loop below. 145934b6d12SChris Lattner SmallVector<Value *, 8> originalOperands, resultValues; 14664d52014SChris Lattner 14764d52014SChris Lattner while (!worklist.empty()) { 14864d52014SChris Lattner auto *op = popFromWorklist(); 14964d52014SChris Lattner 15064d52014SChris Lattner // Nulls get added to the worklist when operations are removed, ignore them. 15164d52014SChris Lattner if (op == nullptr) 15264d52014SChris Lattner continue; 15364d52014SChris Lattner 15464d52014SChris Lattner // If the operation has no side effects, and no users, then it is trivially 15564d52014SChris Lattner // dead - remove it. 15664d52014SChris Lattner if (op->hasNoSideEffect() && op->use_empty()) { 157*4e40c832SLei Zhang // Be careful to update bookkeeping in ConstantHelper to keep consistency 158*4e40c832SLei Zhang // if this is a constant op. 159*4e40c832SLei Zhang if (op->isa<ConstantOp>()) 160*4e40c832SLei Zhang helper.notifyRemoval(op); 16164d52014SChris Lattner op->erase(); 16264d52014SChris Lattner continue; 16364d52014SChris Lattner } 16464d52014SChris Lattner 165*4e40c832SLei Zhang // Collects all the operands and result uses of the given `op` into work 166*4e40c832SLei Zhang // list. 167*4e40c832SLei Zhang auto collectOperandsAndUses = [this](Operation *op) { 168a8866258SRiver Riddle // Add the operands to the worklist for visitation. 169a8866258SRiver Riddle addToWorklist(op->getOperands()); 170*4e40c832SLei Zhang // Add all the users of the result to the worklist so we make sure 171*4e40c832SLei Zhang // to revisit them. 172967d9341SChris Lattner // 173967d9341SChris Lattner // TODO: Add a result->getUsers() iterator. 174*4e40c832SLei Zhang for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { 175b499277fSRiver Riddle for (auto &operand : op->getResult(i)->getUses()) 176b499277fSRiver Riddle addToWorklist(operand.getOwner()); 17764d52014SChris Lattner } 178*4e40c832SLei Zhang }; 17964d52014SChris Lattner 180*4e40c832SLei Zhang // Try to constant fold this op. 181*4e40c832SLei Zhang if (helper.tryToConstantFold(op, collectOperandsAndUses)) { 18264d52014SChris Lattner assert(op->hasNoSideEffect() && "Constant folded op with side effects?"); 18364d52014SChris Lattner op->erase(); 18464d52014SChris Lattner continue; 18564d52014SChris Lattner } 18664d52014SChris Lattner 187934b6d12SChris Lattner // Otherwise see if we can use the generic folder API to simplify the 188934b6d12SChris Lattner // operation. 189934b6d12SChris Lattner originalOperands.assign(op->operand_begin(), op->operand_end()); 190934b6d12SChris Lattner resultValues.clear(); 1915e1f1d2cSRiver Riddle if (succeeded(op->fold(resultValues))) { 192934b6d12SChris Lattner // If the result was an in-place simplification (e.g. max(x,x,y) -> 193934b6d12SChris Lattner // max(x,y)) then add the original operands to the worklist so we can make 194934b6d12SChris Lattner // sure to revisit them. 195934b6d12SChris Lattner if (resultValues.empty()) { 196a8866258SRiver Riddle // Add the operands back to the worklist as there may be more 197a8866258SRiver Riddle // canonicalization opportunities now. 198a8866258SRiver Riddle addToWorklist(originalOperands); 199934b6d12SChris Lattner } else { 200934b6d12SChris Lattner // Otherwise, the operation is simplified away completely. 201934b6d12SChris Lattner assert(resultValues.size() == op->getNumResults()); 202934b6d12SChris Lattner 203a8866258SRiver Riddle // Notify that we are replacing this operation. 204a8866258SRiver Riddle notifyRootReplaced(op); 205a8866258SRiver Riddle 206a8866258SRiver Riddle // Replace the result values and erase the operation. 207934b6d12SChris Lattner for (unsigned i = 0, e = resultValues.size(); i != e; ++i) { 208934b6d12SChris Lattner auto *res = op->getResult(i); 209a8866258SRiver Riddle if (!res->use_empty()) 210934b6d12SChris Lattner res->replaceAllUsesWith(resultValues[i]); 211934b6d12SChris Lattner } 212934b6d12SChris Lattner 213a8866258SRiver Riddle notifyOperationRemoved(op); 214934b6d12SChris Lattner op->erase(); 21599fee0b1SRiver Riddle } 216934b6d12SChris Lattner continue; 21764d52014SChris Lattner } 21864d52014SChris Lattner 21964d52014SChris Lattner // Make sure that any new operations are inserted at this point. 2204bd9f936SChris Lattner builder.setInsertionPoint(op); 2215de726f4SRiver Riddle 2225de726f4SRiver Riddle // Try to match one of the canonicalization patterns. The rewriter is 2235de726f4SRiver Riddle // automatically notified of any necessary changes, so there is nothing else 2245de726f4SRiver Riddle // to do here. 2255de726f4SRiver Riddle matcher.matchAndRewrite(op); 22664d52014SChris Lattner } 22764d52014SChris Lattner } 22864d52014SChris Lattner 22964d52014SChris Lattner /// Rewrite the specified function by repeatedly applying the highest benefit 23064d52014SChris Lattner /// patterns in a greedy work-list driven manner. 23164d52014SChris Lattner /// 23246ade282SChris Lattner void mlir::applyPatternsGreedily(Function &fn, 2333f2530cdSChris Lattner OwningRewritePatternList &&patterns) { 2344bd9f936SChris Lattner GreedyPatternRewriteDriver driver(fn, std::move(patterns)); 2354bd9f936SChris Lattner driver.simplifyFunction(); 23664d52014SChris Lattner } 237