164d52014SChris Lattner //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
264d52014SChris Lattner //
364d52014SChris Lattner // Copyright 2019 The MLIR Authors.
464d52014SChris Lattner //
564d52014SChris Lattner // Licensed under the Apache License, Version 2.0 (the "License");
664d52014SChris Lattner // you may not use this file except in compliance with the License.
764d52014SChris Lattner // You may obtain a copy of the License at
864d52014SChris Lattner //
964d52014SChris Lattner //   http://www.apache.org/licenses/LICENSE-2.0
1064d52014SChris Lattner //
1164d52014SChris Lattner // Unless required by applicable law or agreed to in writing, software
1264d52014SChris Lattner // distributed under the License is distributed on an "AS IS" BASIS,
1364d52014SChris Lattner // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1464d52014SChris Lattner // See the License for the specific language governing permissions and
1564d52014SChris Lattner // limitations under the License.
1664d52014SChris Lattner // =============================================================================
1764d52014SChris Lattner //
1864d52014SChris Lattner // This file implements mlir::applyPatternsGreedily.
1964d52014SChris Lattner //
2064d52014SChris Lattner //===----------------------------------------------------------------------===//
2164d52014SChris Lattner 
2264d52014SChris Lattner #include "mlir/IR/Builders.h"
23da0ebe06SRiver Riddle #include "mlir/IR/Matchers.h"
247de0da95SChris Lattner #include "mlir/IR/PatternMatch.h"
25f37651c7SRiver Riddle #include "mlir/StandardOps/Ops.h"
2664d52014SChris Lattner #include "llvm/ADT/DenseMap.h"
2764d52014SChris Lattner using namespace mlir;
2864d52014SChris Lattner 
2964d52014SChris Lattner namespace {
3064d52014SChris Lattner 
3164d52014SChris Lattner /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
3264d52014SChris Lattner /// applies the locally optimal patterns in a roughly "bottom up" way.
334bd9f936SChris Lattner class GreedyPatternRewriteDriver : public PatternRewriter {
3464d52014SChris Lattner public:
35*46ade282SChris Lattner   explicit GreedyPatternRewriteDriver(Function &fn,
364bd9f936SChris Lattner                                       OwningRewritePatternList &&patterns)
37*46ade282SChris Lattner       : PatternRewriter(fn.getContext()), matcher(std::move(patterns), *this),
38*46ade282SChris Lattner         builder(&fn) {
3964d52014SChris Lattner     worklist.reserve(64);
404bd9f936SChris Lattner 
414bd9f936SChris Lattner     // Add all operations to the worklist.
42*46ade282SChris Lattner     fn.walk([&](Instruction *inst) { addToWorklist(inst); });
4364d52014SChris Lattner   }
4464d52014SChris Lattner 
454bd9f936SChris Lattner   /// Perform the rewrites.
464bd9f936SChris Lattner   void simplifyFunction();
4764d52014SChris Lattner 
48b499277fSRiver Riddle   void addToWorklist(Instruction *op) {
495c4f1fddSRiver Riddle     // Check to see if the worklist already contains this op.
505c4f1fddSRiver Riddle     if (worklistMap.count(op))
515c4f1fddSRiver Riddle       return;
525c4f1fddSRiver Riddle 
5364d52014SChris Lattner     worklistMap[op] = worklist.size();
5464d52014SChris Lattner     worklist.push_back(op);
5564d52014SChris Lattner   }
5664d52014SChris Lattner 
57b499277fSRiver Riddle   Instruction *popFromWorklist() {
5864d52014SChris Lattner     auto *op = worklist.back();
5964d52014SChris Lattner     worklist.pop_back();
6064d52014SChris Lattner 
6164d52014SChris Lattner     // This operation is no longer in the worklist, keep worklistMap up to date.
6264d52014SChris Lattner     if (op)
6364d52014SChris Lattner       worklistMap.erase(op);
6464d52014SChris Lattner     return op;
6564d52014SChris Lattner   }
6664d52014SChris Lattner 
6764d52014SChris Lattner   /// If the specified operation is in the worklist, remove it.  If not, this is
6864d52014SChris Lattner   /// a no-op.
69b499277fSRiver Riddle   void removeFromWorklist(Instruction *op) {
7064d52014SChris Lattner     auto it = worklistMap.find(op);
7164d52014SChris Lattner     if (it != worklistMap.end()) {
7264d52014SChris Lattner       assert(worklist[it->second] == op && "malformed worklist data structure");
7364d52014SChris Lattner       worklist[it->second] = nullptr;
7464d52014SChris Lattner     }
7564d52014SChris Lattner   }
7664d52014SChris Lattner 
774bd9f936SChris Lattner   // These are hooks implemented for PatternRewriter.
784bd9f936SChris Lattner protected:
794bd9f936SChris Lattner   // Implement the hook for creating operations, and make sure that newly
804bd9f936SChris Lattner   // created ops are added to the worklist for processing.
81b499277fSRiver Riddle   Instruction *createOperation(const OperationState &state) override {
824bd9f936SChris Lattner     auto *result = builder.createOperation(state);
834bd9f936SChris Lattner     addToWorklist(result);
844bd9f936SChris Lattner     return result;
854bd9f936SChris Lattner   }
8664d52014SChris Lattner 
8764d52014SChris Lattner   // If an operation is about to be removed, make sure it is not in our
8864d52014SChris Lattner   // worklist anymore because we'd get dangling references to it.
89b499277fSRiver Riddle   void notifyOperationRemoved(Instruction *op) override {
90a8866258SRiver Riddle     addToWorklist(op->getOperands());
914bd9f936SChris Lattner     removeFromWorklist(op);
9264d52014SChris Lattner   }
9364d52014SChris Lattner 
94085b687fSChris Lattner   // When the root of a pattern is about to be replaced, it can trigger
95085b687fSChris Lattner   // simplifications to its users - make sure to add them to the worklist
96085b687fSChris Lattner   // before the root is changed.
97b499277fSRiver Riddle   void notifyRootReplaced(Instruction *op) override {
98085b687fSChris Lattner     for (auto *result : op->getResults())
99085b687fSChris Lattner       // TODO: Add a result->getUsers() iterator.
100b499277fSRiver Riddle       for (auto &user : result->getUses())
101b499277fSRiver Riddle         addToWorklist(user.getOwner());
102085b687fSChris Lattner   }
103085b687fSChris Lattner 
1044bd9f936SChris Lattner private:
105a8866258SRiver Riddle   // Look over the provided operands for any defining instructions that should
106a8866258SRiver Riddle   // be re-added to the worklist. This function should be called when an
107a8866258SRiver Riddle   // operation is modified or removed, as it may trigger further
108a8866258SRiver Riddle   // simplifications.
109a8866258SRiver Riddle   template <typename Operands> void addToWorklist(Operands &&operands) {
110a8866258SRiver Riddle     for (Value *operand : operands) {
111a8866258SRiver Riddle       // If the use count of this operand is now < 2, we re-add the defining
112a8866258SRiver Riddle       // instruction to the worklist.
113a8866258SRiver Riddle       // TODO(riverriddle) This is based on the fact that zero use instructions
114a8866258SRiver Riddle       // may be deleted, and that single use values often have more
115a8866258SRiver Riddle       // canonicalization opportunities.
116a8866258SRiver Riddle       if (!operand->use_empty() &&
117a8866258SRiver Riddle           std::next(operand->use_begin()) != operand->use_end())
118a8866258SRiver Riddle         continue;
119a8866258SRiver Riddle       if (auto *defInst = operand->getDefiningInst())
120a8866258SRiver Riddle         addToWorklist(defInst);
121a8866258SRiver Riddle     }
122a8866258SRiver Riddle   }
123a8866258SRiver Riddle 
1244bd9f936SChris Lattner   /// The low-level pattern matcher.
1255de726f4SRiver Riddle   RewritePatternMatcher matcher;
1264bd9f936SChris Lattner 
1274bd9f936SChris Lattner   /// This builder is used to create new operations.
1284bd9f936SChris Lattner   FuncBuilder builder;
1294bd9f936SChris Lattner 
1304bd9f936SChris Lattner   /// The worklist for this transformation keeps track of the operations that
1314bd9f936SChris Lattner   /// need to be revisited, plus their index in the worklist.  This allows us to
1324bd9f936SChris Lattner   /// efficiently remove operations from the worklist when they are erased from
1334bd9f936SChris Lattner   /// the function, even if they aren't the root of a pattern.
134b499277fSRiver Riddle   std::vector<Instruction *> worklist;
135b499277fSRiver Riddle   DenseMap<Instruction *, unsigned> worklistMap;
1364bd9f936SChris Lattner 
1374bd9f936SChris Lattner   /// As part of canonicalization, we move constants to the top of the entry
1384bd9f936SChris Lattner   /// block of the current function and de-duplicate them.  This keeps track of
1394bd9f936SChris Lattner   /// constants we have done this for.
140b499277fSRiver Riddle   DenseMap<std::pair<Attribute, Type>, Instruction *> uniquedConstants;
14164d52014SChris Lattner };
1424bd9f936SChris Lattner }; // end anonymous namespace
14364d52014SChris Lattner 
1444bd9f936SChris Lattner /// Perform the rewrites.
1454bd9f936SChris Lattner void GreedyPatternRewriteDriver::simplifyFunction() {
14664d52014SChris Lattner   // These are scratch vectors used in the constant folding loop below.
147792d1c25SRiver Riddle   SmallVector<Attribute, 8> operandConstants, resultConstants;
148934b6d12SChris Lattner   SmallVector<Value *, 8> originalOperands, resultValues;
14964d52014SChris Lattner 
15064d52014SChris Lattner   while (!worklist.empty()) {
15164d52014SChris Lattner     auto *op = popFromWorklist();
15264d52014SChris Lattner 
15364d52014SChris Lattner     // Nulls get added to the worklist when operations are removed, ignore them.
15464d52014SChris Lattner     if (op == nullptr)
15564d52014SChris Lattner       continue;
15664d52014SChris Lattner 
15764d52014SChris Lattner     // If we have a constant op, unique it into the entry block.
15864d52014SChris Lattner     if (auto constant = op->dyn_cast<ConstantOp>()) {
15964d52014SChris Lattner       // If this constant is dead, remove it, being careful to keep
16064d52014SChris Lattner       // uniquedConstants up to date.
16196ebde9cSRiver Riddle       if (constant.use_empty()) {
16264d52014SChris Lattner         auto it =
16396ebde9cSRiver Riddle             uniquedConstants.find({constant.getValue(), constant.getType()});
16464d52014SChris Lattner         if (it != uniquedConstants.end() && it->second == op)
16564d52014SChris Lattner           uniquedConstants.erase(it);
16696ebde9cSRiver Riddle         constant.erase();
16764d52014SChris Lattner         continue;
16864d52014SChris Lattner       }
16964d52014SChris Lattner 
17064d52014SChris Lattner       // Check to see if we already have a constant with this type and value:
17196ebde9cSRiver Riddle       auto &entry = uniquedConstants[std::make_pair(constant.getValue(),
17296ebde9cSRiver Riddle                                                     constant.getType())];
17364d52014SChris Lattner       if (entry) {
17464d52014SChris Lattner         // If this constant is already our uniqued one, then leave it alone.
17564d52014SChris Lattner         if (entry == op)
17664d52014SChris Lattner           continue;
17764d52014SChris Lattner 
17864d52014SChris Lattner         // Otherwise replace this redundant constant with the uniqued one.  We
17964d52014SChris Lattner         // know this is safe because we move constants to the top of the
18064d52014SChris Lattner         // function when they are uniqued, so we know they dominate all uses.
18196ebde9cSRiver Riddle         constant.replaceAllUsesWith(entry->getResult(0));
18296ebde9cSRiver Riddle         constant.erase();
18364d52014SChris Lattner         continue;
18464d52014SChris Lattner       }
18564d52014SChris Lattner 
18664d52014SChris Lattner       // If we have no entry, then we should unique this constant as the
18764d52014SChris Lattner       // canonical version.  To ensure safe dominance, move the operation to the
18864d52014SChris Lattner       // top of the function.
18964d52014SChris Lattner       entry = op;
1904bd9f936SChris Lattner       auto &entryBB = builder.getInsertionBlock()->getFunction()->front();
191471c9764SChris Lattner       op->moveBefore(&entryBB, entryBB.begin());
19264d52014SChris Lattner       continue;
19364d52014SChris Lattner     }
19464d52014SChris Lattner 
19564d52014SChris Lattner     // If the operation has no side effects, and no users, then it is trivially
19664d52014SChris Lattner     // dead - remove it.
19764d52014SChris Lattner     if (op->hasNoSideEffect() && op->use_empty()) {
19864d52014SChris Lattner       op->erase();
19964d52014SChris Lattner       continue;
20064d52014SChris Lattner     }
20164d52014SChris Lattner 
20264d52014SChris Lattner     // Check to see if any operands to the instruction is constant and whether
20364d52014SChris Lattner     // the operation knows how to constant fold itself.
204da0ebe06SRiver Riddle     operandConstants.assign(op->getNumOperands(), Attribute());
205da0ebe06SRiver Riddle     for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
206da0ebe06SRiver Riddle       matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
20764d52014SChris Lattner 
208934b6d12SChris Lattner     // If this is a commutative binary operation with a constant on the left
209934b6d12SChris Lattner     // side move it to the right side.
210934b6d12SChris Lattner     if (operandConstants.size() == 2 && operandConstants[0] &&
211934b6d12SChris Lattner         !operandConstants[1] && op->isCommutative()) {
212934b6d12SChris Lattner       std::swap(op->getInstOperand(0), op->getInstOperand(1));
213934b6d12SChris Lattner       std::swap(operandConstants[0], operandConstants[1]);
214934b6d12SChris Lattner     }
215934b6d12SChris Lattner 
21664d52014SChris Lattner     // If constant folding was successful, create the result constants, RAUW the
21764d52014SChris Lattner     // operation and remove it.
21864d52014SChris Lattner     resultConstants.clear();
2195e1f1d2cSRiver Riddle     if (succeeded(op->constantFold(operandConstants, resultConstants))) {
2204bd9f936SChris Lattner       builder.setInsertionPoint(op);
22164d52014SChris Lattner 
222a8866258SRiver Riddle       // Add the operands to the worklist for visitation.
223a8866258SRiver Riddle       addToWorklist(op->getOperands());
224a8866258SRiver Riddle 
22564d52014SChris Lattner       for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
22664d52014SChris Lattner         auto *res = op->getResult(i);
22764d52014SChris Lattner         if (res->use_empty()) // ignore dead uses.
22864d52014SChris Lattner           continue;
22964d52014SChris Lattner 
23064d52014SChris Lattner         // If we already have a canonicalized version of this constant, just
23164d52014SChris Lattner         // reuse it.  Otherwise create a new one.
2323f190312SChris Lattner         Value *cstValue;
23364d52014SChris Lattner         auto it = uniquedConstants.find({resultConstants[i], res->getType()});
23464d52014SChris Lattner         if (it != uniquedConstants.end())
23564d52014SChris Lattner           cstValue = it->second->getResult(0);
23664d52014SChris Lattner         else
23761ec6c09SLei Zhang           cstValue = create<ConstantOp>(op->getLoc(), res->getType(),
23861ec6c09SLei Zhang                                         resultConstants[i]);
239967d9341SChris Lattner 
240967d9341SChris Lattner         // Add all the users of the result to the worklist so we make sure to
241967d9341SChris Lattner         // revisit them.
242967d9341SChris Lattner         //
243967d9341SChris Lattner         // TODO: Add a result->getUsers() iterator.
244b499277fSRiver Riddle         for (auto &operand : op->getResult(i)->getUses())
245b499277fSRiver Riddle           addToWorklist(operand.getOwner());
246967d9341SChris Lattner 
24764d52014SChris Lattner         res->replaceAllUsesWith(cstValue);
24864d52014SChris Lattner       }
24964d52014SChris Lattner 
25064d52014SChris Lattner       assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
25164d52014SChris Lattner       op->erase();
25264d52014SChris Lattner       continue;
25364d52014SChris Lattner     }
25464d52014SChris Lattner 
255934b6d12SChris Lattner     // Otherwise see if we can use the generic folder API to simplify the
256934b6d12SChris Lattner     // operation.
257934b6d12SChris Lattner     originalOperands.assign(op->operand_begin(), op->operand_end());
258934b6d12SChris Lattner     resultValues.clear();
2595e1f1d2cSRiver Riddle     if (succeeded(op->fold(resultValues))) {
260934b6d12SChris Lattner       // If the result was an in-place simplification (e.g. max(x,x,y) ->
261934b6d12SChris Lattner       // max(x,y)) then add the original operands to the worklist so we can make
262934b6d12SChris Lattner       // sure to revisit them.
263934b6d12SChris Lattner       if (resultValues.empty()) {
264a8866258SRiver Riddle         // Add the operands back to the worklist as there may be more
265a8866258SRiver Riddle         // canonicalization opportunities now.
266a8866258SRiver Riddle         addToWorklist(originalOperands);
267934b6d12SChris Lattner       } else {
268934b6d12SChris Lattner         // Otherwise, the operation is simplified away completely.
269934b6d12SChris Lattner         assert(resultValues.size() == op->getNumResults());
270934b6d12SChris Lattner 
271a8866258SRiver Riddle         // Notify that we are replacing this operation.
272a8866258SRiver Riddle         notifyRootReplaced(op);
273a8866258SRiver Riddle 
274a8866258SRiver Riddle         // Replace the result values and erase the operation.
275934b6d12SChris Lattner         for (unsigned i = 0, e = resultValues.size(); i != e; ++i) {
276934b6d12SChris Lattner           auto *res = op->getResult(i);
277a8866258SRiver Riddle           if (!res->use_empty())
278934b6d12SChris Lattner             res->replaceAllUsesWith(resultValues[i]);
279934b6d12SChris Lattner         }
280934b6d12SChris Lattner 
281a8866258SRiver Riddle         notifyOperationRemoved(op);
282934b6d12SChris Lattner         op->erase();
28399fee0b1SRiver Riddle       }
284934b6d12SChris Lattner       continue;
28564d52014SChris Lattner     }
28664d52014SChris Lattner 
28764d52014SChris Lattner     // Make sure that any new operations are inserted at this point.
2884bd9f936SChris Lattner     builder.setInsertionPoint(op);
2895de726f4SRiver Riddle 
2905de726f4SRiver Riddle     // Try to match one of the canonicalization patterns. The rewriter is
2915de726f4SRiver Riddle     // automatically notified of any necessary changes, so there is nothing else
2925de726f4SRiver Riddle     // to do here.
2935de726f4SRiver Riddle     matcher.matchAndRewrite(op);
29464d52014SChris Lattner   }
29564d52014SChris Lattner 
29664d52014SChris Lattner   uniquedConstants.clear();
29764d52014SChris Lattner }
29864d52014SChris Lattner 
29964d52014SChris Lattner /// Rewrite the specified function by repeatedly applying the highest benefit
30064d52014SChris Lattner /// patterns in a greedy work-list driven manner.
30164d52014SChris Lattner ///
302*46ade282SChris Lattner void mlir::applyPatternsGreedily(Function &fn,
3033f2530cdSChris Lattner                                  OwningRewritePatternList &&patterns) {
3044bd9f936SChris Lattner   GreedyPatternRewriteDriver driver(fn, std::move(patterns));
3054bd9f936SChris Lattner   driver.simplifyFunction();
30664d52014SChris Lattner }
307