164d52014SChris Lattner //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
264d52014SChris Lattner //
364d52014SChris Lattner // Copyright 2019 The MLIR Authors.
464d52014SChris Lattner //
564d52014SChris Lattner // Licensed under the Apache License, Version 2.0 (the "License");
664d52014SChris Lattner // you may not use this file except in compliance with the License.
764d52014SChris Lattner // You may obtain a copy of the License at
864d52014SChris Lattner //
964d52014SChris Lattner //   http://www.apache.org/licenses/LICENSE-2.0
1064d52014SChris Lattner //
1164d52014SChris Lattner // Unless required by applicable law or agreed to in writing, software
1264d52014SChris Lattner // distributed under the License is distributed on an "AS IS" BASIS,
1364d52014SChris Lattner // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1464d52014SChris Lattner // See the License for the specific language governing permissions and
1564d52014SChris Lattner // limitations under the License.
1664d52014SChris Lattner // =============================================================================
1764d52014SChris Lattner //
1864d52014SChris Lattner // This file implements mlir::applyPatternsGreedily.
1964d52014SChris Lattner //
2064d52014SChris Lattner //===----------------------------------------------------------------------===//
2164d52014SChris Lattner 
2264d52014SChris Lattner #include "mlir/IR/Builders.h"
2364d52014SChris Lattner #include "mlir/IR/BuiltinOps.h"
24*7de0da95SChris Lattner #include "mlir/IR/PatternMatch.h"
2564d52014SChris Lattner #include "llvm/ADT/DenseMap.h"
2664d52014SChris Lattner using namespace mlir;
2764d52014SChris Lattner 
2864d52014SChris Lattner namespace {
2964d52014SChris Lattner class WorklistRewriter;
3064d52014SChris Lattner 
3164d52014SChris Lattner /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
3264d52014SChris Lattner /// applies the locally optimal patterns in a roughly "bottom up" way.
3364d52014SChris Lattner class GreedyPatternRewriteDriver {
3464d52014SChris Lattner public:
3564d52014SChris Lattner   explicit GreedyPatternRewriteDriver(OwningPatternList &&patterns)
3664d52014SChris Lattner       : matcher(std::move(patterns)) {
3764d52014SChris Lattner     worklist.reserve(64);
3864d52014SChris Lattner   }
3964d52014SChris Lattner 
4064d52014SChris Lattner   void simplifyFunction(Function *currentFunction, WorklistRewriter &rewriter);
4164d52014SChris Lattner 
4264d52014SChris Lattner   void addToWorklist(Operation *op) {
4364d52014SChris Lattner     worklistMap[op] = worklist.size();
4464d52014SChris Lattner     worklist.push_back(op);
4564d52014SChris Lattner   }
4664d52014SChris Lattner 
4764d52014SChris Lattner   Operation *popFromWorklist() {
4864d52014SChris Lattner     auto *op = worklist.back();
4964d52014SChris Lattner     worklist.pop_back();
5064d52014SChris Lattner 
5164d52014SChris Lattner     // This operation is no longer in the worklist, keep worklistMap up to date.
5264d52014SChris Lattner     if (op)
5364d52014SChris Lattner       worklistMap.erase(op);
5464d52014SChris Lattner     return op;
5564d52014SChris Lattner   }
5664d52014SChris Lattner 
5764d52014SChris Lattner   /// If the specified operation is in the worklist, remove it.  If not, this is
5864d52014SChris Lattner   /// a no-op.
5964d52014SChris Lattner   void removeFromWorklist(Operation *op) {
6064d52014SChris Lattner     auto it = worklistMap.find(op);
6164d52014SChris Lattner     if (it != worklistMap.end()) {
6264d52014SChris Lattner       assert(worklist[it->second] == op && "malformed worklist data structure");
6364d52014SChris Lattner       worklist[it->second] = nullptr;
6464d52014SChris Lattner     }
6564d52014SChris Lattner   }
6664d52014SChris Lattner 
6764d52014SChris Lattner private:
6864d52014SChris Lattner   /// The low-level pattern matcher.
6964d52014SChris Lattner   PatternMatcher matcher;
7064d52014SChris Lattner 
7164d52014SChris Lattner   /// The worklist for this transformation keeps track of the operations that
7264d52014SChris Lattner   /// need to be revisited, plus their index in the worklist.  This allows us to
7364d52014SChris Lattner   /// efficiently remove operations from the worklist when they are removed even
7464d52014SChris Lattner   /// if they aren't the root of a pattern.
7564d52014SChris Lattner   std::vector<Operation *> worklist;
7664d52014SChris Lattner   DenseMap<Operation *, unsigned> worklistMap;
7764d52014SChris Lattner 
7864d52014SChris Lattner   /// As part of canonicalization, we move constants to the top of the entry
7964d52014SChris Lattner   /// block of the current function and de-duplicate them.  This keeps track of
8064d52014SChris Lattner   /// constants we have done this for.
81792d1c25SRiver Riddle   DenseMap<std::pair<Attribute, Type *>, Operation *> uniquedConstants;
8264d52014SChris Lattner };
8364d52014SChris Lattner }; // end anonymous namespace
8464d52014SChris Lattner 
8564d52014SChris Lattner /// This is a listener object that updates our worklists and other data
8664d52014SChris Lattner /// structures in response to operations being added and removed.
8764d52014SChris Lattner namespace {
8864d52014SChris Lattner class WorklistRewriter : public PatternRewriter {
8964d52014SChris Lattner public:
9064d52014SChris Lattner   WorklistRewriter(GreedyPatternRewriteDriver &driver, MLIRContext *context)
9164d52014SChris Lattner       : PatternRewriter(context), driver(driver) {}
9264d52014SChris Lattner 
9364d52014SChris Lattner   virtual void setInsertionPoint(Operation *op) = 0;
9464d52014SChris Lattner 
9564d52014SChris Lattner   // If an operation is about to be removed, make sure it is not in our
9664d52014SChris Lattner   // worklist anymore because we'd get dangling references to it.
9764d52014SChris Lattner   void notifyOperationRemoved(Operation *op) override {
9864d52014SChris Lattner     driver.removeFromWorklist(op);
9964d52014SChris Lattner   }
10064d52014SChris Lattner 
10164d52014SChris Lattner   GreedyPatternRewriteDriver &driver;
10264d52014SChris Lattner };
10364d52014SChris Lattner 
10464d52014SChris Lattner } // end anonymous namespace
10564d52014SChris Lattner 
10664d52014SChris Lattner void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
10764d52014SChris Lattner                                                   WorklistRewriter &rewriter) {
10864d52014SChris Lattner   // These are scratch vectors used in the constant folding loop below.
109792d1c25SRiver Riddle   SmallVector<Attribute, 8> operandConstants, resultConstants;
11064d52014SChris Lattner 
11164d52014SChris Lattner   while (!worklist.empty()) {
11264d52014SChris Lattner     auto *op = popFromWorklist();
11364d52014SChris Lattner 
11464d52014SChris Lattner     // Nulls get added to the worklist when operations are removed, ignore them.
11564d52014SChris Lattner     if (op == nullptr)
11664d52014SChris Lattner       continue;
11764d52014SChris Lattner 
11864d52014SChris Lattner     // If we have a constant op, unique it into the entry block.
11964d52014SChris Lattner     if (auto constant = op->dyn_cast<ConstantOp>()) {
12064d52014SChris Lattner       // If this constant is dead, remove it, being careful to keep
12164d52014SChris Lattner       // uniquedConstants up to date.
12264d52014SChris Lattner       if (constant->use_empty()) {
12364d52014SChris Lattner         auto it =
12464d52014SChris Lattner             uniquedConstants.find({constant->getValue(), constant->getType()});
12564d52014SChris Lattner         if (it != uniquedConstants.end() && it->second == op)
12664d52014SChris Lattner           uniquedConstants.erase(it);
12764d52014SChris Lattner         constant->erase();
12864d52014SChris Lattner         continue;
12964d52014SChris Lattner       }
13064d52014SChris Lattner 
13164d52014SChris Lattner       // Check to see if we already have a constant with this type and value:
13264d52014SChris Lattner       auto &entry = uniquedConstants[std::make_pair(constant->getValue(),
13364d52014SChris Lattner                                                     constant->getType())];
13464d52014SChris Lattner       if (entry) {
13564d52014SChris Lattner         // If this constant is already our uniqued one, then leave it alone.
13664d52014SChris Lattner         if (entry == op)
13764d52014SChris Lattner           continue;
13864d52014SChris Lattner 
13964d52014SChris Lattner         // Otherwise replace this redundant constant with the uniqued one.  We
14064d52014SChris Lattner         // know this is safe because we move constants to the top of the
14164d52014SChris Lattner         // function when they are uniqued, so we know they dominate all uses.
14264d52014SChris Lattner         constant->replaceAllUsesWith(entry->getResult(0));
14364d52014SChris Lattner         constant->erase();
14464d52014SChris Lattner         continue;
14564d52014SChris Lattner       }
14664d52014SChris Lattner 
14764d52014SChris Lattner       // If we have no entry, then we should unique this constant as the
14864d52014SChris Lattner       // canonical version.  To ensure safe dominance, move the operation to the
14964d52014SChris Lattner       // top of the function.
15064d52014SChris Lattner       entry = op;
15164d52014SChris Lattner 
15264d52014SChris Lattner       // TODO: If we make terminators into Operations then we could turn this
15364d52014SChris Lattner       // into a nice Operation::moveBefore(Operation*) method.  We just need the
15464d52014SChris Lattner       // guarantee that a block is non-empty.
15564d52014SChris Lattner       if (auto *cfgFunc = dyn_cast<CFGFunction>(currentFunction)) {
15664d52014SChris Lattner         auto &entryBB = cfgFunc->front();
15764d52014SChris Lattner         cast<OperationInst>(op)->moveBefore(&entryBB, entryBB.begin());
15864d52014SChris Lattner       } else {
15964d52014SChris Lattner         auto *mlFunc = cast<MLFunction>(currentFunction);
16064d52014SChris Lattner         cast<OperationStmt>(op)->moveBefore(mlFunc, mlFunc->begin());
16164d52014SChris Lattner       }
16264d52014SChris Lattner 
16364d52014SChris Lattner       continue;
16464d52014SChris Lattner     }
16564d52014SChris Lattner 
16664d52014SChris Lattner     // If the operation has no side effects, and no users, then it is trivially
16764d52014SChris Lattner     // dead - remove it.
16864d52014SChris Lattner     if (op->hasNoSideEffect() && op->use_empty()) {
16964d52014SChris Lattner       op->erase();
17064d52014SChris Lattner       continue;
17164d52014SChris Lattner     }
17264d52014SChris Lattner 
17364d52014SChris Lattner     // Check to see if any operands to the instruction is constant and whether
17464d52014SChris Lattner     // the operation knows how to constant fold itself.
17564d52014SChris Lattner     operandConstants.clear();
17664d52014SChris Lattner     for (auto *operand : op->getOperands()) {
177792d1c25SRiver Riddle       Attribute operandCst;
17864d52014SChris Lattner       if (auto *operandOp = operand->getDefiningOperation()) {
17964d52014SChris Lattner         if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
18064d52014SChris Lattner           operandCst = operandConstantOp->getValue();
18164d52014SChris Lattner       }
18264d52014SChris Lattner       operandConstants.push_back(operandCst);
18364d52014SChris Lattner     }
18464d52014SChris Lattner 
18564d52014SChris Lattner     // If constant folding was successful, create the result constants, RAUW the
18664d52014SChris Lattner     // operation and remove it.
18764d52014SChris Lattner     resultConstants.clear();
18864d52014SChris Lattner     if (!op->constantFold(operandConstants, resultConstants)) {
18964d52014SChris Lattner       rewriter.setInsertionPoint(op);
19064d52014SChris Lattner 
19164d52014SChris Lattner       for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
19264d52014SChris Lattner         auto *res = op->getResult(i);
19364d52014SChris Lattner         if (res->use_empty()) // ignore dead uses.
19464d52014SChris Lattner           continue;
19564d52014SChris Lattner 
19664d52014SChris Lattner         // If we already have a canonicalized version of this constant, just
19764d52014SChris Lattner         // reuse it.  Otherwise create a new one.
19864d52014SChris Lattner         SSAValue *cstValue;
19964d52014SChris Lattner         auto it = uniquedConstants.find({resultConstants[i], res->getType()});
20064d52014SChris Lattner         if (it != uniquedConstants.end())
20164d52014SChris Lattner           cstValue = it->second->getResult(0);
20264d52014SChris Lattner         else
20364d52014SChris Lattner           cstValue = rewriter.create<ConstantOp>(
20464d52014SChris Lattner               op->getLoc(), resultConstants[i], res->getType());
20564d52014SChris Lattner         res->replaceAllUsesWith(cstValue);
20664d52014SChris Lattner       }
20764d52014SChris Lattner 
20864d52014SChris Lattner       assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
20964d52014SChris Lattner       op->erase();
21064d52014SChris Lattner       continue;
21164d52014SChris Lattner     }
21264d52014SChris Lattner 
21364d52014SChris Lattner     // If this is an associative binary operation with a constant on the LHS,
21464d52014SChris Lattner     // move it to the right side.
21564d52014SChris Lattner     if (operandConstants.size() == 2 && operandConstants[0] &&
21664d52014SChris Lattner         !operandConstants[1]) {
21764d52014SChris Lattner       auto *newLHS = op->getOperand(1);
21864d52014SChris Lattner       op->setOperand(1, op->getOperand(0));
21964d52014SChris Lattner       op->setOperand(0, newLHS);
22064d52014SChris Lattner     }
22164d52014SChris Lattner 
22264d52014SChris Lattner     // Check to see if we have any patterns that match this node.
22364d52014SChris Lattner     auto match = matcher.findMatch(op);
22464d52014SChris Lattner     if (!match.first)
22564d52014SChris Lattner       continue;
22664d52014SChris Lattner 
22764d52014SChris Lattner     // Make sure that any new operations are inserted at this point.
22864d52014SChris Lattner     rewriter.setInsertionPoint(op);
22964d52014SChris Lattner     match.first->rewrite(op, std::move(match.second), rewriter);
23064d52014SChris Lattner   }
23164d52014SChris Lattner 
23264d52014SChris Lattner   uniquedConstants.clear();
23364d52014SChris Lattner }
23464d52014SChris Lattner 
23564d52014SChris Lattner static void processMLFunction(MLFunction *fn, OwningPatternList &&patterns) {
23664d52014SChris Lattner   class MLFuncRewriter : public WorklistRewriter {
23764d52014SChris Lattner   public:
23864d52014SChris Lattner     MLFuncRewriter(GreedyPatternRewriteDriver &driver, MLFuncBuilder &builder)
23964d52014SChris Lattner         : WorklistRewriter(driver, builder.getContext()), builder(builder) {}
24064d52014SChris Lattner 
24164d52014SChris Lattner     // Implement the hook for creating operations, and make sure that newly
24264d52014SChris Lattner     // created ops are added to the worklist for processing.
24364d52014SChris Lattner     Operation *createOperation(const OperationState &state) override {
24464d52014SChris Lattner       auto *result = builder.createOperation(state);
24564d52014SChris Lattner       driver.addToWorklist(result);
24664d52014SChris Lattner       return result;
24764d52014SChris Lattner     }
24864d52014SChris Lattner 
24964d52014SChris Lattner     // When the root of a pattern is about to be replaced, it can trigger
25064d52014SChris Lattner     // simplifications to its users - make sure to add them to the worklist
25164d52014SChris Lattner     // before the root is changed.
25264d52014SChris Lattner     void notifyRootReplaced(Operation *op) override {
25364d52014SChris Lattner       auto *opStmt = cast<OperationStmt>(op);
25464d52014SChris Lattner       for (auto *result : opStmt->getResults())
25564d52014SChris Lattner         // TODO: Add a result->getUsers() iterator.
25664d52014SChris Lattner         for (auto &user : result->getUses()) {
25764d52014SChris Lattner           if (auto *op = dyn_cast<OperationStmt>(user.getOwner()))
25864d52014SChris Lattner             driver.addToWorklist(op);
25964d52014SChris Lattner         }
26064d52014SChris Lattner 
26164d52014SChris Lattner       // TODO: Walk the operand list dropping them as we go.  If any of them
26264d52014SChris Lattner       // drop to zero uses, then add them to the worklist to allow them to be
26364d52014SChris Lattner       // deleted as dead.
26464d52014SChris Lattner     }
26564d52014SChris Lattner 
26664d52014SChris Lattner     void setInsertionPoint(Operation *op) override {
26764d52014SChris Lattner       // Any new operations should be added before this statement.
26864d52014SChris Lattner       builder.setInsertionPoint(cast<OperationStmt>(op));
26964d52014SChris Lattner     }
27064d52014SChris Lattner 
27164d52014SChris Lattner   private:
27264d52014SChris Lattner     MLFuncBuilder &builder;
27364d52014SChris Lattner   };
27464d52014SChris Lattner 
27564d52014SChris Lattner   GreedyPatternRewriteDriver driver(std::move(patterns));
27664d52014SChris Lattner   fn->walk([&](OperationStmt *stmt) { driver.addToWorklist(stmt); });
27764d52014SChris Lattner 
27864d52014SChris Lattner   MLFuncBuilder mlBuilder(fn);
27964d52014SChris Lattner   MLFuncRewriter rewriter(driver, mlBuilder);
28064d52014SChris Lattner   driver.simplifyFunction(fn, rewriter);
28164d52014SChris Lattner }
28264d52014SChris Lattner 
28364d52014SChris Lattner static void processCFGFunction(CFGFunction *fn, OwningPatternList &&patterns) {
28464d52014SChris Lattner   class CFGFuncRewriter : public WorklistRewriter {
28564d52014SChris Lattner   public:
28664d52014SChris Lattner     CFGFuncRewriter(GreedyPatternRewriteDriver &driver, CFGFuncBuilder &builder)
28764d52014SChris Lattner         : WorklistRewriter(driver, builder.getContext()), builder(builder) {}
28864d52014SChris Lattner 
28964d52014SChris Lattner     // Implement the hook for creating operations, and make sure that newly
29064d52014SChris Lattner     // created ops are added to the worklist for processing.
29164d52014SChris Lattner     Operation *createOperation(const OperationState &state) override {
29264d52014SChris Lattner       auto *result = builder.createOperation(state);
29364d52014SChris Lattner       driver.addToWorklist(result);
29464d52014SChris Lattner       return result;
29564d52014SChris Lattner     }
29664d52014SChris Lattner 
29764d52014SChris Lattner     // When the root of a pattern is about to be replaced, it can trigger
29864d52014SChris Lattner     // simplifications to its users - make sure to add them to the worklist
29964d52014SChris Lattner     // before the root is changed.
30064d52014SChris Lattner     void notifyRootReplaced(Operation *op) override {
30164d52014SChris Lattner       auto *opStmt = cast<OperationInst>(op);
30264d52014SChris Lattner       for (auto *result : opStmt->getResults())
30364d52014SChris Lattner         // TODO: Add a result->getUsers() iterator.
30464d52014SChris Lattner         for (auto &user : result->getUses()) {
30564d52014SChris Lattner           if (auto *op = dyn_cast<OperationInst>(user.getOwner()))
30664d52014SChris Lattner             driver.addToWorklist(op);
30764d52014SChris Lattner         }
30864d52014SChris Lattner 
30964d52014SChris Lattner       // TODO: Walk the operand list dropping them as we go.  If any of them
31064d52014SChris Lattner       // drop to zero uses, then add them to the worklist to allow them to be
31164d52014SChris Lattner       // deleted as dead.
31264d52014SChris Lattner     }
31364d52014SChris Lattner 
31464d52014SChris Lattner     void setInsertionPoint(Operation *op) override {
31564d52014SChris Lattner       // Any new operations should be added before this instruction.
31664d52014SChris Lattner       builder.setInsertionPoint(cast<OperationInst>(op));
31764d52014SChris Lattner     }
31864d52014SChris Lattner 
31964d52014SChris Lattner   private:
32064d52014SChris Lattner     CFGFuncBuilder &builder;
32164d52014SChris Lattner   };
32264d52014SChris Lattner 
32364d52014SChris Lattner   GreedyPatternRewriteDriver driver(std::move(patterns));
32464d52014SChris Lattner   for (auto &bb : *fn)
32564d52014SChris Lattner     for (auto &op : bb)
32664d52014SChris Lattner       driver.addToWorklist(&op);
32764d52014SChris Lattner 
32864d52014SChris Lattner   CFGFuncBuilder cfgBuilder(fn);
32964d52014SChris Lattner   CFGFuncRewriter rewriter(driver, cfgBuilder);
33064d52014SChris Lattner   driver.simplifyFunction(fn, rewriter);
33164d52014SChris Lattner }
33264d52014SChris Lattner 
33364d52014SChris Lattner /// Rewrite the specified function by repeatedly applying the highest benefit
33464d52014SChris Lattner /// patterns in a greedy work-list driven manner.
33564d52014SChris Lattner ///
33664d52014SChris Lattner void mlir::applyPatternsGreedily(Function *fn, OwningPatternList &&patterns) {
33764d52014SChris Lattner   if (auto *cfg = dyn_cast<CFGFunction>(fn)) {
33864d52014SChris Lattner     processCFGFunction(cfg, std::move(patterns));
33964d52014SChris Lattner   } else {
34064d52014SChris Lattner     processMLFunction(cast<MLFunction>(fn), std::move(patterns));
34164d52014SChris Lattner   }
34264d52014SChris Lattner }
343