164d52014SChris Lattner //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===// 264d52014SChris Lattner // 364d52014SChris Lattner // Copyright 2019 The MLIR Authors. 464d52014SChris Lattner // 564d52014SChris Lattner // Licensed under the Apache License, Version 2.0 (the "License"); 664d52014SChris Lattner // you may not use this file except in compliance with the License. 764d52014SChris Lattner // You may obtain a copy of the License at 864d52014SChris Lattner // 964d52014SChris Lattner // http://www.apache.org/licenses/LICENSE-2.0 1064d52014SChris Lattner // 1164d52014SChris Lattner // Unless required by applicable law or agreed to in writing, software 1264d52014SChris Lattner // distributed under the License is distributed on an "AS IS" BASIS, 1364d52014SChris Lattner // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1464d52014SChris Lattner // See the License for the specific language governing permissions and 1564d52014SChris Lattner // limitations under the License. 1664d52014SChris Lattner // ============================================================================= 1764d52014SChris Lattner // 1864d52014SChris Lattner // This file implements mlir::applyPatternsGreedily. 1964d52014SChris Lattner // 2064d52014SChris Lattner //===----------------------------------------------------------------------===// 2164d52014SChris Lattner 2264d52014SChris Lattner #include "mlir/IR/Builders.h" 2364d52014SChris Lattner #include "mlir/IR/BuiltinOps.h" 2464d52014SChris Lattner #include "mlir/StandardOps/StandardOps.h" 2564d52014SChris Lattner #include "mlir/Transforms/PatternMatch.h" 2664d52014SChris Lattner #include "llvm/ADT/DenseMap.h" 2764d52014SChris Lattner using namespace mlir; 2864d52014SChris Lattner 2964d52014SChris Lattner namespace { 3064d52014SChris Lattner class WorklistRewriter; 3164d52014SChris Lattner 3264d52014SChris Lattner /// This is a worklist-driven driver for the PatternMatcher, which repeatedly 3364d52014SChris Lattner /// applies the locally optimal patterns in a roughly "bottom up" way. 3464d52014SChris Lattner class GreedyPatternRewriteDriver { 3564d52014SChris Lattner public: 3664d52014SChris Lattner explicit GreedyPatternRewriteDriver(OwningPatternList &&patterns) 3764d52014SChris Lattner : matcher(std::move(patterns)) { 3864d52014SChris Lattner worklist.reserve(64); 3964d52014SChris Lattner } 4064d52014SChris Lattner 4164d52014SChris Lattner void simplifyFunction(Function *currentFunction, WorklistRewriter &rewriter); 4264d52014SChris Lattner 4364d52014SChris Lattner void addToWorklist(Operation *op) { 4464d52014SChris Lattner worklistMap[op] = worklist.size(); 4564d52014SChris Lattner worklist.push_back(op); 4664d52014SChris Lattner } 4764d52014SChris Lattner 4864d52014SChris Lattner Operation *popFromWorklist() { 4964d52014SChris Lattner auto *op = worklist.back(); 5064d52014SChris Lattner worklist.pop_back(); 5164d52014SChris Lattner 5264d52014SChris Lattner // This operation is no longer in the worklist, keep worklistMap up to date. 5364d52014SChris Lattner if (op) 5464d52014SChris Lattner worklistMap.erase(op); 5564d52014SChris Lattner return op; 5664d52014SChris Lattner } 5764d52014SChris Lattner 5864d52014SChris Lattner /// If the specified operation is in the worklist, remove it. If not, this is 5964d52014SChris Lattner /// a no-op. 6064d52014SChris Lattner void removeFromWorklist(Operation *op) { 6164d52014SChris Lattner auto it = worklistMap.find(op); 6264d52014SChris Lattner if (it != worklistMap.end()) { 6364d52014SChris Lattner assert(worklist[it->second] == op && "malformed worklist data structure"); 6464d52014SChris Lattner worklist[it->second] = nullptr; 6564d52014SChris Lattner } 6664d52014SChris Lattner } 6764d52014SChris Lattner 6864d52014SChris Lattner private: 6964d52014SChris Lattner /// The low-level pattern matcher. 7064d52014SChris Lattner PatternMatcher matcher; 7164d52014SChris Lattner 7264d52014SChris Lattner /// The worklist for this transformation keeps track of the operations that 7364d52014SChris Lattner /// need to be revisited, plus their index in the worklist. This allows us to 7464d52014SChris Lattner /// efficiently remove operations from the worklist when they are removed even 7564d52014SChris Lattner /// if they aren't the root of a pattern. 7664d52014SChris Lattner std::vector<Operation *> worklist; 7764d52014SChris Lattner DenseMap<Operation *, unsigned> worklistMap; 7864d52014SChris Lattner 7964d52014SChris Lattner /// As part of canonicalization, we move constants to the top of the entry 8064d52014SChris Lattner /// block of the current function and de-duplicate them. This keeps track of 8164d52014SChris Lattner /// constants we have done this for. 82*792d1c25SRiver Riddle DenseMap<std::pair<Attribute, Type *>, Operation *> uniquedConstants; 8364d52014SChris Lattner }; 8464d52014SChris Lattner }; // end anonymous namespace 8564d52014SChris Lattner 8664d52014SChris Lattner /// This is a listener object that updates our worklists and other data 8764d52014SChris Lattner /// structures in response to operations being added and removed. 8864d52014SChris Lattner namespace { 8964d52014SChris Lattner class WorklistRewriter : public PatternRewriter { 9064d52014SChris Lattner public: 9164d52014SChris Lattner WorklistRewriter(GreedyPatternRewriteDriver &driver, MLIRContext *context) 9264d52014SChris Lattner : PatternRewriter(context), driver(driver) {} 9364d52014SChris Lattner 9464d52014SChris Lattner virtual void setInsertionPoint(Operation *op) = 0; 9564d52014SChris Lattner 9664d52014SChris Lattner // If an operation is about to be removed, make sure it is not in our 9764d52014SChris Lattner // worklist anymore because we'd get dangling references to it. 9864d52014SChris Lattner void notifyOperationRemoved(Operation *op) override { 9964d52014SChris Lattner driver.removeFromWorklist(op); 10064d52014SChris Lattner } 10164d52014SChris Lattner 10264d52014SChris Lattner GreedyPatternRewriteDriver &driver; 10364d52014SChris Lattner }; 10464d52014SChris Lattner 10564d52014SChris Lattner } // end anonymous namespace 10664d52014SChris Lattner 10764d52014SChris Lattner void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, 10864d52014SChris Lattner WorklistRewriter &rewriter) { 10964d52014SChris Lattner // These are scratch vectors used in the constant folding loop below. 110*792d1c25SRiver Riddle SmallVector<Attribute, 8> operandConstants, resultConstants; 11164d52014SChris Lattner 11264d52014SChris Lattner while (!worklist.empty()) { 11364d52014SChris Lattner auto *op = popFromWorklist(); 11464d52014SChris Lattner 11564d52014SChris Lattner // Nulls get added to the worklist when operations are removed, ignore them. 11664d52014SChris Lattner if (op == nullptr) 11764d52014SChris Lattner continue; 11864d52014SChris Lattner 11964d52014SChris Lattner // If we have a constant op, unique it into the entry block. 12064d52014SChris Lattner if (auto constant = op->dyn_cast<ConstantOp>()) { 12164d52014SChris Lattner // If this constant is dead, remove it, being careful to keep 12264d52014SChris Lattner // uniquedConstants up to date. 12364d52014SChris Lattner if (constant->use_empty()) { 12464d52014SChris Lattner auto it = 12564d52014SChris Lattner uniquedConstants.find({constant->getValue(), constant->getType()}); 12664d52014SChris Lattner if (it != uniquedConstants.end() && it->second == op) 12764d52014SChris Lattner uniquedConstants.erase(it); 12864d52014SChris Lattner constant->erase(); 12964d52014SChris Lattner continue; 13064d52014SChris Lattner } 13164d52014SChris Lattner 13264d52014SChris Lattner // Check to see if we already have a constant with this type and value: 13364d52014SChris Lattner auto &entry = uniquedConstants[std::make_pair(constant->getValue(), 13464d52014SChris Lattner constant->getType())]; 13564d52014SChris Lattner if (entry) { 13664d52014SChris Lattner // If this constant is already our uniqued one, then leave it alone. 13764d52014SChris Lattner if (entry == op) 13864d52014SChris Lattner continue; 13964d52014SChris Lattner 14064d52014SChris Lattner // Otherwise replace this redundant constant with the uniqued one. We 14164d52014SChris Lattner // know this is safe because we move constants to the top of the 14264d52014SChris Lattner // function when they are uniqued, so we know they dominate all uses. 14364d52014SChris Lattner constant->replaceAllUsesWith(entry->getResult(0)); 14464d52014SChris Lattner constant->erase(); 14564d52014SChris Lattner continue; 14664d52014SChris Lattner } 14764d52014SChris Lattner 14864d52014SChris Lattner // If we have no entry, then we should unique this constant as the 14964d52014SChris Lattner // canonical version. To ensure safe dominance, move the operation to the 15064d52014SChris Lattner // top of the function. 15164d52014SChris Lattner entry = op; 15264d52014SChris Lattner 15364d52014SChris Lattner // TODO: If we make terminators into Operations then we could turn this 15464d52014SChris Lattner // into a nice Operation::moveBefore(Operation*) method. We just need the 15564d52014SChris Lattner // guarantee that a block is non-empty. 15664d52014SChris Lattner if (auto *cfgFunc = dyn_cast<CFGFunction>(currentFunction)) { 15764d52014SChris Lattner auto &entryBB = cfgFunc->front(); 15864d52014SChris Lattner cast<OperationInst>(op)->moveBefore(&entryBB, entryBB.begin()); 15964d52014SChris Lattner } else { 16064d52014SChris Lattner auto *mlFunc = cast<MLFunction>(currentFunction); 16164d52014SChris Lattner cast<OperationStmt>(op)->moveBefore(mlFunc, mlFunc->begin()); 16264d52014SChris Lattner } 16364d52014SChris Lattner 16464d52014SChris Lattner continue; 16564d52014SChris Lattner } 16664d52014SChris Lattner 16764d52014SChris Lattner // If the operation has no side effects, and no users, then it is trivially 16864d52014SChris Lattner // dead - remove it. 16964d52014SChris Lattner if (op->hasNoSideEffect() && op->use_empty()) { 17064d52014SChris Lattner op->erase(); 17164d52014SChris Lattner continue; 17264d52014SChris Lattner } 17364d52014SChris Lattner 17464d52014SChris Lattner // Check to see if any operands to the instruction is constant and whether 17564d52014SChris Lattner // the operation knows how to constant fold itself. 17664d52014SChris Lattner operandConstants.clear(); 17764d52014SChris Lattner for (auto *operand : op->getOperands()) { 178*792d1c25SRiver Riddle Attribute operandCst; 17964d52014SChris Lattner if (auto *operandOp = operand->getDefiningOperation()) { 18064d52014SChris Lattner if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>()) 18164d52014SChris Lattner operandCst = operandConstantOp->getValue(); 18264d52014SChris Lattner } 18364d52014SChris Lattner operandConstants.push_back(operandCst); 18464d52014SChris Lattner } 18564d52014SChris Lattner 18664d52014SChris Lattner // If constant folding was successful, create the result constants, RAUW the 18764d52014SChris Lattner // operation and remove it. 18864d52014SChris Lattner resultConstants.clear(); 18964d52014SChris Lattner if (!op->constantFold(operandConstants, resultConstants)) { 19064d52014SChris Lattner rewriter.setInsertionPoint(op); 19164d52014SChris Lattner 19264d52014SChris Lattner for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { 19364d52014SChris Lattner auto *res = op->getResult(i); 19464d52014SChris Lattner if (res->use_empty()) // ignore dead uses. 19564d52014SChris Lattner continue; 19664d52014SChris Lattner 19764d52014SChris Lattner // If we already have a canonicalized version of this constant, just 19864d52014SChris Lattner // reuse it. Otherwise create a new one. 19964d52014SChris Lattner SSAValue *cstValue; 20064d52014SChris Lattner auto it = uniquedConstants.find({resultConstants[i], res->getType()}); 20164d52014SChris Lattner if (it != uniquedConstants.end()) 20264d52014SChris Lattner cstValue = it->second->getResult(0); 20364d52014SChris Lattner else 20464d52014SChris Lattner cstValue = rewriter.create<ConstantOp>( 20564d52014SChris Lattner op->getLoc(), resultConstants[i], res->getType()); 20664d52014SChris Lattner res->replaceAllUsesWith(cstValue); 20764d52014SChris Lattner } 20864d52014SChris Lattner 20964d52014SChris Lattner assert(op->hasNoSideEffect() && "Constant folded op with side effects?"); 21064d52014SChris Lattner op->erase(); 21164d52014SChris Lattner continue; 21264d52014SChris Lattner } 21364d52014SChris Lattner 21464d52014SChris Lattner // If this is an associative binary operation with a constant on the LHS, 21564d52014SChris Lattner // move it to the right side. 21664d52014SChris Lattner if (operandConstants.size() == 2 && operandConstants[0] && 21764d52014SChris Lattner !operandConstants[1]) { 21864d52014SChris Lattner auto *newLHS = op->getOperand(1); 21964d52014SChris Lattner op->setOperand(1, op->getOperand(0)); 22064d52014SChris Lattner op->setOperand(0, newLHS); 22164d52014SChris Lattner } 22264d52014SChris Lattner 22364d52014SChris Lattner // Check to see if we have any patterns that match this node. 22464d52014SChris Lattner auto match = matcher.findMatch(op); 22564d52014SChris Lattner if (!match.first) 22664d52014SChris Lattner continue; 22764d52014SChris Lattner 22864d52014SChris Lattner // Make sure that any new operations are inserted at this point. 22964d52014SChris Lattner rewriter.setInsertionPoint(op); 23064d52014SChris Lattner match.first->rewrite(op, std::move(match.second), rewriter); 23164d52014SChris Lattner } 23264d52014SChris Lattner 23364d52014SChris Lattner uniquedConstants.clear(); 23464d52014SChris Lattner } 23564d52014SChris Lattner 23664d52014SChris Lattner static void processMLFunction(MLFunction *fn, OwningPatternList &&patterns) { 23764d52014SChris Lattner class MLFuncRewriter : public WorklistRewriter { 23864d52014SChris Lattner public: 23964d52014SChris Lattner MLFuncRewriter(GreedyPatternRewriteDriver &driver, MLFuncBuilder &builder) 24064d52014SChris Lattner : WorklistRewriter(driver, builder.getContext()), builder(builder) {} 24164d52014SChris Lattner 24264d52014SChris Lattner // Implement the hook for creating operations, and make sure that newly 24364d52014SChris Lattner // created ops are added to the worklist for processing. 24464d52014SChris Lattner Operation *createOperation(const OperationState &state) override { 24564d52014SChris Lattner auto *result = builder.createOperation(state); 24664d52014SChris Lattner driver.addToWorklist(result); 24764d52014SChris Lattner return result; 24864d52014SChris Lattner } 24964d52014SChris Lattner 25064d52014SChris Lattner // When the root of a pattern is about to be replaced, it can trigger 25164d52014SChris Lattner // simplifications to its users - make sure to add them to the worklist 25264d52014SChris Lattner // before the root is changed. 25364d52014SChris Lattner void notifyRootReplaced(Operation *op) override { 25464d52014SChris Lattner auto *opStmt = cast<OperationStmt>(op); 25564d52014SChris Lattner for (auto *result : opStmt->getResults()) 25664d52014SChris Lattner // TODO: Add a result->getUsers() iterator. 25764d52014SChris Lattner for (auto &user : result->getUses()) { 25864d52014SChris Lattner if (auto *op = dyn_cast<OperationStmt>(user.getOwner())) 25964d52014SChris Lattner driver.addToWorklist(op); 26064d52014SChris Lattner } 26164d52014SChris Lattner 26264d52014SChris Lattner // TODO: Walk the operand list dropping them as we go. If any of them 26364d52014SChris Lattner // drop to zero uses, then add them to the worklist to allow them to be 26464d52014SChris Lattner // deleted as dead. 26564d52014SChris Lattner } 26664d52014SChris Lattner 26764d52014SChris Lattner void setInsertionPoint(Operation *op) override { 26864d52014SChris Lattner // Any new operations should be added before this statement. 26964d52014SChris Lattner builder.setInsertionPoint(cast<OperationStmt>(op)); 27064d52014SChris Lattner } 27164d52014SChris Lattner 27264d52014SChris Lattner private: 27364d52014SChris Lattner MLFuncBuilder &builder; 27464d52014SChris Lattner }; 27564d52014SChris Lattner 27664d52014SChris Lattner GreedyPatternRewriteDriver driver(std::move(patterns)); 27764d52014SChris Lattner fn->walk([&](OperationStmt *stmt) { driver.addToWorklist(stmt); }); 27864d52014SChris Lattner 27964d52014SChris Lattner MLFuncBuilder mlBuilder(fn); 28064d52014SChris Lattner MLFuncRewriter rewriter(driver, mlBuilder); 28164d52014SChris Lattner driver.simplifyFunction(fn, rewriter); 28264d52014SChris Lattner } 28364d52014SChris Lattner 28464d52014SChris Lattner static void processCFGFunction(CFGFunction *fn, OwningPatternList &&patterns) { 28564d52014SChris Lattner class CFGFuncRewriter : public WorklistRewriter { 28664d52014SChris Lattner public: 28764d52014SChris Lattner CFGFuncRewriter(GreedyPatternRewriteDriver &driver, CFGFuncBuilder &builder) 28864d52014SChris Lattner : WorklistRewriter(driver, builder.getContext()), builder(builder) {} 28964d52014SChris Lattner 29064d52014SChris Lattner // Implement the hook for creating operations, and make sure that newly 29164d52014SChris Lattner // created ops are added to the worklist for processing. 29264d52014SChris Lattner Operation *createOperation(const OperationState &state) override { 29364d52014SChris Lattner auto *result = builder.createOperation(state); 29464d52014SChris Lattner driver.addToWorklist(result); 29564d52014SChris Lattner return result; 29664d52014SChris Lattner } 29764d52014SChris Lattner 29864d52014SChris Lattner // When the root of a pattern is about to be replaced, it can trigger 29964d52014SChris Lattner // simplifications to its users - make sure to add them to the worklist 30064d52014SChris Lattner // before the root is changed. 30164d52014SChris Lattner void notifyRootReplaced(Operation *op) override { 30264d52014SChris Lattner auto *opStmt = cast<OperationInst>(op); 30364d52014SChris Lattner for (auto *result : opStmt->getResults()) 30464d52014SChris Lattner // TODO: Add a result->getUsers() iterator. 30564d52014SChris Lattner for (auto &user : result->getUses()) { 30664d52014SChris Lattner if (auto *op = dyn_cast<OperationInst>(user.getOwner())) 30764d52014SChris Lattner driver.addToWorklist(op); 30864d52014SChris Lattner } 30964d52014SChris Lattner 31064d52014SChris Lattner // TODO: Walk the operand list dropping them as we go. If any of them 31164d52014SChris Lattner // drop to zero uses, then add them to the worklist to allow them to be 31264d52014SChris Lattner // deleted as dead. 31364d52014SChris Lattner } 31464d52014SChris Lattner 31564d52014SChris Lattner void setInsertionPoint(Operation *op) override { 31664d52014SChris Lattner // Any new operations should be added before this instruction. 31764d52014SChris Lattner builder.setInsertionPoint(cast<OperationInst>(op)); 31864d52014SChris Lattner } 31964d52014SChris Lattner 32064d52014SChris Lattner private: 32164d52014SChris Lattner CFGFuncBuilder &builder; 32264d52014SChris Lattner }; 32364d52014SChris Lattner 32464d52014SChris Lattner GreedyPatternRewriteDriver driver(std::move(patterns)); 32564d52014SChris Lattner for (auto &bb : *fn) 32664d52014SChris Lattner for (auto &op : bb) 32764d52014SChris Lattner driver.addToWorklist(&op); 32864d52014SChris Lattner 32964d52014SChris Lattner CFGFuncBuilder cfgBuilder(fn); 33064d52014SChris Lattner CFGFuncRewriter rewriter(driver, cfgBuilder); 33164d52014SChris Lattner driver.simplifyFunction(fn, rewriter); 33264d52014SChris Lattner } 33364d52014SChris Lattner 33464d52014SChris Lattner /// Rewrite the specified function by repeatedly applying the highest benefit 33564d52014SChris Lattner /// patterns in a greedy work-list driven manner. 33664d52014SChris Lattner /// 33764d52014SChris Lattner void mlir::applyPatternsGreedily(Function *fn, OwningPatternList &&patterns) { 33864d52014SChris Lattner if (auto *cfg = dyn_cast<CFGFunction>(fn)) { 33964d52014SChris Lattner processCFGFunction(cfg, std::move(patterns)); 34064d52014SChris Lattner } else { 34164d52014SChris Lattner processMLFunction(cast<MLFunction>(fn), std::move(patterns)); 34264d52014SChris Lattner } 34364d52014SChris Lattner } 344