1*64d52014SChris Lattner //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2*64d52014SChris Lattner //
3*64d52014SChris Lattner // Copyright 2019 The MLIR Authors.
4*64d52014SChris Lattner //
5*64d52014SChris Lattner // Licensed under the Apache License, Version 2.0 (the "License");
6*64d52014SChris Lattner // you may not use this file except in compliance with the License.
7*64d52014SChris Lattner // You may obtain a copy of the License at
8*64d52014SChris Lattner //
9*64d52014SChris Lattner //   http://www.apache.org/licenses/LICENSE-2.0
10*64d52014SChris Lattner //
11*64d52014SChris Lattner // Unless required by applicable law or agreed to in writing, software
12*64d52014SChris Lattner // distributed under the License is distributed on an "AS IS" BASIS,
13*64d52014SChris Lattner // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14*64d52014SChris Lattner // See the License for the specific language governing permissions and
15*64d52014SChris Lattner // limitations under the License.
16*64d52014SChris Lattner // =============================================================================
17*64d52014SChris Lattner //
18*64d52014SChris Lattner // This file implements mlir::applyPatternsGreedily.
19*64d52014SChris Lattner //
20*64d52014SChris Lattner //===----------------------------------------------------------------------===//
21*64d52014SChris Lattner 
22*64d52014SChris Lattner #include "mlir/IR/Builders.h"
23*64d52014SChris Lattner #include "mlir/IR/BuiltinOps.h"
24*64d52014SChris Lattner #include "mlir/StandardOps/StandardOps.h"
25*64d52014SChris Lattner #include "mlir/Transforms/PatternMatch.h"
26*64d52014SChris Lattner #include "llvm/ADT/DenseMap.h"
27*64d52014SChris Lattner using namespace mlir;
28*64d52014SChris Lattner 
29*64d52014SChris Lattner namespace {
30*64d52014SChris Lattner class WorklistRewriter;
31*64d52014SChris Lattner 
32*64d52014SChris Lattner /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
33*64d52014SChris Lattner /// applies the locally optimal patterns in a roughly "bottom up" way.
34*64d52014SChris Lattner class GreedyPatternRewriteDriver {
35*64d52014SChris Lattner public:
36*64d52014SChris Lattner   explicit GreedyPatternRewriteDriver(OwningPatternList &&patterns)
37*64d52014SChris Lattner       : matcher(std::move(patterns)) {
38*64d52014SChris Lattner     worklist.reserve(64);
39*64d52014SChris Lattner   }
40*64d52014SChris Lattner 
41*64d52014SChris Lattner   void simplifyFunction(Function *currentFunction, WorklistRewriter &rewriter);
42*64d52014SChris Lattner 
43*64d52014SChris Lattner   void addToWorklist(Operation *op) {
44*64d52014SChris Lattner     worklistMap[op] = worklist.size();
45*64d52014SChris Lattner     worklist.push_back(op);
46*64d52014SChris Lattner   }
47*64d52014SChris Lattner 
48*64d52014SChris Lattner   Operation *popFromWorklist() {
49*64d52014SChris Lattner     auto *op = worklist.back();
50*64d52014SChris Lattner     worklist.pop_back();
51*64d52014SChris Lattner 
52*64d52014SChris Lattner     // This operation is no longer in the worklist, keep worklistMap up to date.
53*64d52014SChris Lattner     if (op)
54*64d52014SChris Lattner       worklistMap.erase(op);
55*64d52014SChris Lattner     return op;
56*64d52014SChris Lattner   }
57*64d52014SChris Lattner 
58*64d52014SChris Lattner   /// If the specified operation is in the worklist, remove it.  If not, this is
59*64d52014SChris Lattner   /// a no-op.
60*64d52014SChris Lattner   void removeFromWorklist(Operation *op) {
61*64d52014SChris Lattner     auto it = worklistMap.find(op);
62*64d52014SChris Lattner     if (it != worklistMap.end()) {
63*64d52014SChris Lattner       assert(worklist[it->second] == op && "malformed worklist data structure");
64*64d52014SChris Lattner       worklist[it->second] = nullptr;
65*64d52014SChris Lattner     }
66*64d52014SChris Lattner   }
67*64d52014SChris Lattner 
68*64d52014SChris Lattner private:
69*64d52014SChris Lattner   /// The low-level pattern matcher.
70*64d52014SChris Lattner   PatternMatcher matcher;
71*64d52014SChris Lattner 
72*64d52014SChris Lattner   /// The worklist for this transformation keeps track of the operations that
73*64d52014SChris Lattner   /// need to be revisited, plus their index in the worklist.  This allows us to
74*64d52014SChris Lattner   /// efficiently remove operations from the worklist when they are removed even
75*64d52014SChris Lattner   /// if they aren't the root of a pattern.
76*64d52014SChris Lattner   std::vector<Operation *> worklist;
77*64d52014SChris Lattner   DenseMap<Operation *, unsigned> worklistMap;
78*64d52014SChris Lattner 
79*64d52014SChris Lattner   /// As part of canonicalization, we move constants to the top of the entry
80*64d52014SChris Lattner   /// block of the current function and de-duplicate them.  This keeps track of
81*64d52014SChris Lattner   /// constants we have done this for.
82*64d52014SChris Lattner   DenseMap<std::pair<Attribute *, Type *>, Operation *> uniquedConstants;
83*64d52014SChris Lattner };
84*64d52014SChris Lattner }; // end anonymous namespace
85*64d52014SChris Lattner 
86*64d52014SChris Lattner /// This is a listener object that updates our worklists and other data
87*64d52014SChris Lattner /// structures in response to operations being added and removed.
88*64d52014SChris Lattner namespace {
89*64d52014SChris Lattner class WorklistRewriter : public PatternRewriter {
90*64d52014SChris Lattner public:
91*64d52014SChris Lattner   WorklistRewriter(GreedyPatternRewriteDriver &driver, MLIRContext *context)
92*64d52014SChris Lattner       : PatternRewriter(context), driver(driver) {}
93*64d52014SChris Lattner 
94*64d52014SChris Lattner   virtual void setInsertionPoint(Operation *op) = 0;
95*64d52014SChris Lattner 
96*64d52014SChris Lattner   // If an operation is about to be removed, make sure it is not in our
97*64d52014SChris Lattner   // worklist anymore because we'd get dangling references to it.
98*64d52014SChris Lattner   void notifyOperationRemoved(Operation *op) override {
99*64d52014SChris Lattner     driver.removeFromWorklist(op);
100*64d52014SChris Lattner   }
101*64d52014SChris Lattner 
102*64d52014SChris Lattner   GreedyPatternRewriteDriver &driver;
103*64d52014SChris Lattner };
104*64d52014SChris Lattner 
105*64d52014SChris Lattner } // end anonymous namespace
106*64d52014SChris Lattner 
107*64d52014SChris Lattner void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
108*64d52014SChris Lattner                                                   WorklistRewriter &rewriter) {
109*64d52014SChris Lattner   // These are scratch vectors used in the constant folding loop below.
110*64d52014SChris Lattner   SmallVector<Attribute *, 8> operandConstants, resultConstants;
111*64d52014SChris Lattner 
112*64d52014SChris Lattner   while (!worklist.empty()) {
113*64d52014SChris Lattner     auto *op = popFromWorklist();
114*64d52014SChris Lattner 
115*64d52014SChris Lattner     // Nulls get added to the worklist when operations are removed, ignore them.
116*64d52014SChris Lattner     if (op == nullptr)
117*64d52014SChris Lattner       continue;
118*64d52014SChris Lattner 
119*64d52014SChris Lattner     // If we have a constant op, unique it into the entry block.
120*64d52014SChris Lattner     if (auto constant = op->dyn_cast<ConstantOp>()) {
121*64d52014SChris Lattner       // If this constant is dead, remove it, being careful to keep
122*64d52014SChris Lattner       // uniquedConstants up to date.
123*64d52014SChris Lattner       if (constant->use_empty()) {
124*64d52014SChris Lattner         auto it =
125*64d52014SChris Lattner             uniquedConstants.find({constant->getValue(), constant->getType()});
126*64d52014SChris Lattner         if (it != uniquedConstants.end() && it->second == op)
127*64d52014SChris Lattner           uniquedConstants.erase(it);
128*64d52014SChris Lattner         constant->erase();
129*64d52014SChris Lattner         continue;
130*64d52014SChris Lattner       }
131*64d52014SChris Lattner 
132*64d52014SChris Lattner       // Check to see if we already have a constant with this type and value:
133*64d52014SChris Lattner       auto &entry = uniquedConstants[std::make_pair(constant->getValue(),
134*64d52014SChris Lattner                                                     constant->getType())];
135*64d52014SChris Lattner       if (entry) {
136*64d52014SChris Lattner         // If this constant is already our uniqued one, then leave it alone.
137*64d52014SChris Lattner         if (entry == op)
138*64d52014SChris Lattner           continue;
139*64d52014SChris Lattner 
140*64d52014SChris Lattner         // Otherwise replace this redundant constant with the uniqued one.  We
141*64d52014SChris Lattner         // know this is safe because we move constants to the top of the
142*64d52014SChris Lattner         // function when they are uniqued, so we know they dominate all uses.
143*64d52014SChris Lattner         constant->replaceAllUsesWith(entry->getResult(0));
144*64d52014SChris Lattner         constant->erase();
145*64d52014SChris Lattner         continue;
146*64d52014SChris Lattner       }
147*64d52014SChris Lattner 
148*64d52014SChris Lattner       // If we have no entry, then we should unique this constant as the
149*64d52014SChris Lattner       // canonical version.  To ensure safe dominance, move the operation to the
150*64d52014SChris Lattner       // top of the function.
151*64d52014SChris Lattner       entry = op;
152*64d52014SChris Lattner 
153*64d52014SChris Lattner       // TODO: If we make terminators into Operations then we could turn this
154*64d52014SChris Lattner       // into a nice Operation::moveBefore(Operation*) method.  We just need the
155*64d52014SChris Lattner       // guarantee that a block is non-empty.
156*64d52014SChris Lattner       if (auto *cfgFunc = dyn_cast<CFGFunction>(currentFunction)) {
157*64d52014SChris Lattner         auto &entryBB = cfgFunc->front();
158*64d52014SChris Lattner         cast<OperationInst>(op)->moveBefore(&entryBB, entryBB.begin());
159*64d52014SChris Lattner       } else {
160*64d52014SChris Lattner         auto *mlFunc = cast<MLFunction>(currentFunction);
161*64d52014SChris Lattner         cast<OperationStmt>(op)->moveBefore(mlFunc, mlFunc->begin());
162*64d52014SChris Lattner       }
163*64d52014SChris Lattner 
164*64d52014SChris Lattner       continue;
165*64d52014SChris Lattner     }
166*64d52014SChris Lattner 
167*64d52014SChris Lattner     // If the operation has no side effects, and no users, then it is trivially
168*64d52014SChris Lattner     // dead - remove it.
169*64d52014SChris Lattner     if (op->hasNoSideEffect() && op->use_empty()) {
170*64d52014SChris Lattner       op->erase();
171*64d52014SChris Lattner       continue;
172*64d52014SChris Lattner     }
173*64d52014SChris Lattner 
174*64d52014SChris Lattner     // Check to see if any operands to the instruction is constant and whether
175*64d52014SChris Lattner     // the operation knows how to constant fold itself.
176*64d52014SChris Lattner     operandConstants.clear();
177*64d52014SChris Lattner     for (auto *operand : op->getOperands()) {
178*64d52014SChris Lattner       Attribute *operandCst = nullptr;
179*64d52014SChris Lattner       if (auto *operandOp = operand->getDefiningOperation()) {
180*64d52014SChris Lattner         if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
181*64d52014SChris Lattner           operandCst = operandConstantOp->getValue();
182*64d52014SChris Lattner       }
183*64d52014SChris Lattner       operandConstants.push_back(operandCst);
184*64d52014SChris Lattner     }
185*64d52014SChris Lattner 
186*64d52014SChris Lattner     // If constant folding was successful, create the result constants, RAUW the
187*64d52014SChris Lattner     // operation and remove it.
188*64d52014SChris Lattner     resultConstants.clear();
189*64d52014SChris Lattner     if (!op->constantFold(operandConstants, resultConstants)) {
190*64d52014SChris Lattner       rewriter.setInsertionPoint(op);
191*64d52014SChris Lattner 
192*64d52014SChris Lattner       for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
193*64d52014SChris Lattner         auto *res = op->getResult(i);
194*64d52014SChris Lattner         if (res->use_empty()) // ignore dead uses.
195*64d52014SChris Lattner           continue;
196*64d52014SChris Lattner 
197*64d52014SChris Lattner         // If we already have a canonicalized version of this constant, just
198*64d52014SChris Lattner         // reuse it.  Otherwise create a new one.
199*64d52014SChris Lattner         SSAValue *cstValue;
200*64d52014SChris Lattner         auto it = uniquedConstants.find({resultConstants[i], res->getType()});
201*64d52014SChris Lattner         if (it != uniquedConstants.end())
202*64d52014SChris Lattner           cstValue = it->second->getResult(0);
203*64d52014SChris Lattner         else
204*64d52014SChris Lattner           cstValue = rewriter.create<ConstantOp>(
205*64d52014SChris Lattner               op->getLoc(), resultConstants[i], res->getType());
206*64d52014SChris Lattner         res->replaceAllUsesWith(cstValue);
207*64d52014SChris Lattner       }
208*64d52014SChris Lattner 
209*64d52014SChris Lattner       assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
210*64d52014SChris Lattner       op->erase();
211*64d52014SChris Lattner       continue;
212*64d52014SChris Lattner     }
213*64d52014SChris Lattner 
214*64d52014SChris Lattner     // If this is an associative binary operation with a constant on the LHS,
215*64d52014SChris Lattner     // move it to the right side.
216*64d52014SChris Lattner     if (operandConstants.size() == 2 && operandConstants[0] &&
217*64d52014SChris Lattner         !operandConstants[1]) {
218*64d52014SChris Lattner       auto *newLHS = op->getOperand(1);
219*64d52014SChris Lattner       op->setOperand(1, op->getOperand(0));
220*64d52014SChris Lattner       op->setOperand(0, newLHS);
221*64d52014SChris Lattner     }
222*64d52014SChris Lattner 
223*64d52014SChris Lattner     // Check to see if we have any patterns that match this node.
224*64d52014SChris Lattner     auto match = matcher.findMatch(op);
225*64d52014SChris Lattner     if (!match.first)
226*64d52014SChris Lattner       continue;
227*64d52014SChris Lattner 
228*64d52014SChris Lattner     // Make sure that any new operations are inserted at this point.
229*64d52014SChris Lattner     rewriter.setInsertionPoint(op);
230*64d52014SChris Lattner     match.first->rewrite(op, std::move(match.second), rewriter);
231*64d52014SChris Lattner   }
232*64d52014SChris Lattner 
233*64d52014SChris Lattner   uniquedConstants.clear();
234*64d52014SChris Lattner }
235*64d52014SChris Lattner 
236*64d52014SChris Lattner static void processMLFunction(MLFunction *fn, OwningPatternList &&patterns) {
237*64d52014SChris Lattner   class MLFuncRewriter : public WorklistRewriter {
238*64d52014SChris Lattner   public:
239*64d52014SChris Lattner     MLFuncRewriter(GreedyPatternRewriteDriver &driver, MLFuncBuilder &builder)
240*64d52014SChris Lattner         : WorklistRewriter(driver, builder.getContext()), builder(builder) {}
241*64d52014SChris Lattner 
242*64d52014SChris Lattner     // Implement the hook for creating operations, and make sure that newly
243*64d52014SChris Lattner     // created ops are added to the worklist for processing.
244*64d52014SChris Lattner     Operation *createOperation(const OperationState &state) override {
245*64d52014SChris Lattner       auto *result = builder.createOperation(state);
246*64d52014SChris Lattner       driver.addToWorklist(result);
247*64d52014SChris Lattner       return result;
248*64d52014SChris Lattner     }
249*64d52014SChris Lattner 
250*64d52014SChris Lattner     // When the root of a pattern is about to be replaced, it can trigger
251*64d52014SChris Lattner     // simplifications to its users - make sure to add them to the worklist
252*64d52014SChris Lattner     // before the root is changed.
253*64d52014SChris Lattner     void notifyRootReplaced(Operation *op) override {
254*64d52014SChris Lattner       auto *opStmt = cast<OperationStmt>(op);
255*64d52014SChris Lattner       for (auto *result : opStmt->getResults())
256*64d52014SChris Lattner         // TODO: Add a result->getUsers() iterator.
257*64d52014SChris Lattner         for (auto &user : result->getUses()) {
258*64d52014SChris Lattner           if (auto *op = dyn_cast<OperationStmt>(user.getOwner()))
259*64d52014SChris Lattner             driver.addToWorklist(op);
260*64d52014SChris Lattner         }
261*64d52014SChris Lattner 
262*64d52014SChris Lattner       // TODO: Walk the operand list dropping them as we go.  If any of them
263*64d52014SChris Lattner       // drop to zero uses, then add them to the worklist to allow them to be
264*64d52014SChris Lattner       // deleted as dead.
265*64d52014SChris Lattner     }
266*64d52014SChris Lattner 
267*64d52014SChris Lattner     void setInsertionPoint(Operation *op) override {
268*64d52014SChris Lattner       // Any new operations should be added before this statement.
269*64d52014SChris Lattner       builder.setInsertionPoint(cast<OperationStmt>(op));
270*64d52014SChris Lattner     }
271*64d52014SChris Lattner 
272*64d52014SChris Lattner   private:
273*64d52014SChris Lattner     MLFuncBuilder &builder;
274*64d52014SChris Lattner   };
275*64d52014SChris Lattner 
276*64d52014SChris Lattner   GreedyPatternRewriteDriver driver(std::move(patterns));
277*64d52014SChris Lattner   fn->walk([&](OperationStmt *stmt) { driver.addToWorklist(stmt); });
278*64d52014SChris Lattner 
279*64d52014SChris Lattner   MLFuncBuilder mlBuilder(fn);
280*64d52014SChris Lattner   MLFuncRewriter rewriter(driver, mlBuilder);
281*64d52014SChris Lattner   driver.simplifyFunction(fn, rewriter);
282*64d52014SChris Lattner }
283*64d52014SChris Lattner 
284*64d52014SChris Lattner static void processCFGFunction(CFGFunction *fn, OwningPatternList &&patterns) {
285*64d52014SChris Lattner   class CFGFuncRewriter : public WorklistRewriter {
286*64d52014SChris Lattner   public:
287*64d52014SChris Lattner     CFGFuncRewriter(GreedyPatternRewriteDriver &driver, CFGFuncBuilder &builder)
288*64d52014SChris Lattner         : WorklistRewriter(driver, builder.getContext()), builder(builder) {}
289*64d52014SChris Lattner 
290*64d52014SChris Lattner     // Implement the hook for creating operations, and make sure that newly
291*64d52014SChris Lattner     // created ops are added to the worklist for processing.
292*64d52014SChris Lattner     Operation *createOperation(const OperationState &state) override {
293*64d52014SChris Lattner       auto *result = builder.createOperation(state);
294*64d52014SChris Lattner       driver.addToWorklist(result);
295*64d52014SChris Lattner       return result;
296*64d52014SChris Lattner     }
297*64d52014SChris Lattner 
298*64d52014SChris Lattner     // When the root of a pattern is about to be replaced, it can trigger
299*64d52014SChris Lattner     // simplifications to its users - make sure to add them to the worklist
300*64d52014SChris Lattner     // before the root is changed.
301*64d52014SChris Lattner     void notifyRootReplaced(Operation *op) override {
302*64d52014SChris Lattner       auto *opStmt = cast<OperationInst>(op);
303*64d52014SChris Lattner       for (auto *result : opStmt->getResults())
304*64d52014SChris Lattner         // TODO: Add a result->getUsers() iterator.
305*64d52014SChris Lattner         for (auto &user : result->getUses()) {
306*64d52014SChris Lattner           if (auto *op = dyn_cast<OperationInst>(user.getOwner()))
307*64d52014SChris Lattner             driver.addToWorklist(op);
308*64d52014SChris Lattner         }
309*64d52014SChris Lattner 
310*64d52014SChris Lattner       // TODO: Walk the operand list dropping them as we go.  If any of them
311*64d52014SChris Lattner       // drop to zero uses, then add them to the worklist to allow them to be
312*64d52014SChris Lattner       // deleted as dead.
313*64d52014SChris Lattner     }
314*64d52014SChris Lattner 
315*64d52014SChris Lattner     void setInsertionPoint(Operation *op) override {
316*64d52014SChris Lattner       // Any new operations should be added before this instruction.
317*64d52014SChris Lattner       builder.setInsertionPoint(cast<OperationInst>(op));
318*64d52014SChris Lattner     }
319*64d52014SChris Lattner 
320*64d52014SChris Lattner   private:
321*64d52014SChris Lattner     CFGFuncBuilder &builder;
322*64d52014SChris Lattner   };
323*64d52014SChris Lattner 
324*64d52014SChris Lattner   GreedyPatternRewriteDriver driver(std::move(patterns));
325*64d52014SChris Lattner   for (auto &bb : *fn)
326*64d52014SChris Lattner     for (auto &op : bb)
327*64d52014SChris Lattner       driver.addToWorklist(&op);
328*64d52014SChris Lattner 
329*64d52014SChris Lattner   CFGFuncBuilder cfgBuilder(fn);
330*64d52014SChris Lattner   CFGFuncRewriter rewriter(driver, cfgBuilder);
331*64d52014SChris Lattner   driver.simplifyFunction(fn, rewriter);
332*64d52014SChris Lattner }
333*64d52014SChris Lattner 
334*64d52014SChris Lattner /// Rewrite the specified function by repeatedly applying the highest benefit
335*64d52014SChris Lattner /// patterns in a greedy work-list driven manner.
336*64d52014SChris Lattner ///
337*64d52014SChris Lattner void mlir::applyPatternsGreedily(Function *fn, OwningPatternList &&patterns) {
338*64d52014SChris Lattner   if (auto *cfg = dyn_cast<CFGFunction>(fn)) {
339*64d52014SChris Lattner     processCFGFunction(cfg, std::move(patterns));
340*64d52014SChris Lattner   } else {
341*64d52014SChris Lattner     processMLFunction(cast<MLFunction>(fn), std::move(patterns));
342*64d52014SChris Lattner   }
343*64d52014SChris Lattner }
344