164d52014SChris Lattner //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
264d52014SChris Lattner //
364d52014SChris Lattner // Copyright 2019 The MLIR Authors.
464d52014SChris Lattner //
564d52014SChris Lattner // Licensed under the Apache License, Version 2.0 (the "License");
664d52014SChris Lattner // you may not use this file except in compliance with the License.
764d52014SChris Lattner // You may obtain a copy of the License at
864d52014SChris Lattner //
964d52014SChris Lattner //   http://www.apache.org/licenses/LICENSE-2.0
1064d52014SChris Lattner //
1164d52014SChris Lattner // Unless required by applicable law or agreed to in writing, software
1264d52014SChris Lattner // distributed under the License is distributed on an "AS IS" BASIS,
1364d52014SChris Lattner // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1464d52014SChris Lattner // See the License for the specific language governing permissions and
1564d52014SChris Lattner // limitations under the License.
1664d52014SChris Lattner // =============================================================================
1764d52014SChris Lattner //
1864d52014SChris Lattner // This file implements mlir::applyPatternsGreedily.
1964d52014SChris Lattner //
2064d52014SChris Lattner //===----------------------------------------------------------------------===//
2164d52014SChris Lattner 
2264d52014SChris Lattner #include "mlir/IR/Builders.h"
2364d52014SChris Lattner #include "mlir/IR/BuiltinOps.h"
247de0da95SChris Lattner #include "mlir/IR/PatternMatch.h"
2564d52014SChris Lattner #include "llvm/ADT/DenseMap.h"
2664d52014SChris Lattner using namespace mlir;
2764d52014SChris Lattner 
2864d52014SChris Lattner namespace {
2964d52014SChris Lattner 
3064d52014SChris Lattner /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
3164d52014SChris Lattner /// applies the locally optimal patterns in a roughly "bottom up" way.
324bd9f936SChris Lattner class GreedyPatternRewriteDriver : public PatternRewriter {
3364d52014SChris Lattner public:
344bd9f936SChris Lattner   explicit GreedyPatternRewriteDriver(Function *fn,
354bd9f936SChris Lattner                                       OwningRewritePatternList &&patterns)
364bd9f936SChris Lattner       : PatternRewriter(fn->getContext()), matcher(std::move(patterns)),
374bd9f936SChris Lattner         builder(fn) {
3864d52014SChris Lattner     worklist.reserve(64);
394bd9f936SChris Lattner 
404bd9f936SChris Lattner     // Add all operations to the worklist.
414bd9f936SChris Lattner     fn->walkOps([&](OperationInst *inst) { addToWorklist(inst); });
4264d52014SChris Lattner   }
4364d52014SChris Lattner 
444bd9f936SChris Lattner   /// Perform the rewrites.
454bd9f936SChris Lattner   void simplifyFunction();
4664d52014SChris Lattner 
475187cfcfSChris Lattner   void addToWorklist(OperationInst *op) {
485c4f1fddSRiver Riddle     // Check to see if the worklist already contains this op.
495c4f1fddSRiver Riddle     if (worklistMap.count(op))
505c4f1fddSRiver Riddle       return;
515c4f1fddSRiver Riddle 
5264d52014SChris Lattner     worklistMap[op] = worklist.size();
5364d52014SChris Lattner     worklist.push_back(op);
5464d52014SChris Lattner   }
5564d52014SChris Lattner 
565187cfcfSChris Lattner   OperationInst *popFromWorklist() {
5764d52014SChris Lattner     auto *op = worklist.back();
5864d52014SChris Lattner     worklist.pop_back();
5964d52014SChris Lattner 
6064d52014SChris Lattner     // This operation is no longer in the worklist, keep worklistMap up to date.
6164d52014SChris Lattner     if (op)
6264d52014SChris Lattner       worklistMap.erase(op);
6364d52014SChris Lattner     return op;
6464d52014SChris Lattner   }
6564d52014SChris Lattner 
6664d52014SChris Lattner   /// If the specified operation is in the worklist, remove it.  If not, this is
6764d52014SChris Lattner   /// a no-op.
685187cfcfSChris Lattner   void removeFromWorklist(OperationInst *op) {
6964d52014SChris Lattner     auto it = worklistMap.find(op);
7064d52014SChris Lattner     if (it != worklistMap.end()) {
7164d52014SChris Lattner       assert(worklist[it->second] == op && "malformed worklist data structure");
7264d52014SChris Lattner       worklist[it->second] = nullptr;
7364d52014SChris Lattner     }
7464d52014SChris Lattner   }
7564d52014SChris Lattner 
764bd9f936SChris Lattner   // These are hooks implemented for PatternRewriter.
774bd9f936SChris Lattner protected:
784bd9f936SChris Lattner   // Implement the hook for creating operations, and make sure that newly
794bd9f936SChris Lattner   // created ops are added to the worklist for processing.
804bd9f936SChris Lattner   OperationInst *createOperation(const OperationState &state) override {
814bd9f936SChris Lattner     auto *result = builder.createOperation(state);
824bd9f936SChris Lattner     addToWorklist(result);
834bd9f936SChris Lattner     return result;
844bd9f936SChris Lattner   }
8564d52014SChris Lattner 
8664d52014SChris Lattner   // If an operation is about to be removed, make sure it is not in our
8764d52014SChris Lattner   // worklist anymore because we'd get dangling references to it.
885187cfcfSChris Lattner   void notifyOperationRemoved(OperationInst *op) override {
894bd9f936SChris Lattner     removeFromWorklist(op);
9064d52014SChris Lattner   }
9164d52014SChris Lattner 
92085b687fSChris Lattner   // When the root of a pattern is about to be replaced, it can trigger
93085b687fSChris Lattner   // simplifications to its users - make sure to add them to the worklist
94085b687fSChris Lattner   // before the root is changed.
955187cfcfSChris Lattner   void notifyRootReplaced(OperationInst *op) override {
96085b687fSChris Lattner     for (auto *result : op->getResults())
97085b687fSChris Lattner       // TODO: Add a result->getUsers() iterator.
98085b687fSChris Lattner       for (auto &user : result->getUses()) {
995187cfcfSChris Lattner         if (auto *op = dyn_cast<OperationInst>(user.getOwner()))
1004bd9f936SChris Lattner           addToWorklist(op);
101085b687fSChris Lattner       }
102085b687fSChris Lattner 
103085b687fSChris Lattner     // TODO: Walk the operand list dropping them as we go.  If any of them
104085b687fSChris Lattner     // drop to zero uses, then add them to the worklist to allow them to be
105085b687fSChris Lattner     // deleted as dead.
106085b687fSChris Lattner   }
107085b687fSChris Lattner 
1084bd9f936SChris Lattner private:
1094bd9f936SChris Lattner   /// The low-level pattern matcher.
1104bd9f936SChris Lattner   PatternMatcher matcher;
1114bd9f936SChris Lattner 
1124bd9f936SChris Lattner   /// This builder is used to create new operations.
1134bd9f936SChris Lattner   FuncBuilder builder;
1144bd9f936SChris Lattner 
1154bd9f936SChris Lattner   /// The worklist for this transformation keeps track of the operations that
1164bd9f936SChris Lattner   /// need to be revisited, plus their index in the worklist.  This allows us to
1174bd9f936SChris Lattner   /// efficiently remove operations from the worklist when they are erased from
1184bd9f936SChris Lattner   /// the function, even if they aren't the root of a pattern.
1194bd9f936SChris Lattner   std::vector<OperationInst *> worklist;
1204bd9f936SChris Lattner   DenseMap<OperationInst *, unsigned> worklistMap;
1214bd9f936SChris Lattner 
1224bd9f936SChris Lattner   /// As part of canonicalization, we move constants to the top of the entry
1234bd9f936SChris Lattner   /// block of the current function and de-duplicate them.  This keeps track of
1244bd9f936SChris Lattner   /// constants we have done this for.
1254bd9f936SChris Lattner   DenseMap<std::pair<Attribute, Type>, OperationInst *> uniquedConstants;
12664d52014SChris Lattner };
1274bd9f936SChris Lattner }; // end anonymous namespace
12864d52014SChris Lattner 
1294bd9f936SChris Lattner /// Perform the rewrites.
1304bd9f936SChris Lattner void GreedyPatternRewriteDriver::simplifyFunction() {
13164d52014SChris Lattner   // These are scratch vectors used in the constant folding loop below.
132792d1c25SRiver Riddle   SmallVector<Attribute, 8> operandConstants, resultConstants;
13364d52014SChris Lattner 
13464d52014SChris Lattner   while (!worklist.empty()) {
13564d52014SChris Lattner     auto *op = popFromWorklist();
13664d52014SChris Lattner 
13764d52014SChris Lattner     // Nulls get added to the worklist when operations are removed, ignore them.
13864d52014SChris Lattner     if (op == nullptr)
13964d52014SChris Lattner       continue;
14064d52014SChris Lattner 
14164d52014SChris Lattner     // If we have a constant op, unique it into the entry block.
14264d52014SChris Lattner     if (auto constant = op->dyn_cast<ConstantOp>()) {
14364d52014SChris Lattner       // If this constant is dead, remove it, being careful to keep
14464d52014SChris Lattner       // uniquedConstants up to date.
14564d52014SChris Lattner       if (constant->use_empty()) {
14664d52014SChris Lattner         auto it =
14764d52014SChris Lattner             uniquedConstants.find({constant->getValue(), constant->getType()});
14864d52014SChris Lattner         if (it != uniquedConstants.end() && it->second == op)
14964d52014SChris Lattner           uniquedConstants.erase(it);
15064d52014SChris Lattner         constant->erase();
15164d52014SChris Lattner         continue;
15264d52014SChris Lattner       }
15364d52014SChris Lattner 
15464d52014SChris Lattner       // Check to see if we already have a constant with this type and value:
15564d52014SChris Lattner       auto &entry = uniquedConstants[std::make_pair(constant->getValue(),
15664d52014SChris Lattner                                                     constant->getType())];
15764d52014SChris Lattner       if (entry) {
15864d52014SChris Lattner         // If this constant is already our uniqued one, then leave it alone.
15964d52014SChris Lattner         if (entry == op)
16064d52014SChris Lattner           continue;
16164d52014SChris Lattner 
16264d52014SChris Lattner         // Otherwise replace this redundant constant with the uniqued one.  We
16364d52014SChris Lattner         // know this is safe because we move constants to the top of the
16464d52014SChris Lattner         // function when they are uniqued, so we know they dominate all uses.
16564d52014SChris Lattner         constant->replaceAllUsesWith(entry->getResult(0));
16664d52014SChris Lattner         constant->erase();
16764d52014SChris Lattner         continue;
16864d52014SChris Lattner       }
16964d52014SChris Lattner 
17064d52014SChris Lattner       // If we have no entry, then we should unique this constant as the
17164d52014SChris Lattner       // canonical version.  To ensure safe dominance, move the operation to the
17264d52014SChris Lattner       // top of the function.
17364d52014SChris Lattner       entry = op;
1744bd9f936SChris Lattner       auto &entryBB = builder.getInsertionBlock()->getFunction()->front();
175471c9764SChris Lattner       op->moveBefore(&entryBB, entryBB.begin());
17664d52014SChris Lattner       continue;
17764d52014SChris Lattner     }
17864d52014SChris Lattner 
17964d52014SChris Lattner     // If the operation has no side effects, and no users, then it is trivially
18064d52014SChris Lattner     // dead - remove it.
18164d52014SChris Lattner     if (op->hasNoSideEffect() && op->use_empty()) {
18264d52014SChris Lattner       op->erase();
18364d52014SChris Lattner       continue;
18464d52014SChris Lattner     }
18564d52014SChris Lattner 
18664d52014SChris Lattner     // Check to see if any operands to the instruction is constant and whether
18764d52014SChris Lattner     // the operation knows how to constant fold itself.
18864d52014SChris Lattner     operandConstants.clear();
18964d52014SChris Lattner     for (auto *operand : op->getOperands()) {
190792d1c25SRiver Riddle       Attribute operandCst;
1915187cfcfSChris Lattner       if (auto *operandOp = operand->getDefiningInst()) {
19264d52014SChris Lattner         if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
19364d52014SChris Lattner           operandCst = operandConstantOp->getValue();
19464d52014SChris Lattner       }
19564d52014SChris Lattner       operandConstants.push_back(operandCst);
19664d52014SChris Lattner     }
19764d52014SChris Lattner 
19864d52014SChris Lattner     // If constant folding was successful, create the result constants, RAUW the
19964d52014SChris Lattner     // operation and remove it.
20064d52014SChris Lattner     resultConstants.clear();
20164d52014SChris Lattner     if (!op->constantFold(operandConstants, resultConstants)) {
2024bd9f936SChris Lattner       builder.setInsertionPoint(op);
20364d52014SChris Lattner 
20464d52014SChris Lattner       for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
20564d52014SChris Lattner         auto *res = op->getResult(i);
20664d52014SChris Lattner         if (res->use_empty()) // ignore dead uses.
20764d52014SChris Lattner           continue;
20864d52014SChris Lattner 
20964d52014SChris Lattner         // If we already have a canonicalized version of this constant, just
21064d52014SChris Lattner         // reuse it.  Otherwise create a new one.
2113f190312SChris Lattner         Value *cstValue;
21264d52014SChris Lattner         auto it = uniquedConstants.find({resultConstants[i], res->getType()});
21364d52014SChris Lattner         if (it != uniquedConstants.end())
21464d52014SChris Lattner           cstValue = it->second->getResult(0);
21564d52014SChris Lattner         else
216*61ec6c09SLei Zhang           cstValue = create<ConstantOp>(op->getLoc(), res->getType(),
217*61ec6c09SLei Zhang                                         resultConstants[i]);
218967d9341SChris Lattner 
219967d9341SChris Lattner         // Add all the users of the result to the worklist so we make sure to
220967d9341SChris Lattner         // revisit them.
221967d9341SChris Lattner         //
222967d9341SChris Lattner         // TODO: Add a result->getUsers() iterator.
223085b687fSChris Lattner         for (auto &operand : op->getResult(i)->getUses()) {
2245187cfcfSChris Lattner           if (auto *op = dyn_cast<OperationInst>(operand.getOwner()))
225967d9341SChris Lattner             addToWorklist(op);
226967d9341SChris Lattner         }
227967d9341SChris Lattner 
22864d52014SChris Lattner         res->replaceAllUsesWith(cstValue);
22964d52014SChris Lattner       }
23064d52014SChris Lattner 
23164d52014SChris Lattner       assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
23264d52014SChris Lattner       op->erase();
23364d52014SChris Lattner       continue;
23464d52014SChris Lattner     }
23564d52014SChris Lattner 
236967d9341SChris Lattner     // If this is a commutative binary operation with a constant on the left
237967d9341SChris Lattner     // side move it to the right side.
23864d52014SChris Lattner     if (operandConstants.size() == 2 && operandConstants[0] &&
239967d9341SChris Lattner         !operandConstants[1] && op->isCommutative()) {
24064d52014SChris Lattner       auto *newLHS = op->getOperand(1);
24164d52014SChris Lattner       op->setOperand(1, op->getOperand(0));
24264d52014SChris Lattner       op->setOperand(0, newLHS);
24364d52014SChris Lattner     }
24464d52014SChris Lattner 
24564d52014SChris Lattner     // Check to see if we have any patterns that match this node.
24664d52014SChris Lattner     auto match = matcher.findMatch(op);
24764d52014SChris Lattner     if (!match.first)
24864d52014SChris Lattner       continue;
24964d52014SChris Lattner 
25064d52014SChris Lattner     // Make sure that any new operations are inserted at this point.
2514bd9f936SChris Lattner     builder.setInsertionPoint(op);
2523f2530cdSChris Lattner     // We know that any pattern that matched is RewritePattern because we
2533f2530cdSChris Lattner     // initialized the matcher with RewritePatterns.
2543f2530cdSChris Lattner     auto *rewritePattern = static_cast<RewritePattern *>(match.first);
2554bd9f936SChris Lattner     rewritePattern->rewrite(op, std::move(match.second), *this);
25664d52014SChris Lattner   }
25764d52014SChris Lattner 
25864d52014SChris Lattner   uniquedConstants.clear();
25964d52014SChris Lattner }
26064d52014SChris Lattner 
26164d52014SChris Lattner /// Rewrite the specified function by repeatedly applying the highest benefit
26264d52014SChris Lattner /// patterns in a greedy work-list driven manner.
26364d52014SChris Lattner ///
2643f2530cdSChris Lattner void mlir::applyPatternsGreedily(Function *fn,
2653f2530cdSChris Lattner                                  OwningRewritePatternList &&patterns) {
2664bd9f936SChris Lattner   GreedyPatternRewriteDriver driver(fn, std::move(patterns));
2674bd9f936SChris Lattner   driver.simplifyFunction();
26864d52014SChris Lattner }
269