1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/Transforms/RegionUtils.h"
19 #include "mlir/IR/Block.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/IR/RegionGraphTraits.h"
22 #include "mlir/IR/Value.h"
23 
24 #include "llvm/ADT/DepthFirstIterator.h"
25 #include "llvm/ADT/PostOrderIterator.h"
26 #include "llvm/ADT/SmallSet.h"
27 
28 using namespace mlir;
29 
30 void mlir::replaceAllUsesInRegionWith(Value *orig, Value *replacement,
31                                       Region &region) {
32   for (IROperand &use : llvm::make_early_inc_range(orig->getUses())) {
33     if (region.isAncestor(use.getOwner()->getParentRegion()))
34       use.set(replacement);
35   }
36 }
37 
38 void mlir::visitUsedValuesDefinedAbove(
39     Region &region, Region &limit,
40     llvm::function_ref<void(OpOperand *)> callback) {
41   assert(limit.isAncestor(&region) &&
42          "expected isolation limit to be an ancestor of the given region");
43 
44   // Collect proper ancestors of `limit` upfront to avoid traversing the region
45   // tree for every value.
46   llvm::SmallPtrSet<Region *, 4> properAncestors;
47   for (auto *reg = limit.getParentRegion(); reg != nullptr;
48        reg = reg->getParentRegion()) {
49     properAncestors.insert(reg);
50   }
51 
52   region.walk([callback, &properAncestors](Operation *op) {
53     for (OpOperand &operand : op->getOpOperands())
54       // Callback on values defined in a proper ancestor of region.
55       if (properAncestors.count(operand.get()->getParentRegion()))
56         callback(&operand);
57   });
58 }
59 
60 void mlir::visitUsedValuesDefinedAbove(
61     llvm::MutableArrayRef<Region> regions,
62     llvm::function_ref<void(OpOperand *)> callback) {
63   for (Region &region : regions)
64     visitUsedValuesDefinedAbove(region, region, callback);
65 }
66 
67 void mlir::getUsedValuesDefinedAbove(Region &region, Region &limit,
68                                      llvm::SetVector<Value *> &values) {
69   visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
70     values.insert(operand->get());
71   });
72 }
73 
74 void mlir::getUsedValuesDefinedAbove(llvm::MutableArrayRef<Region> regions,
75                                      llvm::SetVector<Value *> &values) {
76   for (Region &region : regions)
77     getUsedValuesDefinedAbove(region, region, values);
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // Unreachable Block Elimination
82 //===----------------------------------------------------------------------===//
83 
84 /// Erase the unreachable blocks within the provided regions. Returns success
85 /// if any blocks were erased, failure otherwise.
86 // TODO: We could likely merge this with the DCE algorithm below.
87 static LogicalResult eraseUnreachableBlocks(MutableArrayRef<Region> regions) {
88   // Set of blocks found to be reachable within a given region.
89   llvm::df_iterator_default_set<Block *, 16> reachable;
90   // If any blocks were found to be dead.
91   bool erasedDeadBlocks = false;
92 
93   SmallVector<Region *, 1> worklist;
94   worklist.reserve(regions.size());
95   for (Region &region : regions)
96     worklist.push_back(&region);
97   while (!worklist.empty()) {
98     Region *region = worklist.pop_back_val();
99     if (region->empty())
100       continue;
101 
102     // If this is a single block region, just collect the nested regions.
103     if (std::next(region->begin()) == region->end()) {
104       for (Operation &op : region->front())
105         for (Region &region : op.getRegions())
106           worklist.push_back(&region);
107       continue;
108     }
109 
110     // Mark all reachable blocks.
111     reachable.clear();
112     for (Block *block : depth_first_ext(&region->front(), reachable))
113       (void)block /* Mark all reachable blocks */;
114 
115     // Collect all of the dead blocks and push the live regions onto the
116     // worklist.
117     for (Block &block : llvm::make_early_inc_range(*region)) {
118       if (!reachable.count(&block)) {
119         block.dropAllDefinedValueUses();
120         block.erase();
121         erasedDeadBlocks = true;
122         continue;
123       }
124 
125       // Walk any regions within this block.
126       for (Operation &op : block)
127         for (Region &region : op.getRegions())
128           worklist.push_back(&region);
129     }
130   }
131 
132   return success(erasedDeadBlocks);
133 }
134 
135 //===----------------------------------------------------------------------===//
136 // Dead Code Elimination
137 //===----------------------------------------------------------------------===//
138 
139 namespace {
140 /// Data structure used to track which values have already been proved live.
141 ///
142 /// Because Operation's can have multiple results, this data structure tracks
143 /// liveness for both Value's and Operation's to avoid having to look through
144 /// all Operation results when analyzing a use.
145 ///
146 /// This data structure essentially tracks the dataflow lattice.
147 /// The set of values/ops proved live increases monotonically to a fixed-point.
148 class LiveMap {
149 public:
150   /// Value methods.
151   bool wasProvenLive(Value *value) { return liveValues.count(value); }
152   void setProvedLive(Value *value) {
153     changed |= liveValues.insert(value).second;
154   }
155 
156   /// Operation methods.
157   bool wasProvenLive(Operation *op) { return liveOps.count(op); }
158   void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; }
159 
160   /// Methods for tracking if we have reached a fixed-point.
161   void resetChanged() { changed = false; }
162   bool hasChanged() { return changed; }
163 
164 private:
165   bool changed = false;
166   DenseSet<Value *> liveValues;
167   DenseSet<Operation *> liveOps;
168 };
169 } // namespace
170 
171 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) {
172   Operation *owner = use.getOwner();
173   unsigned operandIndex = use.getOperandNumber();
174   // This pass generally treats all uses of an op as live if the op itself is
175   // considered live. However, for successor operands to terminators we need a
176   // finer-grained notion where we deduce liveness for operands individually.
177   // The reason for this is easiest to think about in terms of a classical phi
178   // node based SSA IR, where each successor operand is really an operand to a
179   // *separate* phi node, rather than all operands to the branch itself as with
180   // the block argument representation that MLIR uses.
181   //
182   // And similarly, because each successor operand is really an operand to a phi
183   // node, rather than to the terminator op itself, a terminator op can't e.g.
184   // "print" the value of a successor operand.
185   if (owner->isKnownTerminator()) {
186     if (auto arg = owner->getSuccessorBlockArgument(operandIndex))
187       return !liveMap.wasProvenLive(*arg);
188     return false;
189   }
190   return false;
191 }
192 
193 static void processValue(Value *value, LiveMap &liveMap) {
194   bool provedLive = llvm::any_of(value->getUses(), [&](OpOperand &use) {
195     if (isUseSpeciallyKnownDead(use, liveMap))
196       return false;
197     return liveMap.wasProvenLive(use.getOwner());
198   });
199   if (provedLive)
200     liveMap.setProvedLive(value);
201 }
202 
203 static bool isOpIntrinsicallyLive(Operation *op) {
204   // This pass doesn't modify the CFG, so terminators are never deleted.
205   if (!op->isKnownNonTerminator())
206     return true;
207   // If the op has a side effect, we treat it as live.
208   if (!op->hasNoSideEffect())
209     return true;
210   return false;
211 }
212 
213 static void propagateLiveness(Region &region, LiveMap &liveMap);
214 static void propagateLiveness(Operation *op, LiveMap &liveMap) {
215   // All Value's are either a block argument or an op result.
216   // We call processValue on those cases.
217 
218   // Recurse on any regions the op has.
219   for (Region &region : op->getRegions())
220     propagateLiveness(region, liveMap);
221 
222   // Process the op itself.
223   if (isOpIntrinsicallyLive(op)) {
224     liveMap.setProvedLive(op);
225     return;
226   }
227   for (Value *value : op->getResults())
228     processValue(value, liveMap);
229   bool provedLive = llvm::any_of(op->getResults(), [&](Value *value) {
230     return liveMap.wasProvenLive(value);
231   });
232   if (provedLive)
233     liveMap.setProvedLive(op);
234 }
235 
236 static void propagateLiveness(Region &region, LiveMap &liveMap) {
237   if (region.empty())
238     return;
239 
240   for (Block *block : llvm::post_order(&region.front())) {
241     // We process block arguments after the ops in the block, to promote
242     // faster convergence to a fixed point (we try to visit uses before defs).
243     for (Operation &op : llvm::reverse(block->getOperations()))
244       propagateLiveness(&op, liveMap);
245     for (Value *value : block->getArguments())
246       processValue(value, liveMap);
247   }
248 }
249 
250 static void eraseTerminatorSuccessorOperands(Operation *terminator,
251                                              LiveMap &liveMap) {
252   for (unsigned succI = 0, succE = terminator->getNumSuccessors();
253        succI < succE; succI++) {
254     // Iterating successors in reverse is not strictly needed, since we
255     // aren't erasing any successors. But it is slightly more efficient
256     // since it will promote later operands of the terminator being erased
257     // first, reducing the quadratic-ness.
258     unsigned succ = succE - succI - 1;
259     for (unsigned argI = 0, argE = terminator->getNumSuccessorOperands(succ);
260          argI < argE; argI++) {
261       // Iterating args in reverse is needed for correctness, to avoid
262       // shifting later args when earlier args are erased.
263       unsigned arg = argE - argI - 1;
264       Value *value = terminator->getSuccessor(succ)->getArgument(arg);
265       if (!liveMap.wasProvenLive(value)) {
266         terminator->eraseSuccessorOperand(succ, arg);
267       }
268     }
269   }
270 }
271 
272 static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
273                                     LiveMap &liveMap) {
274   bool erasedAnything = false;
275   for (Region &region : regions) {
276     if (region.empty())
277       continue;
278 
279     // We do the deletion in an order that deletes all uses before deleting
280     // defs.
281     // MLIR's SSA structural invariants guarantee that except for block
282     // arguments, the use-def graph is acyclic, so this is possible with a
283     // single walk of ops and then a final pass to clean up block arguments.
284     //
285     // To do this, we visit ops in an order that visits domtree children
286     // before domtree parents. A CFG post-order (with reverse iteration with a
287     // block) satisfies that without needing an explicit domtree calculation.
288     for (Block *block : llvm::post_order(&region.front())) {
289       eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
290       for (Operation &childOp :
291            llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
292         erasedAnything |=
293             succeeded(deleteDeadness(childOp.getRegions(), liveMap));
294         if (!liveMap.wasProvenLive(&childOp)) {
295           erasedAnything = true;
296           childOp.erase();
297         }
298       }
299     }
300     // Delete block arguments.
301     // The entry block has an unknown contract with their enclosing block, so
302     // skip it.
303     for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
304       // Iterate in reverse to avoid shifting later arguments when deleting
305       // earlier arguments.
306       for (unsigned i = 0, e = block.getNumArguments(); i < e; i++)
307         if (!liveMap.wasProvenLive(block.getArgument(e - i - 1))) {
308           block.eraseArgument(e - i - 1, /*updatePredTerms=*/false);
309           erasedAnything = true;
310         }
311     }
312   }
313   return success(erasedAnything);
314 }
315 
316 // This function performs a simple dead code elimination algorithm over the
317 // given regions.
318 //
319 // The overall goal is to prove that Values are dead, which allows deleting ops
320 // and block arguments.
321 //
322 // This uses an optimistic algorithm that assumes everything is dead until
323 // proved otherwise, allowing it to delete recursively dead cycles.
324 //
325 // This is a simple fixed-point dataflow analysis algorithm on a lattice
326 // {Dead,Alive}. Because liveness flows backward, we generally try to
327 // iterate everything backward to speed up convergence to the fixed-point. This
328 // allows for being able to delete recursively dead cycles of the use-def graph,
329 // including block arguments.
330 //
331 // This function returns success if any operations or arguments were deleted,
332 // failure otherwise.
333 static LogicalResult runRegionDCE(MutableArrayRef<Region> regions) {
334   assert(regions.size() == 1);
335 
336   LiveMap liveMap;
337   do {
338     liveMap.resetChanged();
339 
340     for (Region &region : regions)
341       propagateLiveness(region, liveMap);
342   } while (liveMap.hasChanged());
343 
344   return deleteDeadness(regions, liveMap);
345 }
346 
347 //===----------------------------------------------------------------------===//
348 // Region Simplification
349 //===----------------------------------------------------------------------===//
350 
351 /// Run a set of structural simplifications over the given regions. This
352 /// includes transformations like unreachable block elimination, dead argument
353 /// elimination, as well as some other DCE. This function returns success if any
354 /// of the regions were simplified, failure otherwise.
355 LogicalResult mlir::simplifyRegions(llvm::MutableArrayRef<Region> regions) {
356   LogicalResult eliminatedBlocks = eraseUnreachableBlocks(regions);
357   LogicalResult eliminatedOpsOrArgs = runRegionDCE(regions);
358   return success(succeeded(eliminatedBlocks) || succeeded(eliminatedOpsOrArgs));
359 }
360