1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/RegionUtils.h"
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/IR/RegionGraphTraits.h"
13 #include "mlir/IR/Value.h"
14 
15 #include "llvm/ADT/DepthFirstIterator.h"
16 #include "llvm/ADT/PostOrderIterator.h"
17 #include "llvm/ADT/SmallSet.h"
18 
19 using namespace mlir;
20 
21 void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement,
22                                       Region &region) {
23   for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
24     if (region.isAncestor(use.getOwner()->getParentRegion()))
25       use.set(replacement);
26   }
27 }
28 
29 void mlir::visitUsedValuesDefinedAbove(
30     Region &region, Region &limit, function_ref<void(OpOperand *)> callback) {
31   assert(limit.isAncestor(&region) &&
32          "expected isolation limit to be an ancestor of the given region");
33 
34   // Collect proper ancestors of `limit` upfront to avoid traversing the region
35   // tree for every value.
36   SmallPtrSet<Region *, 4> properAncestors;
37   for (auto *reg = limit.getParentRegion(); reg != nullptr;
38        reg = reg->getParentRegion()) {
39     properAncestors.insert(reg);
40   }
41 
42   region.walk([callback, &properAncestors](Operation *op) {
43     for (OpOperand &operand : op->getOpOperands())
44       // Callback on values defined in a proper ancestor of region.
45       if (properAncestors.count(operand.get().getParentRegion()))
46         callback(&operand);
47   });
48 }
49 
50 void mlir::visitUsedValuesDefinedAbove(
51     MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) {
52   for (Region &region : regions)
53     visitUsedValuesDefinedAbove(region, region, callback);
54 }
55 
56 void mlir::getUsedValuesDefinedAbove(Region &region, Region &limit,
57                                      llvm::SetVector<Value> &values) {
58   visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
59     values.insert(operand->get());
60   });
61 }
62 
63 void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
64                                      llvm::SetVector<Value> &values) {
65   for (Region &region : regions)
66     getUsedValuesDefinedAbove(region, region, values);
67 }
68 
69 //===----------------------------------------------------------------------===//
70 // Unreachable Block Elimination
71 //===----------------------------------------------------------------------===//
72 
73 /// Erase the unreachable blocks within the provided regions. Returns success
74 /// if any blocks were erased, failure otherwise.
75 // TODO: We could likely merge this with the DCE algorithm below.
76 static LogicalResult eraseUnreachableBlocks(MutableArrayRef<Region> regions) {
77   // Set of blocks found to be reachable within a given region.
78   llvm::df_iterator_default_set<Block *, 16> reachable;
79   // If any blocks were found to be dead.
80   bool erasedDeadBlocks = false;
81 
82   SmallVector<Region *, 1> worklist;
83   worklist.reserve(regions.size());
84   for (Region &region : regions)
85     worklist.push_back(&region);
86   while (!worklist.empty()) {
87     Region *region = worklist.pop_back_val();
88     if (region->empty())
89       continue;
90 
91     // If this is a single block region, just collect the nested regions.
92     if (std::next(region->begin()) == region->end()) {
93       for (Operation &op : region->front())
94         for (Region &region : op.getRegions())
95           worklist.push_back(&region);
96       continue;
97     }
98 
99     // Mark all reachable blocks.
100     reachable.clear();
101     for (Block *block : depth_first_ext(&region->front(), reachable))
102       (void)block /* Mark all reachable blocks */;
103 
104     // Collect all of the dead blocks and push the live regions onto the
105     // worklist.
106     for (Block &block : llvm::make_early_inc_range(*region)) {
107       if (!reachable.count(&block)) {
108         block.dropAllDefinedValueUses();
109         block.erase();
110         erasedDeadBlocks = true;
111         continue;
112       }
113 
114       // Walk any regions within this block.
115       for (Operation &op : block)
116         for (Region &region : op.getRegions())
117           worklist.push_back(&region);
118     }
119   }
120 
121   return success(erasedDeadBlocks);
122 }
123 
124 //===----------------------------------------------------------------------===//
125 // Dead Code Elimination
126 //===----------------------------------------------------------------------===//
127 
128 namespace {
129 /// Data structure used to track which values have already been proved live.
130 ///
131 /// Because Operation's can have multiple results, this data structure tracks
132 /// liveness for both Value's and Operation's to avoid having to look through
133 /// all Operation results when analyzing a use.
134 ///
135 /// This data structure essentially tracks the dataflow lattice.
136 /// The set of values/ops proved live increases monotonically to a fixed-point.
137 class LiveMap {
138 public:
139   /// Value methods.
140   bool wasProvenLive(Value value) { return liveValues.count(value); }
141   void setProvedLive(Value value) {
142     changed |= liveValues.insert(value).second;
143   }
144 
145   /// Operation methods.
146   bool wasProvenLive(Operation *op) { return liveOps.count(op); }
147   void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; }
148 
149   /// Methods for tracking if we have reached a fixed-point.
150   void resetChanged() { changed = false; }
151   bool hasChanged() { return changed; }
152 
153 private:
154   bool changed = false;
155   DenseSet<Value> liveValues;
156   DenseSet<Operation *> liveOps;
157 };
158 } // namespace
159 
160 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) {
161   Operation *owner = use.getOwner();
162   unsigned operandIndex = use.getOperandNumber();
163   // This pass generally treats all uses of an op as live if the op itself is
164   // considered live. However, for successor operands to terminators we need a
165   // finer-grained notion where we deduce liveness for operands individually.
166   // The reason for this is easiest to think about in terms of a classical phi
167   // node based SSA IR, where each successor operand is really an operand to a
168   // *separate* phi node, rather than all operands to the branch itself as with
169   // the block argument representation that MLIR uses.
170   //
171   // And similarly, because each successor operand is really an operand to a phi
172   // node, rather than to the terminator op itself, a terminator op can't e.g.
173   // "print" the value of a successor operand.
174   if (owner->isKnownTerminator()) {
175     if (auto arg = owner->getSuccessorBlockArgument(operandIndex))
176       return !liveMap.wasProvenLive(*arg);
177     return false;
178   }
179   return false;
180 }
181 
182 static void processValue(Value value, LiveMap &liveMap) {
183   bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) {
184     if (isUseSpeciallyKnownDead(use, liveMap))
185       return false;
186     return liveMap.wasProvenLive(use.getOwner());
187   });
188   if (provedLive)
189     liveMap.setProvedLive(value);
190 }
191 
192 static bool isOpIntrinsicallyLive(Operation *op) {
193   // This pass doesn't modify the CFG, so terminators are never deleted.
194   if (!op->isKnownNonTerminator())
195     return true;
196   // If the op has a side effect, we treat it as live.
197   if (!op->hasNoSideEffect())
198     return true;
199   return false;
200 }
201 
202 static void propagateLiveness(Region &region, LiveMap &liveMap);
203 static void propagateLiveness(Operation *op, LiveMap &liveMap) {
204   // All Value's are either a block argument or an op result.
205   // We call processValue on those cases.
206 
207   // Recurse on any regions the op has.
208   for (Region &region : op->getRegions())
209     propagateLiveness(region, liveMap);
210 
211   // Process the op itself.
212   if (isOpIntrinsicallyLive(op)) {
213     liveMap.setProvedLive(op);
214     return;
215   }
216   for (Value value : op->getResults())
217     processValue(value, liveMap);
218   bool provedLive = llvm::any_of(op->getResults(), [&](Value value) {
219     return liveMap.wasProvenLive(value);
220   });
221   if (provedLive)
222     liveMap.setProvedLive(op);
223 }
224 
225 static void propagateLiveness(Region &region, LiveMap &liveMap) {
226   if (region.empty())
227     return;
228 
229   for (Block *block : llvm::post_order(&region.front())) {
230     // We process block arguments after the ops in the block, to promote
231     // faster convergence to a fixed point (we try to visit uses before defs).
232     for (Operation &op : llvm::reverse(block->getOperations()))
233       propagateLiveness(&op, liveMap);
234     for (Value value : block->getArguments())
235       processValue(value, liveMap);
236   }
237 }
238 
239 static void eraseTerminatorSuccessorOperands(Operation *terminator,
240                                              LiveMap &liveMap) {
241   for (unsigned succI = 0, succE = terminator->getNumSuccessors();
242        succI < succE; succI++) {
243     // Iterating successors in reverse is not strictly needed, since we
244     // aren't erasing any successors. But it is slightly more efficient
245     // since it will promote later operands of the terminator being erased
246     // first, reducing the quadratic-ness.
247     unsigned succ = succE - succI - 1;
248     for (unsigned argI = 0, argE = terminator->getNumSuccessorOperands(succ);
249          argI < argE; argI++) {
250       // Iterating args in reverse is needed for correctness, to avoid
251       // shifting later args when earlier args are erased.
252       unsigned arg = argE - argI - 1;
253       Value value = terminator->getSuccessor(succ)->getArgument(arg);
254       if (!liveMap.wasProvenLive(value)) {
255         terminator->eraseSuccessorOperand(succ, arg);
256       }
257     }
258   }
259 }
260 
261 static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
262                                     LiveMap &liveMap) {
263   bool erasedAnything = false;
264   for (Region &region : regions) {
265     if (region.empty())
266       continue;
267 
268     // We do the deletion in an order that deletes all uses before deleting
269     // defs.
270     // MLIR's SSA structural invariants guarantee that except for block
271     // arguments, the use-def graph is acyclic, so this is possible with a
272     // single walk of ops and then a final pass to clean up block arguments.
273     //
274     // To do this, we visit ops in an order that visits domtree children
275     // before domtree parents. A CFG post-order (with reverse iteration with a
276     // block) satisfies that without needing an explicit domtree calculation.
277     for (Block *block : llvm::post_order(&region.front())) {
278       eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
279       for (Operation &childOp :
280            llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
281         erasedAnything |=
282             succeeded(deleteDeadness(childOp.getRegions(), liveMap));
283         if (!liveMap.wasProvenLive(&childOp)) {
284           erasedAnything = true;
285           childOp.erase();
286         }
287       }
288     }
289     // Delete block arguments.
290     // The entry block has an unknown contract with their enclosing block, so
291     // skip it.
292     for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
293       // Iterate in reverse to avoid shifting later arguments when deleting
294       // earlier arguments.
295       for (unsigned i = 0, e = block.getNumArguments(); i < e; i++)
296         if (!liveMap.wasProvenLive(block.getArgument(e - i - 1))) {
297           block.eraseArgument(e - i - 1, /*updatePredTerms=*/false);
298           erasedAnything = true;
299         }
300     }
301   }
302   return success(erasedAnything);
303 }
304 
305 // This function performs a simple dead code elimination algorithm over the
306 // given regions.
307 //
308 // The overall goal is to prove that Values are dead, which allows deleting ops
309 // and block arguments.
310 //
311 // This uses an optimistic algorithm that assumes everything is dead until
312 // proved otherwise, allowing it to delete recursively dead cycles.
313 //
314 // This is a simple fixed-point dataflow analysis algorithm on a lattice
315 // {Dead,Alive}. Because liveness flows backward, we generally try to
316 // iterate everything backward to speed up convergence to the fixed-point. This
317 // allows for being able to delete recursively dead cycles of the use-def graph,
318 // including block arguments.
319 //
320 // This function returns success if any operations or arguments were deleted,
321 // failure otherwise.
322 static LogicalResult runRegionDCE(MutableArrayRef<Region> regions) {
323   LiveMap liveMap;
324   do {
325     liveMap.resetChanged();
326 
327     for (Region &region : regions)
328       propagateLiveness(region, liveMap);
329   } while (liveMap.hasChanged());
330 
331   return deleteDeadness(regions, liveMap);
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // Region Simplification
336 //===----------------------------------------------------------------------===//
337 
338 /// Run a set of structural simplifications over the given regions. This
339 /// includes transformations like unreachable block elimination, dead argument
340 /// elimination, as well as some other DCE. This function returns success if any
341 /// of the regions were simplified, failure otherwise.
342 LogicalResult mlir::simplifyRegions(MutableArrayRef<Region> regions) {
343   LogicalResult eliminatedBlocks = eraseUnreachableBlocks(regions);
344   LogicalResult eliminatedOpsOrArgs = runRegionDCE(regions);
345   return success(succeeded(eliminatedBlocks) || succeeded(eliminatedOpsOrArgs));
346 }
347