164d52014SChris Lattner //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===// 264d52014SChris Lattner // 364d52014SChris Lattner // Copyright 2019 The MLIR Authors. 464d52014SChris Lattner // 564d52014SChris Lattner // Licensed under the Apache License, Version 2.0 (the "License"); 664d52014SChris Lattner // you may not use this file except in compliance with the License. 764d52014SChris Lattner // You may obtain a copy of the License at 864d52014SChris Lattner // 964d52014SChris Lattner // http://www.apache.org/licenses/LICENSE-2.0 1064d52014SChris Lattner // 1164d52014SChris Lattner // Unless required by applicable law or agreed to in writing, software 1264d52014SChris Lattner // distributed under the License is distributed on an "AS IS" BASIS, 1364d52014SChris Lattner // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1464d52014SChris Lattner // See the License for the specific language governing permissions and 1564d52014SChris Lattner // limitations under the License. 1664d52014SChris Lattner // ============================================================================= 1764d52014SChris Lattner // 1864d52014SChris Lattner // This file implements mlir::applyPatternsGreedily. 1964d52014SChris Lattner // 2064d52014SChris Lattner //===----------------------------------------------------------------------===// 2164d52014SChris Lattner 2264d52014SChris Lattner #include "mlir/IR/Builders.h" 2364d52014SChris Lattner #include "mlir/IR/BuiltinOps.h" 247de0da95SChris Lattner #include "mlir/IR/PatternMatch.h" 2564d52014SChris Lattner #include "llvm/ADT/DenseMap.h" 2664d52014SChris Lattner using namespace mlir; 2764d52014SChris Lattner 2864d52014SChris Lattner namespace { 2964d52014SChris Lattner 3064d52014SChris Lattner /// This is a worklist-driven driver for the PatternMatcher, which repeatedly 3164d52014SChris Lattner /// applies the locally optimal patterns in a roughly "bottom up" way. 324bd9f936SChris Lattner class GreedyPatternRewriteDriver : public PatternRewriter { 3364d52014SChris Lattner public: 344bd9f936SChris Lattner explicit GreedyPatternRewriteDriver(Function *fn, 354bd9f936SChris Lattner OwningRewritePatternList &&patterns) 364bd9f936SChris Lattner : PatternRewriter(fn->getContext()), matcher(std::move(patterns)), 374bd9f936SChris Lattner builder(fn) { 3864d52014SChris Lattner worklist.reserve(64); 394bd9f936SChris Lattner 404bd9f936SChris Lattner // Add all operations to the worklist. 414bd9f936SChris Lattner fn->walkOps([&](OperationInst *inst) { addToWorklist(inst); }); 4264d52014SChris Lattner } 4364d52014SChris Lattner 444bd9f936SChris Lattner /// Perform the rewrites. 454bd9f936SChris Lattner void simplifyFunction(); 4664d52014SChris Lattner 475187cfcfSChris Lattner void addToWorklist(OperationInst *op) { 485c4f1fddSRiver Riddle // Check to see if the worklist already contains this op. 495c4f1fddSRiver Riddle if (worklistMap.count(op)) 505c4f1fddSRiver Riddle return; 515c4f1fddSRiver Riddle 5264d52014SChris Lattner worklistMap[op] = worklist.size(); 5364d52014SChris Lattner worklist.push_back(op); 5464d52014SChris Lattner } 5564d52014SChris Lattner 565187cfcfSChris Lattner OperationInst *popFromWorklist() { 5764d52014SChris Lattner auto *op = worklist.back(); 5864d52014SChris Lattner worklist.pop_back(); 5964d52014SChris Lattner 6064d52014SChris Lattner // This operation is no longer in the worklist, keep worklistMap up to date. 6164d52014SChris Lattner if (op) 6264d52014SChris Lattner worklistMap.erase(op); 6364d52014SChris Lattner return op; 6464d52014SChris Lattner } 6564d52014SChris Lattner 6664d52014SChris Lattner /// If the specified operation is in the worklist, remove it. If not, this is 6764d52014SChris Lattner /// a no-op. 685187cfcfSChris Lattner void removeFromWorklist(OperationInst *op) { 6964d52014SChris Lattner auto it = worklistMap.find(op); 7064d52014SChris Lattner if (it != worklistMap.end()) { 7164d52014SChris Lattner assert(worklist[it->second] == op && "malformed worklist data structure"); 7264d52014SChris Lattner worklist[it->second] = nullptr; 7364d52014SChris Lattner } 7464d52014SChris Lattner } 7564d52014SChris Lattner 764bd9f936SChris Lattner // These are hooks implemented for PatternRewriter. 774bd9f936SChris Lattner protected: 784bd9f936SChris Lattner // Implement the hook for creating operations, and make sure that newly 794bd9f936SChris Lattner // created ops are added to the worklist for processing. 804bd9f936SChris Lattner OperationInst *createOperation(const OperationState &state) override { 814bd9f936SChris Lattner auto *result = builder.createOperation(state); 824bd9f936SChris Lattner addToWorklist(result); 834bd9f936SChris Lattner return result; 844bd9f936SChris Lattner } 8564d52014SChris Lattner 8664d52014SChris Lattner // If an operation is about to be removed, make sure it is not in our 8764d52014SChris Lattner // worklist anymore because we'd get dangling references to it. 885187cfcfSChris Lattner void notifyOperationRemoved(OperationInst *op) override { 894bd9f936SChris Lattner removeFromWorklist(op); 9064d52014SChris Lattner } 9164d52014SChris Lattner 92085b687fSChris Lattner // When the root of a pattern is about to be replaced, it can trigger 93085b687fSChris Lattner // simplifications to its users - make sure to add them to the worklist 94085b687fSChris Lattner // before the root is changed. 955187cfcfSChris Lattner void notifyRootReplaced(OperationInst *op) override { 96085b687fSChris Lattner for (auto *result : op->getResults()) 97085b687fSChris Lattner // TODO: Add a result->getUsers() iterator. 98085b687fSChris Lattner for (auto &user : result->getUses()) { 995187cfcfSChris Lattner if (auto *op = dyn_cast<OperationInst>(user.getOwner())) 1004bd9f936SChris Lattner addToWorklist(op); 101085b687fSChris Lattner } 102085b687fSChris Lattner 103085b687fSChris Lattner // TODO: Walk the operand list dropping them as we go. If any of them 104085b687fSChris Lattner // drop to zero uses, then add them to the worklist to allow them to be 105085b687fSChris Lattner // deleted as dead. 106085b687fSChris Lattner } 107085b687fSChris Lattner 1084bd9f936SChris Lattner private: 1094bd9f936SChris Lattner /// The low-level pattern matcher. 1104bd9f936SChris Lattner PatternMatcher matcher; 1114bd9f936SChris Lattner 1124bd9f936SChris Lattner /// This builder is used to create new operations. 1134bd9f936SChris Lattner FuncBuilder builder; 1144bd9f936SChris Lattner 1154bd9f936SChris Lattner /// The worklist for this transformation keeps track of the operations that 1164bd9f936SChris Lattner /// need to be revisited, plus their index in the worklist. This allows us to 1174bd9f936SChris Lattner /// efficiently remove operations from the worklist when they are erased from 1184bd9f936SChris Lattner /// the function, even if they aren't the root of a pattern. 1194bd9f936SChris Lattner std::vector<OperationInst *> worklist; 1204bd9f936SChris Lattner DenseMap<OperationInst *, unsigned> worklistMap; 1214bd9f936SChris Lattner 1224bd9f936SChris Lattner /// As part of canonicalization, we move constants to the top of the entry 1234bd9f936SChris Lattner /// block of the current function and de-duplicate them. This keeps track of 1244bd9f936SChris Lattner /// constants we have done this for. 1254bd9f936SChris Lattner DenseMap<std::pair<Attribute, Type>, OperationInst *> uniquedConstants; 12664d52014SChris Lattner }; 1274bd9f936SChris Lattner }; // end anonymous namespace 12864d52014SChris Lattner 1294bd9f936SChris Lattner /// Perform the rewrites. 1304bd9f936SChris Lattner void GreedyPatternRewriteDriver::simplifyFunction() { 13164d52014SChris Lattner // These are scratch vectors used in the constant folding loop below. 132792d1c25SRiver Riddle SmallVector<Attribute, 8> operandConstants, resultConstants; 13364d52014SChris Lattner 13464d52014SChris Lattner while (!worklist.empty()) { 13564d52014SChris Lattner auto *op = popFromWorklist(); 13664d52014SChris Lattner 13764d52014SChris Lattner // Nulls get added to the worklist when operations are removed, ignore them. 13864d52014SChris Lattner if (op == nullptr) 13964d52014SChris Lattner continue; 14064d52014SChris Lattner 14164d52014SChris Lattner // If we have a constant op, unique it into the entry block. 14264d52014SChris Lattner if (auto constant = op->dyn_cast<ConstantOp>()) { 14364d52014SChris Lattner // If this constant is dead, remove it, being careful to keep 14464d52014SChris Lattner // uniquedConstants up to date. 14564d52014SChris Lattner if (constant->use_empty()) { 14664d52014SChris Lattner auto it = 14764d52014SChris Lattner uniquedConstants.find({constant->getValue(), constant->getType()}); 14864d52014SChris Lattner if (it != uniquedConstants.end() && it->second == op) 14964d52014SChris Lattner uniquedConstants.erase(it); 15064d52014SChris Lattner constant->erase(); 15164d52014SChris Lattner continue; 15264d52014SChris Lattner } 15364d52014SChris Lattner 15464d52014SChris Lattner // Check to see if we already have a constant with this type and value: 15564d52014SChris Lattner auto &entry = uniquedConstants[std::make_pair(constant->getValue(), 15664d52014SChris Lattner constant->getType())]; 15764d52014SChris Lattner if (entry) { 15864d52014SChris Lattner // If this constant is already our uniqued one, then leave it alone. 15964d52014SChris Lattner if (entry == op) 16064d52014SChris Lattner continue; 16164d52014SChris Lattner 16264d52014SChris Lattner // Otherwise replace this redundant constant with the uniqued one. We 16364d52014SChris Lattner // know this is safe because we move constants to the top of the 16464d52014SChris Lattner // function when they are uniqued, so we know they dominate all uses. 16564d52014SChris Lattner constant->replaceAllUsesWith(entry->getResult(0)); 16664d52014SChris Lattner constant->erase(); 16764d52014SChris Lattner continue; 16864d52014SChris Lattner } 16964d52014SChris Lattner 17064d52014SChris Lattner // If we have no entry, then we should unique this constant as the 17164d52014SChris Lattner // canonical version. To ensure safe dominance, move the operation to the 17264d52014SChris Lattner // top of the function. 17364d52014SChris Lattner entry = op; 1744bd9f936SChris Lattner auto &entryBB = builder.getInsertionBlock()->getFunction()->front(); 175471c9764SChris Lattner op->moveBefore(&entryBB, entryBB.begin()); 17664d52014SChris Lattner continue; 17764d52014SChris Lattner } 17864d52014SChris Lattner 17964d52014SChris Lattner // If the operation has no side effects, and no users, then it is trivially 18064d52014SChris Lattner // dead - remove it. 18164d52014SChris Lattner if (op->hasNoSideEffect() && op->use_empty()) { 18264d52014SChris Lattner op->erase(); 18364d52014SChris Lattner continue; 18464d52014SChris Lattner } 18564d52014SChris Lattner 18664d52014SChris Lattner // Check to see if any operands to the instruction is constant and whether 18764d52014SChris Lattner // the operation knows how to constant fold itself. 18864d52014SChris Lattner operandConstants.clear(); 18964d52014SChris Lattner for (auto *operand : op->getOperands()) { 190792d1c25SRiver Riddle Attribute operandCst; 1915187cfcfSChris Lattner if (auto *operandOp = operand->getDefiningInst()) { 19264d52014SChris Lattner if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>()) 19364d52014SChris Lattner operandCst = operandConstantOp->getValue(); 19464d52014SChris Lattner } 19564d52014SChris Lattner operandConstants.push_back(operandCst); 19664d52014SChris Lattner } 19764d52014SChris Lattner 19864d52014SChris Lattner // If constant folding was successful, create the result constants, RAUW the 19964d52014SChris Lattner // operation and remove it. 20064d52014SChris Lattner resultConstants.clear(); 20164d52014SChris Lattner if (!op->constantFold(operandConstants, resultConstants)) { 2024bd9f936SChris Lattner builder.setInsertionPoint(op); 20364d52014SChris Lattner 20464d52014SChris Lattner for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { 20564d52014SChris Lattner auto *res = op->getResult(i); 20664d52014SChris Lattner if (res->use_empty()) // ignore dead uses. 20764d52014SChris Lattner continue; 20864d52014SChris Lattner 20964d52014SChris Lattner // If we already have a canonicalized version of this constant, just 21064d52014SChris Lattner // reuse it. Otherwise create a new one. 2113f190312SChris Lattner Value *cstValue; 21264d52014SChris Lattner auto it = uniquedConstants.find({resultConstants[i], res->getType()}); 21364d52014SChris Lattner if (it != uniquedConstants.end()) 21464d52014SChris Lattner cstValue = it->second->getResult(0); 21564d52014SChris Lattner else 216*61ec6c09SLei Zhang cstValue = create<ConstantOp>(op->getLoc(), res->getType(), 217*61ec6c09SLei Zhang resultConstants[i]); 218967d9341SChris Lattner 219967d9341SChris Lattner // Add all the users of the result to the worklist so we make sure to 220967d9341SChris Lattner // revisit them. 221967d9341SChris Lattner // 222967d9341SChris Lattner // TODO: Add a result->getUsers() iterator. 223085b687fSChris Lattner for (auto &operand : op->getResult(i)->getUses()) { 2245187cfcfSChris Lattner if (auto *op = dyn_cast<OperationInst>(operand.getOwner())) 225967d9341SChris Lattner addToWorklist(op); 226967d9341SChris Lattner } 227967d9341SChris Lattner 22864d52014SChris Lattner res->replaceAllUsesWith(cstValue); 22964d52014SChris Lattner } 23064d52014SChris Lattner 23164d52014SChris Lattner assert(op->hasNoSideEffect() && "Constant folded op with side effects?"); 23264d52014SChris Lattner op->erase(); 23364d52014SChris Lattner continue; 23464d52014SChris Lattner } 23564d52014SChris Lattner 236967d9341SChris Lattner // If this is a commutative binary operation with a constant on the left 237967d9341SChris Lattner // side move it to the right side. 23864d52014SChris Lattner if (operandConstants.size() == 2 && operandConstants[0] && 239967d9341SChris Lattner !operandConstants[1] && op->isCommutative()) { 24064d52014SChris Lattner auto *newLHS = op->getOperand(1); 24164d52014SChris Lattner op->setOperand(1, op->getOperand(0)); 24264d52014SChris Lattner op->setOperand(0, newLHS); 24364d52014SChris Lattner } 24464d52014SChris Lattner 24564d52014SChris Lattner // Check to see if we have any patterns that match this node. 24664d52014SChris Lattner auto match = matcher.findMatch(op); 24764d52014SChris Lattner if (!match.first) 24864d52014SChris Lattner continue; 24964d52014SChris Lattner 25064d52014SChris Lattner // Make sure that any new operations are inserted at this point. 2514bd9f936SChris Lattner builder.setInsertionPoint(op); 2523f2530cdSChris Lattner // We know that any pattern that matched is RewritePattern because we 2533f2530cdSChris Lattner // initialized the matcher with RewritePatterns. 2543f2530cdSChris Lattner auto *rewritePattern = static_cast<RewritePattern *>(match.first); 2554bd9f936SChris Lattner rewritePattern->rewrite(op, std::move(match.second), *this); 25664d52014SChris Lattner } 25764d52014SChris Lattner 25864d52014SChris Lattner uniquedConstants.clear(); 25964d52014SChris Lattner } 26064d52014SChris Lattner 26164d52014SChris Lattner /// Rewrite the specified function by repeatedly applying the highest benefit 26264d52014SChris Lattner /// patterns in a greedy work-list driven manner. 26364d52014SChris Lattner /// 2643f2530cdSChris Lattner void mlir::applyPatternsGreedily(Function *fn, 2653f2530cdSChris Lattner OwningRewritePatternList &&patterns) { 2664bd9f936SChris Lattner GreedyPatternRewriteDriver driver(fn, std::move(patterns)); 2674bd9f936SChris Lattner driver.simplifyFunction(); 26864d52014SChris Lattner } 269