164d52014SChris Lattner //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===// 264d52014SChris Lattner // 364d52014SChris Lattner // Copyright 2019 The MLIR Authors. 464d52014SChris Lattner // 564d52014SChris Lattner // Licensed under the Apache License, Version 2.0 (the "License"); 664d52014SChris Lattner // you may not use this file except in compliance with the License. 764d52014SChris Lattner // You may obtain a copy of the License at 864d52014SChris Lattner // 964d52014SChris Lattner // http://www.apache.org/licenses/LICENSE-2.0 1064d52014SChris Lattner // 1164d52014SChris Lattner // Unless required by applicable law or agreed to in writing, software 1264d52014SChris Lattner // distributed under the License is distributed on an "AS IS" BASIS, 1364d52014SChris Lattner // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1464d52014SChris Lattner // See the License for the specific language governing permissions and 1564d52014SChris Lattner // limitations under the License. 1664d52014SChris Lattner // ============================================================================= 1764d52014SChris Lattner // 1864d52014SChris Lattner // This file implements mlir::applyPatternsGreedily. 1964d52014SChris Lattner // 2064d52014SChris Lattner //===----------------------------------------------------------------------===// 2164d52014SChris Lattner 2264d52014SChris Lattner #include "mlir/IR/Builders.h" 23da0ebe06SRiver Riddle #include "mlir/IR/Matchers.h" 247de0da95SChris Lattner #include "mlir/IR/PatternMatch.h" 25f37651c7SRiver Riddle #include "mlir/StandardOps/Ops.h" 2664d52014SChris Lattner #include "llvm/ADT/DenseMap.h" 2764d52014SChris Lattner using namespace mlir; 2864d52014SChris Lattner 2964d52014SChris Lattner namespace { 3064d52014SChris Lattner 3164d52014SChris Lattner /// This is a worklist-driven driver for the PatternMatcher, which repeatedly 3264d52014SChris Lattner /// applies the locally optimal patterns in a roughly "bottom up" way. 334bd9f936SChris Lattner class GreedyPatternRewriteDriver : public PatternRewriter { 3464d52014SChris Lattner public: 35*46ade282SChris Lattner explicit GreedyPatternRewriteDriver(Function &fn, 364bd9f936SChris Lattner OwningRewritePatternList &&patterns) 37*46ade282SChris Lattner : PatternRewriter(fn.getContext()), matcher(std::move(patterns), *this), 38*46ade282SChris Lattner builder(&fn) { 3964d52014SChris Lattner worklist.reserve(64); 404bd9f936SChris Lattner 414bd9f936SChris Lattner // Add all operations to the worklist. 42*46ade282SChris Lattner fn.walk([&](Instruction *inst) { addToWorklist(inst); }); 4364d52014SChris Lattner } 4464d52014SChris Lattner 454bd9f936SChris Lattner /// Perform the rewrites. 464bd9f936SChris Lattner void simplifyFunction(); 4764d52014SChris Lattner 48b499277fSRiver Riddle void addToWorklist(Instruction *op) { 495c4f1fddSRiver Riddle // Check to see if the worklist already contains this op. 505c4f1fddSRiver Riddle if (worklistMap.count(op)) 515c4f1fddSRiver Riddle return; 525c4f1fddSRiver Riddle 5364d52014SChris Lattner worklistMap[op] = worklist.size(); 5464d52014SChris Lattner worklist.push_back(op); 5564d52014SChris Lattner } 5664d52014SChris Lattner 57b499277fSRiver Riddle Instruction *popFromWorklist() { 5864d52014SChris Lattner auto *op = worklist.back(); 5964d52014SChris Lattner worklist.pop_back(); 6064d52014SChris Lattner 6164d52014SChris Lattner // This operation is no longer in the worklist, keep worklistMap up to date. 6264d52014SChris Lattner if (op) 6364d52014SChris Lattner worklistMap.erase(op); 6464d52014SChris Lattner return op; 6564d52014SChris Lattner } 6664d52014SChris Lattner 6764d52014SChris Lattner /// If the specified operation is in the worklist, remove it. If not, this is 6864d52014SChris Lattner /// a no-op. 69b499277fSRiver Riddle void removeFromWorklist(Instruction *op) { 7064d52014SChris Lattner auto it = worklistMap.find(op); 7164d52014SChris Lattner if (it != worklistMap.end()) { 7264d52014SChris Lattner assert(worklist[it->second] == op && "malformed worklist data structure"); 7364d52014SChris Lattner worklist[it->second] = nullptr; 7464d52014SChris Lattner } 7564d52014SChris Lattner } 7664d52014SChris Lattner 774bd9f936SChris Lattner // These are hooks implemented for PatternRewriter. 784bd9f936SChris Lattner protected: 794bd9f936SChris Lattner // Implement the hook for creating operations, and make sure that newly 804bd9f936SChris Lattner // created ops are added to the worklist for processing. 81b499277fSRiver Riddle Instruction *createOperation(const OperationState &state) override { 824bd9f936SChris Lattner auto *result = builder.createOperation(state); 834bd9f936SChris Lattner addToWorklist(result); 844bd9f936SChris Lattner return result; 854bd9f936SChris Lattner } 8664d52014SChris Lattner 8764d52014SChris Lattner // If an operation is about to be removed, make sure it is not in our 8864d52014SChris Lattner // worklist anymore because we'd get dangling references to it. 89b499277fSRiver Riddle void notifyOperationRemoved(Instruction *op) override { 90a8866258SRiver Riddle addToWorklist(op->getOperands()); 914bd9f936SChris Lattner removeFromWorklist(op); 9264d52014SChris Lattner } 9364d52014SChris Lattner 94085b687fSChris Lattner // When the root of a pattern is about to be replaced, it can trigger 95085b687fSChris Lattner // simplifications to its users - make sure to add them to the worklist 96085b687fSChris Lattner // before the root is changed. 97b499277fSRiver Riddle void notifyRootReplaced(Instruction *op) override { 98085b687fSChris Lattner for (auto *result : op->getResults()) 99085b687fSChris Lattner // TODO: Add a result->getUsers() iterator. 100b499277fSRiver Riddle for (auto &user : result->getUses()) 101b499277fSRiver Riddle addToWorklist(user.getOwner()); 102085b687fSChris Lattner } 103085b687fSChris Lattner 1044bd9f936SChris Lattner private: 105a8866258SRiver Riddle // Look over the provided operands for any defining instructions that should 106a8866258SRiver Riddle // be re-added to the worklist. This function should be called when an 107a8866258SRiver Riddle // operation is modified or removed, as it may trigger further 108a8866258SRiver Riddle // simplifications. 109a8866258SRiver Riddle template <typename Operands> void addToWorklist(Operands &&operands) { 110a8866258SRiver Riddle for (Value *operand : operands) { 111a8866258SRiver Riddle // If the use count of this operand is now < 2, we re-add the defining 112a8866258SRiver Riddle // instruction to the worklist. 113a8866258SRiver Riddle // TODO(riverriddle) This is based on the fact that zero use instructions 114a8866258SRiver Riddle // may be deleted, and that single use values often have more 115a8866258SRiver Riddle // canonicalization opportunities. 116a8866258SRiver Riddle if (!operand->use_empty() && 117a8866258SRiver Riddle std::next(operand->use_begin()) != operand->use_end()) 118a8866258SRiver Riddle continue; 119a8866258SRiver Riddle if (auto *defInst = operand->getDefiningInst()) 120a8866258SRiver Riddle addToWorklist(defInst); 121a8866258SRiver Riddle } 122a8866258SRiver Riddle } 123a8866258SRiver Riddle 1244bd9f936SChris Lattner /// The low-level pattern matcher. 1255de726f4SRiver Riddle RewritePatternMatcher matcher; 1264bd9f936SChris Lattner 1274bd9f936SChris Lattner /// This builder is used to create new operations. 1284bd9f936SChris Lattner FuncBuilder builder; 1294bd9f936SChris Lattner 1304bd9f936SChris Lattner /// The worklist for this transformation keeps track of the operations that 1314bd9f936SChris Lattner /// need to be revisited, plus their index in the worklist. This allows us to 1324bd9f936SChris Lattner /// efficiently remove operations from the worklist when they are erased from 1334bd9f936SChris Lattner /// the function, even if they aren't the root of a pattern. 134b499277fSRiver Riddle std::vector<Instruction *> worklist; 135b499277fSRiver Riddle DenseMap<Instruction *, unsigned> worklistMap; 1364bd9f936SChris Lattner 1374bd9f936SChris Lattner /// As part of canonicalization, we move constants to the top of the entry 1384bd9f936SChris Lattner /// block of the current function and de-duplicate them. This keeps track of 1394bd9f936SChris Lattner /// constants we have done this for. 140b499277fSRiver Riddle DenseMap<std::pair<Attribute, Type>, Instruction *> uniquedConstants; 14164d52014SChris Lattner }; 1424bd9f936SChris Lattner }; // end anonymous namespace 14364d52014SChris Lattner 1444bd9f936SChris Lattner /// Perform the rewrites. 1454bd9f936SChris Lattner void GreedyPatternRewriteDriver::simplifyFunction() { 14664d52014SChris Lattner // These are scratch vectors used in the constant folding loop below. 147792d1c25SRiver Riddle SmallVector<Attribute, 8> operandConstants, resultConstants; 148934b6d12SChris Lattner SmallVector<Value *, 8> originalOperands, resultValues; 14964d52014SChris Lattner 15064d52014SChris Lattner while (!worklist.empty()) { 15164d52014SChris Lattner auto *op = popFromWorklist(); 15264d52014SChris Lattner 15364d52014SChris Lattner // Nulls get added to the worklist when operations are removed, ignore them. 15464d52014SChris Lattner if (op == nullptr) 15564d52014SChris Lattner continue; 15664d52014SChris Lattner 15764d52014SChris Lattner // If we have a constant op, unique it into the entry block. 15864d52014SChris Lattner if (auto constant = op->dyn_cast<ConstantOp>()) { 15964d52014SChris Lattner // If this constant is dead, remove it, being careful to keep 16064d52014SChris Lattner // uniquedConstants up to date. 16196ebde9cSRiver Riddle if (constant.use_empty()) { 16264d52014SChris Lattner auto it = 16396ebde9cSRiver Riddle uniquedConstants.find({constant.getValue(), constant.getType()}); 16464d52014SChris Lattner if (it != uniquedConstants.end() && it->second == op) 16564d52014SChris Lattner uniquedConstants.erase(it); 16696ebde9cSRiver Riddle constant.erase(); 16764d52014SChris Lattner continue; 16864d52014SChris Lattner } 16964d52014SChris Lattner 17064d52014SChris Lattner // Check to see if we already have a constant with this type and value: 17196ebde9cSRiver Riddle auto &entry = uniquedConstants[std::make_pair(constant.getValue(), 17296ebde9cSRiver Riddle constant.getType())]; 17364d52014SChris Lattner if (entry) { 17464d52014SChris Lattner // If this constant is already our uniqued one, then leave it alone. 17564d52014SChris Lattner if (entry == op) 17664d52014SChris Lattner continue; 17764d52014SChris Lattner 17864d52014SChris Lattner // Otherwise replace this redundant constant with the uniqued one. We 17964d52014SChris Lattner // know this is safe because we move constants to the top of the 18064d52014SChris Lattner // function when they are uniqued, so we know they dominate all uses. 18196ebde9cSRiver Riddle constant.replaceAllUsesWith(entry->getResult(0)); 18296ebde9cSRiver Riddle constant.erase(); 18364d52014SChris Lattner continue; 18464d52014SChris Lattner } 18564d52014SChris Lattner 18664d52014SChris Lattner // If we have no entry, then we should unique this constant as the 18764d52014SChris Lattner // canonical version. To ensure safe dominance, move the operation to the 18864d52014SChris Lattner // top of the function. 18964d52014SChris Lattner entry = op; 1904bd9f936SChris Lattner auto &entryBB = builder.getInsertionBlock()->getFunction()->front(); 191471c9764SChris Lattner op->moveBefore(&entryBB, entryBB.begin()); 19264d52014SChris Lattner continue; 19364d52014SChris Lattner } 19464d52014SChris Lattner 19564d52014SChris Lattner // If the operation has no side effects, and no users, then it is trivially 19664d52014SChris Lattner // dead - remove it. 19764d52014SChris Lattner if (op->hasNoSideEffect() && op->use_empty()) { 19864d52014SChris Lattner op->erase(); 19964d52014SChris Lattner continue; 20064d52014SChris Lattner } 20164d52014SChris Lattner 20264d52014SChris Lattner // Check to see if any operands to the instruction is constant and whether 20364d52014SChris Lattner // the operation knows how to constant fold itself. 204da0ebe06SRiver Riddle operandConstants.assign(op->getNumOperands(), Attribute()); 205da0ebe06SRiver Riddle for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) 206da0ebe06SRiver Riddle matchPattern(op->getOperand(i), m_Constant(&operandConstants[i])); 20764d52014SChris Lattner 208934b6d12SChris Lattner // If this is a commutative binary operation with a constant on the left 209934b6d12SChris Lattner // side move it to the right side. 210934b6d12SChris Lattner if (operandConstants.size() == 2 && operandConstants[0] && 211934b6d12SChris Lattner !operandConstants[1] && op->isCommutative()) { 212934b6d12SChris Lattner std::swap(op->getInstOperand(0), op->getInstOperand(1)); 213934b6d12SChris Lattner std::swap(operandConstants[0], operandConstants[1]); 214934b6d12SChris Lattner } 215934b6d12SChris Lattner 21664d52014SChris Lattner // If constant folding was successful, create the result constants, RAUW the 21764d52014SChris Lattner // operation and remove it. 21864d52014SChris Lattner resultConstants.clear(); 2195e1f1d2cSRiver Riddle if (succeeded(op->constantFold(operandConstants, resultConstants))) { 2204bd9f936SChris Lattner builder.setInsertionPoint(op); 22164d52014SChris Lattner 222a8866258SRiver Riddle // Add the operands to the worklist for visitation. 223a8866258SRiver Riddle addToWorklist(op->getOperands()); 224a8866258SRiver Riddle 22564d52014SChris Lattner for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { 22664d52014SChris Lattner auto *res = op->getResult(i); 22764d52014SChris Lattner if (res->use_empty()) // ignore dead uses. 22864d52014SChris Lattner continue; 22964d52014SChris Lattner 23064d52014SChris Lattner // If we already have a canonicalized version of this constant, just 23164d52014SChris Lattner // reuse it. Otherwise create a new one. 2323f190312SChris Lattner Value *cstValue; 23364d52014SChris Lattner auto it = uniquedConstants.find({resultConstants[i], res->getType()}); 23464d52014SChris Lattner if (it != uniquedConstants.end()) 23564d52014SChris Lattner cstValue = it->second->getResult(0); 23664d52014SChris Lattner else 23761ec6c09SLei Zhang cstValue = create<ConstantOp>(op->getLoc(), res->getType(), 23861ec6c09SLei Zhang resultConstants[i]); 239967d9341SChris Lattner 240967d9341SChris Lattner // Add all the users of the result to the worklist so we make sure to 241967d9341SChris Lattner // revisit them. 242967d9341SChris Lattner // 243967d9341SChris Lattner // TODO: Add a result->getUsers() iterator. 244b499277fSRiver Riddle for (auto &operand : op->getResult(i)->getUses()) 245b499277fSRiver Riddle addToWorklist(operand.getOwner()); 246967d9341SChris Lattner 24764d52014SChris Lattner res->replaceAllUsesWith(cstValue); 24864d52014SChris Lattner } 24964d52014SChris Lattner 25064d52014SChris Lattner assert(op->hasNoSideEffect() && "Constant folded op with side effects?"); 25164d52014SChris Lattner op->erase(); 25264d52014SChris Lattner continue; 25364d52014SChris Lattner } 25464d52014SChris Lattner 255934b6d12SChris Lattner // Otherwise see if we can use the generic folder API to simplify the 256934b6d12SChris Lattner // operation. 257934b6d12SChris Lattner originalOperands.assign(op->operand_begin(), op->operand_end()); 258934b6d12SChris Lattner resultValues.clear(); 2595e1f1d2cSRiver Riddle if (succeeded(op->fold(resultValues))) { 260934b6d12SChris Lattner // If the result was an in-place simplification (e.g. max(x,x,y) -> 261934b6d12SChris Lattner // max(x,y)) then add the original operands to the worklist so we can make 262934b6d12SChris Lattner // sure to revisit them. 263934b6d12SChris Lattner if (resultValues.empty()) { 264a8866258SRiver Riddle // Add the operands back to the worklist as there may be more 265a8866258SRiver Riddle // canonicalization opportunities now. 266a8866258SRiver Riddle addToWorklist(originalOperands); 267934b6d12SChris Lattner } else { 268934b6d12SChris Lattner // Otherwise, the operation is simplified away completely. 269934b6d12SChris Lattner assert(resultValues.size() == op->getNumResults()); 270934b6d12SChris Lattner 271a8866258SRiver Riddle // Notify that we are replacing this operation. 272a8866258SRiver Riddle notifyRootReplaced(op); 273a8866258SRiver Riddle 274a8866258SRiver Riddle // Replace the result values and erase the operation. 275934b6d12SChris Lattner for (unsigned i = 0, e = resultValues.size(); i != e; ++i) { 276934b6d12SChris Lattner auto *res = op->getResult(i); 277a8866258SRiver Riddle if (!res->use_empty()) 278934b6d12SChris Lattner res->replaceAllUsesWith(resultValues[i]); 279934b6d12SChris Lattner } 280934b6d12SChris Lattner 281a8866258SRiver Riddle notifyOperationRemoved(op); 282934b6d12SChris Lattner op->erase(); 28399fee0b1SRiver Riddle } 284934b6d12SChris Lattner continue; 28564d52014SChris Lattner } 28664d52014SChris Lattner 28764d52014SChris Lattner // Make sure that any new operations are inserted at this point. 2884bd9f936SChris Lattner builder.setInsertionPoint(op); 2895de726f4SRiver Riddle 2905de726f4SRiver Riddle // Try to match one of the canonicalization patterns. The rewriter is 2915de726f4SRiver Riddle // automatically notified of any necessary changes, so there is nothing else 2925de726f4SRiver Riddle // to do here. 2935de726f4SRiver Riddle matcher.matchAndRewrite(op); 29464d52014SChris Lattner } 29564d52014SChris Lattner 29664d52014SChris Lattner uniquedConstants.clear(); 29764d52014SChris Lattner } 29864d52014SChris Lattner 29964d52014SChris Lattner /// Rewrite the specified function by repeatedly applying the highest benefit 30064d52014SChris Lattner /// patterns in a greedy work-list driven manner. 30164d52014SChris Lattner /// 302*46ade282SChris Lattner void mlir::applyPatternsGreedily(Function &fn, 3033f2530cdSChris Lattner OwningRewritePatternList &&patterns) { 3044bd9f936SChris Lattner GreedyPatternRewriteDriver driver(fn, std::move(patterns)); 3054bd9f936SChris Lattner driver.simplifyFunction(); 30664d52014SChris Lattner } 307