164d52014SChris Lattner //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
264d52014SChris Lattner //
364d52014SChris Lattner // Copyright 2019 The MLIR Authors.
464d52014SChris Lattner //
564d52014SChris Lattner // Licensed under the Apache License, Version 2.0 (the "License");
664d52014SChris Lattner // you may not use this file except in compliance with the License.
764d52014SChris Lattner // You may obtain a copy of the License at
864d52014SChris Lattner //
964d52014SChris Lattner //   http://www.apache.org/licenses/LICENSE-2.0
1064d52014SChris Lattner //
1164d52014SChris Lattner // Unless required by applicable law or agreed to in writing, software
1264d52014SChris Lattner // distributed under the License is distributed on an "AS IS" BASIS,
1364d52014SChris Lattner // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1464d52014SChris Lattner // See the License for the specific language governing permissions and
1564d52014SChris Lattner // limitations under the License.
1664d52014SChris Lattner // =============================================================================
1764d52014SChris Lattner //
1864d52014SChris Lattner // This file implements mlir::applyPatternsGreedily.
1964d52014SChris Lattner //
2064d52014SChris Lattner //===----------------------------------------------------------------------===//
2164d52014SChris Lattner 
2264d52014SChris Lattner #include "mlir/IR/Builders.h"
237de0da95SChris Lattner #include "mlir/IR/PatternMatch.h"
24f37651c7SRiver Riddle #include "mlir/StandardOps/Ops.h"
25*4e40c832SLei Zhang #include "mlir/Transforms/ConstantFoldUtils.h"
2664d52014SChris Lattner #include "llvm/ADT/DenseMap.h"
27*4e40c832SLei Zhang 
2864d52014SChris Lattner using namespace mlir;
2964d52014SChris Lattner 
3064d52014SChris Lattner namespace {
3164d52014SChris Lattner 
3264d52014SChris Lattner /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
3364d52014SChris Lattner /// applies the locally optimal patterns in a roughly "bottom up" way.
344bd9f936SChris Lattner class GreedyPatternRewriteDriver : public PatternRewriter {
3564d52014SChris Lattner public:
3646ade282SChris Lattner   explicit GreedyPatternRewriteDriver(Function &fn,
374bd9f936SChris Lattner                                       OwningRewritePatternList &&patterns)
3846ade282SChris Lattner       : PatternRewriter(fn.getContext()), matcher(std::move(patterns), *this),
3946ade282SChris Lattner         builder(&fn) {
4064d52014SChris Lattner     worklist.reserve(64);
414bd9f936SChris Lattner 
424bd9f936SChris Lattner     // Add all operations to the worklist.
4399b87c97SRiver Riddle     fn.walk([&](Operation *op) { addToWorklist(op); });
4464d52014SChris Lattner   }
4564d52014SChris Lattner 
464bd9f936SChris Lattner   /// Perform the rewrites.
474bd9f936SChris Lattner   void simplifyFunction();
4864d52014SChris Lattner 
4999b87c97SRiver Riddle   void addToWorklist(Operation *op) {
505c4f1fddSRiver Riddle     // Check to see if the worklist already contains this op.
515c4f1fddSRiver Riddle     if (worklistMap.count(op))
525c4f1fddSRiver Riddle       return;
535c4f1fddSRiver Riddle 
5464d52014SChris Lattner     worklistMap[op] = worklist.size();
5564d52014SChris Lattner     worklist.push_back(op);
5664d52014SChris Lattner   }
5764d52014SChris Lattner 
5899b87c97SRiver Riddle   Operation *popFromWorklist() {
5964d52014SChris Lattner     auto *op = worklist.back();
6064d52014SChris Lattner     worklist.pop_back();
6164d52014SChris Lattner 
6264d52014SChris Lattner     // This operation is no longer in the worklist, keep worklistMap up to date.
6364d52014SChris Lattner     if (op)
6464d52014SChris Lattner       worklistMap.erase(op);
6564d52014SChris Lattner     return op;
6664d52014SChris Lattner   }
6764d52014SChris Lattner 
6864d52014SChris Lattner   /// If the specified operation is in the worklist, remove it.  If not, this is
6964d52014SChris Lattner   /// a no-op.
7099b87c97SRiver Riddle   void removeFromWorklist(Operation *op) {
7164d52014SChris Lattner     auto it = worklistMap.find(op);
7264d52014SChris Lattner     if (it != worklistMap.end()) {
7364d52014SChris Lattner       assert(worklist[it->second] == op && "malformed worklist data structure");
7464d52014SChris Lattner       worklist[it->second] = nullptr;
7564d52014SChris Lattner     }
7664d52014SChris Lattner   }
7764d52014SChris Lattner 
784bd9f936SChris Lattner   // These are hooks implemented for PatternRewriter.
794bd9f936SChris Lattner protected:
804bd9f936SChris Lattner   // Implement the hook for creating operations, and make sure that newly
814bd9f936SChris Lattner   // created ops are added to the worklist for processing.
8299b87c97SRiver Riddle   Operation *createOperation(const OperationState &state) override {
834bd9f936SChris Lattner     auto *result = builder.createOperation(state);
844bd9f936SChris Lattner     addToWorklist(result);
854bd9f936SChris Lattner     return result;
864bd9f936SChris Lattner   }
8764d52014SChris Lattner 
8864d52014SChris Lattner   // If an operation is about to be removed, make sure it is not in our
8964d52014SChris Lattner   // worklist anymore because we'd get dangling references to it.
9099b87c97SRiver Riddle   void notifyOperationRemoved(Operation *op) override {
91a8866258SRiver Riddle     addToWorklist(op->getOperands());
924bd9f936SChris Lattner     removeFromWorklist(op);
9364d52014SChris Lattner   }
9464d52014SChris Lattner 
95085b687fSChris Lattner   // When the root of a pattern is about to be replaced, it can trigger
96085b687fSChris Lattner   // simplifications to its users - make sure to add them to the worklist
97085b687fSChris Lattner   // before the root is changed.
9899b87c97SRiver Riddle   void notifyRootReplaced(Operation *op) override {
99085b687fSChris Lattner     for (auto *result : op->getResults())
100085b687fSChris Lattner       // TODO: Add a result->getUsers() iterator.
101b499277fSRiver Riddle       for (auto &user : result->getUses())
102b499277fSRiver Riddle         addToWorklist(user.getOwner());
103085b687fSChris Lattner   }
104085b687fSChris Lattner 
1054bd9f936SChris Lattner private:
10699b87c97SRiver Riddle   // Look over the provided operands for any defining operations that should
107a8866258SRiver Riddle   // be re-added to the worklist. This function should be called when an
108a8866258SRiver Riddle   // operation is modified or removed, as it may trigger further
109a8866258SRiver Riddle   // simplifications.
110a8866258SRiver Riddle   template <typename Operands> void addToWorklist(Operands &&operands) {
111a8866258SRiver Riddle     for (Value *operand : operands) {
112a8866258SRiver Riddle       // If the use count of this operand is now < 2, we re-add the defining
11399b87c97SRiver Riddle       // operation to the worklist.
11499b87c97SRiver Riddle       // TODO(riverriddle) This is based on the fact that zero use operations
115a8866258SRiver Riddle       // may be deleted, and that single use values often have more
116a8866258SRiver Riddle       // canonicalization opportunities.
117a8866258SRiver Riddle       if (!operand->use_empty() &&
118a8866258SRiver Riddle           std::next(operand->use_begin()) != operand->use_end())
119a8866258SRiver Riddle         continue;
120f9d91531SRiver Riddle       if (auto *defInst = operand->getDefiningOp())
121a8866258SRiver Riddle         addToWorklist(defInst);
122a8866258SRiver Riddle     }
123a8866258SRiver Riddle   }
124a8866258SRiver Riddle 
1254bd9f936SChris Lattner   /// The low-level pattern matcher.
1265de726f4SRiver Riddle   RewritePatternMatcher matcher;
1274bd9f936SChris Lattner 
1284bd9f936SChris Lattner   /// This builder is used to create new operations.
1294bd9f936SChris Lattner   FuncBuilder builder;
1304bd9f936SChris Lattner 
1314bd9f936SChris Lattner   /// The worklist for this transformation keeps track of the operations that
1324bd9f936SChris Lattner   /// need to be revisited, plus their index in the worklist.  This allows us to
1334bd9f936SChris Lattner   /// efficiently remove operations from the worklist when they are erased from
1344bd9f936SChris Lattner   /// the function, even if they aren't the root of a pattern.
13599b87c97SRiver Riddle   std::vector<Operation *> worklist;
13699b87c97SRiver Riddle   DenseMap<Operation *, unsigned> worklistMap;
13764d52014SChris Lattner };
1384bd9f936SChris Lattner }; // end anonymous namespace
13964d52014SChris Lattner 
1404bd9f936SChris Lattner /// Perform the rewrites.
1414bd9f936SChris Lattner void GreedyPatternRewriteDriver::simplifyFunction() {
142*4e40c832SLei Zhang   ConstantFoldHelper helper(builder.getFunction());
143*4e40c832SLei Zhang 
144*4e40c832SLei Zhang   // These are scratch vectors used in the folding loop below.
145934b6d12SChris Lattner   SmallVector<Value *, 8> originalOperands, resultValues;
14664d52014SChris Lattner 
14764d52014SChris Lattner   while (!worklist.empty()) {
14864d52014SChris Lattner     auto *op = popFromWorklist();
14964d52014SChris Lattner 
15064d52014SChris Lattner     // Nulls get added to the worklist when operations are removed, ignore them.
15164d52014SChris Lattner     if (op == nullptr)
15264d52014SChris Lattner       continue;
15364d52014SChris Lattner 
15464d52014SChris Lattner     // If the operation has no side effects, and no users, then it is trivially
15564d52014SChris Lattner     // dead - remove it.
15664d52014SChris Lattner     if (op->hasNoSideEffect() && op->use_empty()) {
157*4e40c832SLei Zhang       // Be careful to update bookkeeping in ConstantHelper to keep consistency
158*4e40c832SLei Zhang       // if this is a constant op.
159*4e40c832SLei Zhang       if (op->isa<ConstantOp>())
160*4e40c832SLei Zhang         helper.notifyRemoval(op);
16164d52014SChris Lattner       op->erase();
16264d52014SChris Lattner       continue;
16364d52014SChris Lattner     }
16464d52014SChris Lattner 
165*4e40c832SLei Zhang     // Collects all the operands and result uses of the given `op` into work
166*4e40c832SLei Zhang     // list.
167*4e40c832SLei Zhang     auto collectOperandsAndUses = [this](Operation *op) {
168a8866258SRiver Riddle       // Add the operands to the worklist for visitation.
169a8866258SRiver Riddle       addToWorklist(op->getOperands());
170*4e40c832SLei Zhang       // Add all the users of the result to the worklist so we make sure
171*4e40c832SLei Zhang       // to revisit them.
172967d9341SChris Lattner       //
173967d9341SChris Lattner       // TODO: Add a result->getUsers() iterator.
174*4e40c832SLei Zhang       for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
175b499277fSRiver Riddle         for (auto &operand : op->getResult(i)->getUses())
176b499277fSRiver Riddle           addToWorklist(operand.getOwner());
17764d52014SChris Lattner       }
178*4e40c832SLei Zhang     };
17964d52014SChris Lattner 
180*4e40c832SLei Zhang     // Try to constant fold this op.
181*4e40c832SLei Zhang     if (helper.tryToConstantFold(op, collectOperandsAndUses)) {
18264d52014SChris Lattner       assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
18364d52014SChris Lattner       op->erase();
18464d52014SChris Lattner       continue;
18564d52014SChris Lattner     }
18664d52014SChris Lattner 
187934b6d12SChris Lattner     // Otherwise see if we can use the generic folder API to simplify the
188934b6d12SChris Lattner     // operation.
189934b6d12SChris Lattner     originalOperands.assign(op->operand_begin(), op->operand_end());
190934b6d12SChris Lattner     resultValues.clear();
1915e1f1d2cSRiver Riddle     if (succeeded(op->fold(resultValues))) {
192934b6d12SChris Lattner       // If the result was an in-place simplification (e.g. max(x,x,y) ->
193934b6d12SChris Lattner       // max(x,y)) then add the original operands to the worklist so we can make
194934b6d12SChris Lattner       // sure to revisit them.
195934b6d12SChris Lattner       if (resultValues.empty()) {
196a8866258SRiver Riddle         // Add the operands back to the worklist as there may be more
197a8866258SRiver Riddle         // canonicalization opportunities now.
198a8866258SRiver Riddle         addToWorklist(originalOperands);
199934b6d12SChris Lattner       } else {
200934b6d12SChris Lattner         // Otherwise, the operation is simplified away completely.
201934b6d12SChris Lattner         assert(resultValues.size() == op->getNumResults());
202934b6d12SChris Lattner 
203a8866258SRiver Riddle         // Notify that we are replacing this operation.
204a8866258SRiver Riddle         notifyRootReplaced(op);
205a8866258SRiver Riddle 
206a8866258SRiver Riddle         // Replace the result values and erase the operation.
207934b6d12SChris Lattner         for (unsigned i = 0, e = resultValues.size(); i != e; ++i) {
208934b6d12SChris Lattner           auto *res = op->getResult(i);
209a8866258SRiver Riddle           if (!res->use_empty())
210934b6d12SChris Lattner             res->replaceAllUsesWith(resultValues[i]);
211934b6d12SChris Lattner         }
212934b6d12SChris Lattner 
213a8866258SRiver Riddle         notifyOperationRemoved(op);
214934b6d12SChris Lattner         op->erase();
21599fee0b1SRiver Riddle       }
216934b6d12SChris Lattner       continue;
21764d52014SChris Lattner     }
21864d52014SChris Lattner 
21964d52014SChris Lattner     // Make sure that any new operations are inserted at this point.
2204bd9f936SChris Lattner     builder.setInsertionPoint(op);
2215de726f4SRiver Riddle 
2225de726f4SRiver Riddle     // Try to match one of the canonicalization patterns. The rewriter is
2235de726f4SRiver Riddle     // automatically notified of any necessary changes, so there is nothing else
2245de726f4SRiver Riddle     // to do here.
2255de726f4SRiver Riddle     matcher.matchAndRewrite(op);
22664d52014SChris Lattner   }
22764d52014SChris Lattner }
22864d52014SChris Lattner 
22964d52014SChris Lattner /// Rewrite the specified function by repeatedly applying the highest benefit
23064d52014SChris Lattner /// patterns in a greedy work-list driven manner.
23164d52014SChris Lattner ///
23246ade282SChris Lattner void mlir::applyPatternsGreedily(Function &fn,
2333f2530cdSChris Lattner                                  OwningRewritePatternList &&patterns) {
2344bd9f936SChris Lattner   GreedyPatternRewriteDriver driver(fn, std::move(patterns));
2354bd9f936SChris Lattner   driver.simplifyFunction();
23664d52014SChris Lattner }
237