1*64d52014SChris Lattner //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===// 2*64d52014SChris Lattner // 3*64d52014SChris Lattner // Copyright 2019 The MLIR Authors. 4*64d52014SChris Lattner // 5*64d52014SChris Lattner // Licensed under the Apache License, Version 2.0 (the "License"); 6*64d52014SChris Lattner // you may not use this file except in compliance with the License. 7*64d52014SChris Lattner // You may obtain a copy of the License at 8*64d52014SChris Lattner // 9*64d52014SChris Lattner // http://www.apache.org/licenses/LICENSE-2.0 10*64d52014SChris Lattner // 11*64d52014SChris Lattner // Unless required by applicable law or agreed to in writing, software 12*64d52014SChris Lattner // distributed under the License is distributed on an "AS IS" BASIS, 13*64d52014SChris Lattner // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14*64d52014SChris Lattner // See the License for the specific language governing permissions and 15*64d52014SChris Lattner // limitations under the License. 16*64d52014SChris Lattner // ============================================================================= 17*64d52014SChris Lattner // 18*64d52014SChris Lattner // This file implements mlir::applyPatternsGreedily. 19*64d52014SChris Lattner // 20*64d52014SChris Lattner //===----------------------------------------------------------------------===// 21*64d52014SChris Lattner 22*64d52014SChris Lattner #include "mlir/IR/Builders.h" 23*64d52014SChris Lattner #include "mlir/IR/BuiltinOps.h" 24*64d52014SChris Lattner #include "mlir/StandardOps/StandardOps.h" 25*64d52014SChris Lattner #include "mlir/Transforms/PatternMatch.h" 26*64d52014SChris Lattner #include "llvm/ADT/DenseMap.h" 27*64d52014SChris Lattner using namespace mlir; 28*64d52014SChris Lattner 29*64d52014SChris Lattner namespace { 30*64d52014SChris Lattner class WorklistRewriter; 31*64d52014SChris Lattner 32*64d52014SChris Lattner /// This is a worklist-driven driver for the PatternMatcher, which repeatedly 33*64d52014SChris Lattner /// applies the locally optimal patterns in a roughly "bottom up" way. 34*64d52014SChris Lattner class GreedyPatternRewriteDriver { 35*64d52014SChris Lattner public: 36*64d52014SChris Lattner explicit GreedyPatternRewriteDriver(OwningPatternList &&patterns) 37*64d52014SChris Lattner : matcher(std::move(patterns)) { 38*64d52014SChris Lattner worklist.reserve(64); 39*64d52014SChris Lattner } 40*64d52014SChris Lattner 41*64d52014SChris Lattner void simplifyFunction(Function *currentFunction, WorklistRewriter &rewriter); 42*64d52014SChris Lattner 43*64d52014SChris Lattner void addToWorklist(Operation *op) { 44*64d52014SChris Lattner worklistMap[op] = worklist.size(); 45*64d52014SChris Lattner worklist.push_back(op); 46*64d52014SChris Lattner } 47*64d52014SChris Lattner 48*64d52014SChris Lattner Operation *popFromWorklist() { 49*64d52014SChris Lattner auto *op = worklist.back(); 50*64d52014SChris Lattner worklist.pop_back(); 51*64d52014SChris Lattner 52*64d52014SChris Lattner // This operation is no longer in the worklist, keep worklistMap up to date. 53*64d52014SChris Lattner if (op) 54*64d52014SChris Lattner worklistMap.erase(op); 55*64d52014SChris Lattner return op; 56*64d52014SChris Lattner } 57*64d52014SChris Lattner 58*64d52014SChris Lattner /// If the specified operation is in the worklist, remove it. If not, this is 59*64d52014SChris Lattner /// a no-op. 60*64d52014SChris Lattner void removeFromWorklist(Operation *op) { 61*64d52014SChris Lattner auto it = worklistMap.find(op); 62*64d52014SChris Lattner if (it != worklistMap.end()) { 63*64d52014SChris Lattner assert(worklist[it->second] == op && "malformed worklist data structure"); 64*64d52014SChris Lattner worklist[it->second] = nullptr; 65*64d52014SChris Lattner } 66*64d52014SChris Lattner } 67*64d52014SChris Lattner 68*64d52014SChris Lattner private: 69*64d52014SChris Lattner /// The low-level pattern matcher. 70*64d52014SChris Lattner PatternMatcher matcher; 71*64d52014SChris Lattner 72*64d52014SChris Lattner /// The worklist for this transformation keeps track of the operations that 73*64d52014SChris Lattner /// need to be revisited, plus their index in the worklist. This allows us to 74*64d52014SChris Lattner /// efficiently remove operations from the worklist when they are removed even 75*64d52014SChris Lattner /// if they aren't the root of a pattern. 76*64d52014SChris Lattner std::vector<Operation *> worklist; 77*64d52014SChris Lattner DenseMap<Operation *, unsigned> worklistMap; 78*64d52014SChris Lattner 79*64d52014SChris Lattner /// As part of canonicalization, we move constants to the top of the entry 80*64d52014SChris Lattner /// block of the current function and de-duplicate them. This keeps track of 81*64d52014SChris Lattner /// constants we have done this for. 82*64d52014SChris Lattner DenseMap<std::pair<Attribute *, Type *>, Operation *> uniquedConstants; 83*64d52014SChris Lattner }; 84*64d52014SChris Lattner }; // end anonymous namespace 85*64d52014SChris Lattner 86*64d52014SChris Lattner /// This is a listener object that updates our worklists and other data 87*64d52014SChris Lattner /// structures in response to operations being added and removed. 88*64d52014SChris Lattner namespace { 89*64d52014SChris Lattner class WorklistRewriter : public PatternRewriter { 90*64d52014SChris Lattner public: 91*64d52014SChris Lattner WorklistRewriter(GreedyPatternRewriteDriver &driver, MLIRContext *context) 92*64d52014SChris Lattner : PatternRewriter(context), driver(driver) {} 93*64d52014SChris Lattner 94*64d52014SChris Lattner virtual void setInsertionPoint(Operation *op) = 0; 95*64d52014SChris Lattner 96*64d52014SChris Lattner // If an operation is about to be removed, make sure it is not in our 97*64d52014SChris Lattner // worklist anymore because we'd get dangling references to it. 98*64d52014SChris Lattner void notifyOperationRemoved(Operation *op) override { 99*64d52014SChris Lattner driver.removeFromWorklist(op); 100*64d52014SChris Lattner } 101*64d52014SChris Lattner 102*64d52014SChris Lattner GreedyPatternRewriteDriver &driver; 103*64d52014SChris Lattner }; 104*64d52014SChris Lattner 105*64d52014SChris Lattner } // end anonymous namespace 106*64d52014SChris Lattner 107*64d52014SChris Lattner void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, 108*64d52014SChris Lattner WorklistRewriter &rewriter) { 109*64d52014SChris Lattner // These are scratch vectors used in the constant folding loop below. 110*64d52014SChris Lattner SmallVector<Attribute *, 8> operandConstants, resultConstants; 111*64d52014SChris Lattner 112*64d52014SChris Lattner while (!worklist.empty()) { 113*64d52014SChris Lattner auto *op = popFromWorklist(); 114*64d52014SChris Lattner 115*64d52014SChris Lattner // Nulls get added to the worklist when operations are removed, ignore them. 116*64d52014SChris Lattner if (op == nullptr) 117*64d52014SChris Lattner continue; 118*64d52014SChris Lattner 119*64d52014SChris Lattner // If we have a constant op, unique it into the entry block. 120*64d52014SChris Lattner if (auto constant = op->dyn_cast<ConstantOp>()) { 121*64d52014SChris Lattner // If this constant is dead, remove it, being careful to keep 122*64d52014SChris Lattner // uniquedConstants up to date. 123*64d52014SChris Lattner if (constant->use_empty()) { 124*64d52014SChris Lattner auto it = 125*64d52014SChris Lattner uniquedConstants.find({constant->getValue(), constant->getType()}); 126*64d52014SChris Lattner if (it != uniquedConstants.end() && it->second == op) 127*64d52014SChris Lattner uniquedConstants.erase(it); 128*64d52014SChris Lattner constant->erase(); 129*64d52014SChris Lattner continue; 130*64d52014SChris Lattner } 131*64d52014SChris Lattner 132*64d52014SChris Lattner // Check to see if we already have a constant with this type and value: 133*64d52014SChris Lattner auto &entry = uniquedConstants[std::make_pair(constant->getValue(), 134*64d52014SChris Lattner constant->getType())]; 135*64d52014SChris Lattner if (entry) { 136*64d52014SChris Lattner // If this constant is already our uniqued one, then leave it alone. 137*64d52014SChris Lattner if (entry == op) 138*64d52014SChris Lattner continue; 139*64d52014SChris Lattner 140*64d52014SChris Lattner // Otherwise replace this redundant constant with the uniqued one. We 141*64d52014SChris Lattner // know this is safe because we move constants to the top of the 142*64d52014SChris Lattner // function when they are uniqued, so we know they dominate all uses. 143*64d52014SChris Lattner constant->replaceAllUsesWith(entry->getResult(0)); 144*64d52014SChris Lattner constant->erase(); 145*64d52014SChris Lattner continue; 146*64d52014SChris Lattner } 147*64d52014SChris Lattner 148*64d52014SChris Lattner // If we have no entry, then we should unique this constant as the 149*64d52014SChris Lattner // canonical version. To ensure safe dominance, move the operation to the 150*64d52014SChris Lattner // top of the function. 151*64d52014SChris Lattner entry = op; 152*64d52014SChris Lattner 153*64d52014SChris Lattner // TODO: If we make terminators into Operations then we could turn this 154*64d52014SChris Lattner // into a nice Operation::moveBefore(Operation*) method. We just need the 155*64d52014SChris Lattner // guarantee that a block is non-empty. 156*64d52014SChris Lattner if (auto *cfgFunc = dyn_cast<CFGFunction>(currentFunction)) { 157*64d52014SChris Lattner auto &entryBB = cfgFunc->front(); 158*64d52014SChris Lattner cast<OperationInst>(op)->moveBefore(&entryBB, entryBB.begin()); 159*64d52014SChris Lattner } else { 160*64d52014SChris Lattner auto *mlFunc = cast<MLFunction>(currentFunction); 161*64d52014SChris Lattner cast<OperationStmt>(op)->moveBefore(mlFunc, mlFunc->begin()); 162*64d52014SChris Lattner } 163*64d52014SChris Lattner 164*64d52014SChris Lattner continue; 165*64d52014SChris Lattner } 166*64d52014SChris Lattner 167*64d52014SChris Lattner // If the operation has no side effects, and no users, then it is trivially 168*64d52014SChris Lattner // dead - remove it. 169*64d52014SChris Lattner if (op->hasNoSideEffect() && op->use_empty()) { 170*64d52014SChris Lattner op->erase(); 171*64d52014SChris Lattner continue; 172*64d52014SChris Lattner } 173*64d52014SChris Lattner 174*64d52014SChris Lattner // Check to see if any operands to the instruction is constant and whether 175*64d52014SChris Lattner // the operation knows how to constant fold itself. 176*64d52014SChris Lattner operandConstants.clear(); 177*64d52014SChris Lattner for (auto *operand : op->getOperands()) { 178*64d52014SChris Lattner Attribute *operandCst = nullptr; 179*64d52014SChris Lattner if (auto *operandOp = operand->getDefiningOperation()) { 180*64d52014SChris Lattner if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>()) 181*64d52014SChris Lattner operandCst = operandConstantOp->getValue(); 182*64d52014SChris Lattner } 183*64d52014SChris Lattner operandConstants.push_back(operandCst); 184*64d52014SChris Lattner } 185*64d52014SChris Lattner 186*64d52014SChris Lattner // If constant folding was successful, create the result constants, RAUW the 187*64d52014SChris Lattner // operation and remove it. 188*64d52014SChris Lattner resultConstants.clear(); 189*64d52014SChris Lattner if (!op->constantFold(operandConstants, resultConstants)) { 190*64d52014SChris Lattner rewriter.setInsertionPoint(op); 191*64d52014SChris Lattner 192*64d52014SChris Lattner for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { 193*64d52014SChris Lattner auto *res = op->getResult(i); 194*64d52014SChris Lattner if (res->use_empty()) // ignore dead uses. 195*64d52014SChris Lattner continue; 196*64d52014SChris Lattner 197*64d52014SChris Lattner // If we already have a canonicalized version of this constant, just 198*64d52014SChris Lattner // reuse it. Otherwise create a new one. 199*64d52014SChris Lattner SSAValue *cstValue; 200*64d52014SChris Lattner auto it = uniquedConstants.find({resultConstants[i], res->getType()}); 201*64d52014SChris Lattner if (it != uniquedConstants.end()) 202*64d52014SChris Lattner cstValue = it->second->getResult(0); 203*64d52014SChris Lattner else 204*64d52014SChris Lattner cstValue = rewriter.create<ConstantOp>( 205*64d52014SChris Lattner op->getLoc(), resultConstants[i], res->getType()); 206*64d52014SChris Lattner res->replaceAllUsesWith(cstValue); 207*64d52014SChris Lattner } 208*64d52014SChris Lattner 209*64d52014SChris Lattner assert(op->hasNoSideEffect() && "Constant folded op with side effects?"); 210*64d52014SChris Lattner op->erase(); 211*64d52014SChris Lattner continue; 212*64d52014SChris Lattner } 213*64d52014SChris Lattner 214*64d52014SChris Lattner // If this is an associative binary operation with a constant on the LHS, 215*64d52014SChris Lattner // move it to the right side. 216*64d52014SChris Lattner if (operandConstants.size() == 2 && operandConstants[0] && 217*64d52014SChris Lattner !operandConstants[1]) { 218*64d52014SChris Lattner auto *newLHS = op->getOperand(1); 219*64d52014SChris Lattner op->setOperand(1, op->getOperand(0)); 220*64d52014SChris Lattner op->setOperand(0, newLHS); 221*64d52014SChris Lattner } 222*64d52014SChris Lattner 223*64d52014SChris Lattner // Check to see if we have any patterns that match this node. 224*64d52014SChris Lattner auto match = matcher.findMatch(op); 225*64d52014SChris Lattner if (!match.first) 226*64d52014SChris Lattner continue; 227*64d52014SChris Lattner 228*64d52014SChris Lattner // Make sure that any new operations are inserted at this point. 229*64d52014SChris Lattner rewriter.setInsertionPoint(op); 230*64d52014SChris Lattner match.first->rewrite(op, std::move(match.second), rewriter); 231*64d52014SChris Lattner } 232*64d52014SChris Lattner 233*64d52014SChris Lattner uniquedConstants.clear(); 234*64d52014SChris Lattner } 235*64d52014SChris Lattner 236*64d52014SChris Lattner static void processMLFunction(MLFunction *fn, OwningPatternList &&patterns) { 237*64d52014SChris Lattner class MLFuncRewriter : public WorklistRewriter { 238*64d52014SChris Lattner public: 239*64d52014SChris Lattner MLFuncRewriter(GreedyPatternRewriteDriver &driver, MLFuncBuilder &builder) 240*64d52014SChris Lattner : WorklistRewriter(driver, builder.getContext()), builder(builder) {} 241*64d52014SChris Lattner 242*64d52014SChris Lattner // Implement the hook for creating operations, and make sure that newly 243*64d52014SChris Lattner // created ops are added to the worklist for processing. 244*64d52014SChris Lattner Operation *createOperation(const OperationState &state) override { 245*64d52014SChris Lattner auto *result = builder.createOperation(state); 246*64d52014SChris Lattner driver.addToWorklist(result); 247*64d52014SChris Lattner return result; 248*64d52014SChris Lattner } 249*64d52014SChris Lattner 250*64d52014SChris Lattner // When the root of a pattern is about to be replaced, it can trigger 251*64d52014SChris Lattner // simplifications to its users - make sure to add them to the worklist 252*64d52014SChris Lattner // before the root is changed. 253*64d52014SChris Lattner void notifyRootReplaced(Operation *op) override { 254*64d52014SChris Lattner auto *opStmt = cast<OperationStmt>(op); 255*64d52014SChris Lattner for (auto *result : opStmt->getResults()) 256*64d52014SChris Lattner // TODO: Add a result->getUsers() iterator. 257*64d52014SChris Lattner for (auto &user : result->getUses()) { 258*64d52014SChris Lattner if (auto *op = dyn_cast<OperationStmt>(user.getOwner())) 259*64d52014SChris Lattner driver.addToWorklist(op); 260*64d52014SChris Lattner } 261*64d52014SChris Lattner 262*64d52014SChris Lattner // TODO: Walk the operand list dropping them as we go. If any of them 263*64d52014SChris Lattner // drop to zero uses, then add them to the worklist to allow them to be 264*64d52014SChris Lattner // deleted as dead. 265*64d52014SChris Lattner } 266*64d52014SChris Lattner 267*64d52014SChris Lattner void setInsertionPoint(Operation *op) override { 268*64d52014SChris Lattner // Any new operations should be added before this statement. 269*64d52014SChris Lattner builder.setInsertionPoint(cast<OperationStmt>(op)); 270*64d52014SChris Lattner } 271*64d52014SChris Lattner 272*64d52014SChris Lattner private: 273*64d52014SChris Lattner MLFuncBuilder &builder; 274*64d52014SChris Lattner }; 275*64d52014SChris Lattner 276*64d52014SChris Lattner GreedyPatternRewriteDriver driver(std::move(patterns)); 277*64d52014SChris Lattner fn->walk([&](OperationStmt *stmt) { driver.addToWorklist(stmt); }); 278*64d52014SChris Lattner 279*64d52014SChris Lattner MLFuncBuilder mlBuilder(fn); 280*64d52014SChris Lattner MLFuncRewriter rewriter(driver, mlBuilder); 281*64d52014SChris Lattner driver.simplifyFunction(fn, rewriter); 282*64d52014SChris Lattner } 283*64d52014SChris Lattner 284*64d52014SChris Lattner static void processCFGFunction(CFGFunction *fn, OwningPatternList &&patterns) { 285*64d52014SChris Lattner class CFGFuncRewriter : public WorklistRewriter { 286*64d52014SChris Lattner public: 287*64d52014SChris Lattner CFGFuncRewriter(GreedyPatternRewriteDriver &driver, CFGFuncBuilder &builder) 288*64d52014SChris Lattner : WorklistRewriter(driver, builder.getContext()), builder(builder) {} 289*64d52014SChris Lattner 290*64d52014SChris Lattner // Implement the hook for creating operations, and make sure that newly 291*64d52014SChris Lattner // created ops are added to the worklist for processing. 292*64d52014SChris Lattner Operation *createOperation(const OperationState &state) override { 293*64d52014SChris Lattner auto *result = builder.createOperation(state); 294*64d52014SChris Lattner driver.addToWorklist(result); 295*64d52014SChris Lattner return result; 296*64d52014SChris Lattner } 297*64d52014SChris Lattner 298*64d52014SChris Lattner // When the root of a pattern is about to be replaced, it can trigger 299*64d52014SChris Lattner // simplifications to its users - make sure to add them to the worklist 300*64d52014SChris Lattner // before the root is changed. 301*64d52014SChris Lattner void notifyRootReplaced(Operation *op) override { 302*64d52014SChris Lattner auto *opStmt = cast<OperationInst>(op); 303*64d52014SChris Lattner for (auto *result : opStmt->getResults()) 304*64d52014SChris Lattner // TODO: Add a result->getUsers() iterator. 305*64d52014SChris Lattner for (auto &user : result->getUses()) { 306*64d52014SChris Lattner if (auto *op = dyn_cast<OperationInst>(user.getOwner())) 307*64d52014SChris Lattner driver.addToWorklist(op); 308*64d52014SChris Lattner } 309*64d52014SChris Lattner 310*64d52014SChris Lattner // TODO: Walk the operand list dropping them as we go. If any of them 311*64d52014SChris Lattner // drop to zero uses, then add them to the worklist to allow them to be 312*64d52014SChris Lattner // deleted as dead. 313*64d52014SChris Lattner } 314*64d52014SChris Lattner 315*64d52014SChris Lattner void setInsertionPoint(Operation *op) override { 316*64d52014SChris Lattner // Any new operations should be added before this instruction. 317*64d52014SChris Lattner builder.setInsertionPoint(cast<OperationInst>(op)); 318*64d52014SChris Lattner } 319*64d52014SChris Lattner 320*64d52014SChris Lattner private: 321*64d52014SChris Lattner CFGFuncBuilder &builder; 322*64d52014SChris Lattner }; 323*64d52014SChris Lattner 324*64d52014SChris Lattner GreedyPatternRewriteDriver driver(std::move(patterns)); 325*64d52014SChris Lattner for (auto &bb : *fn) 326*64d52014SChris Lattner for (auto &op : bb) 327*64d52014SChris Lattner driver.addToWorklist(&op); 328*64d52014SChris Lattner 329*64d52014SChris Lattner CFGFuncBuilder cfgBuilder(fn); 330*64d52014SChris Lattner CFGFuncRewriter rewriter(driver, cfgBuilder); 331*64d52014SChris Lattner driver.simplifyFunction(fn, rewriter); 332*64d52014SChris Lattner } 333*64d52014SChris Lattner 334*64d52014SChris Lattner /// Rewrite the specified function by repeatedly applying the highest benefit 335*64d52014SChris Lattner /// patterns in a greedy work-list driven manner. 336*64d52014SChris Lattner /// 337*64d52014SChris Lattner void mlir::applyPatternsGreedily(Function *fn, OwningPatternList &&patterns) { 338*64d52014SChris Lattner if (auto *cfg = dyn_cast<CFGFunction>(fn)) { 339*64d52014SChris Lattner processCFGFunction(cfg, std::move(patterns)); 340*64d52014SChris Lattner } else { 341*64d52014SChris Lattner processMLFunction(cast<MLFunction>(fn), std::move(patterns)); 342*64d52014SChris Lattner } 343*64d52014SChris Lattner } 344