1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/RegionUtils.h"
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/IR/RegionGraphTraits.h"
13 #include "mlir/IR/Value.h"
14 #include "mlir/Interfaces/ControlFlowInterfaces.h"
15 #include "mlir/Interfaces/SideEffects.h"
16 
17 #include "llvm/ADT/DepthFirstIterator.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/ADT/SmallSet.h"
20 
21 using namespace mlir;
22 
23 void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement,
24                                       Region &region) {
25   for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
26     if (region.isAncestor(use.getOwner()->getParentRegion()))
27       use.set(replacement);
28   }
29 }
30 
31 void mlir::visitUsedValuesDefinedAbove(
32     Region &region, Region &limit, function_ref<void(OpOperand *)> callback) {
33   assert(limit.isAncestor(&region) &&
34          "expected isolation limit to be an ancestor of the given region");
35 
36   // Collect proper ancestors of `limit` upfront to avoid traversing the region
37   // tree for every value.
38   SmallPtrSet<Region *, 4> properAncestors;
39   for (auto *reg = limit.getParentRegion(); reg != nullptr;
40        reg = reg->getParentRegion()) {
41     properAncestors.insert(reg);
42   }
43 
44   region.walk([callback, &properAncestors](Operation *op) {
45     for (OpOperand &operand : op->getOpOperands())
46       // Callback on values defined in a proper ancestor of region.
47       if (properAncestors.count(operand.get().getParentRegion()))
48         callback(&operand);
49   });
50 }
51 
52 void mlir::visitUsedValuesDefinedAbove(
53     MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) {
54   for (Region &region : regions)
55     visitUsedValuesDefinedAbove(region, region, callback);
56 }
57 
58 void mlir::getUsedValuesDefinedAbove(Region &region, Region &limit,
59                                      llvm::SetVector<Value> &values) {
60   visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
61     values.insert(operand->get());
62   });
63 }
64 
65 void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
66                                      llvm::SetVector<Value> &values) {
67   for (Region &region : regions)
68     getUsedValuesDefinedAbove(region, region, values);
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // Unreachable Block Elimination
73 //===----------------------------------------------------------------------===//
74 
75 /// Erase the unreachable blocks within the provided regions. Returns success
76 /// if any blocks were erased, failure otherwise.
77 // TODO: We could likely merge this with the DCE algorithm below.
78 static LogicalResult eraseUnreachableBlocks(MutableArrayRef<Region> regions) {
79   // Set of blocks found to be reachable within a given region.
80   llvm::df_iterator_default_set<Block *, 16> reachable;
81   // If any blocks were found to be dead.
82   bool erasedDeadBlocks = false;
83 
84   SmallVector<Region *, 1> worklist;
85   worklist.reserve(regions.size());
86   for (Region &region : regions)
87     worklist.push_back(&region);
88   while (!worklist.empty()) {
89     Region *region = worklist.pop_back_val();
90     if (region->empty())
91       continue;
92 
93     // If this is a single block region, just collect the nested regions.
94     if (std::next(region->begin()) == region->end()) {
95       for (Operation &op : region->front())
96         for (Region &region : op.getRegions())
97           worklist.push_back(&region);
98       continue;
99     }
100 
101     // Mark all reachable blocks.
102     reachable.clear();
103     for (Block *block : depth_first_ext(&region->front(), reachable))
104       (void)block /* Mark all reachable blocks */;
105 
106     // Collect all of the dead blocks and push the live regions onto the
107     // worklist.
108     for (Block &block : llvm::make_early_inc_range(*region)) {
109       if (!reachable.count(&block)) {
110         block.dropAllDefinedValueUses();
111         block.erase();
112         erasedDeadBlocks = true;
113         continue;
114       }
115 
116       // Walk any regions within this block.
117       for (Operation &op : block)
118         for (Region &region : op.getRegions())
119           worklist.push_back(&region);
120     }
121   }
122 
123   return success(erasedDeadBlocks);
124 }
125 
126 //===----------------------------------------------------------------------===//
127 // Dead Code Elimination
128 //===----------------------------------------------------------------------===//
129 
130 namespace {
131 /// Data structure used to track which values have already been proved live.
132 ///
133 /// Because Operation's can have multiple results, this data structure tracks
134 /// liveness for both Value's and Operation's to avoid having to look through
135 /// all Operation results when analyzing a use.
136 ///
137 /// This data structure essentially tracks the dataflow lattice.
138 /// The set of values/ops proved live increases monotonically to a fixed-point.
139 class LiveMap {
140 public:
141   /// Value methods.
142   bool wasProvenLive(Value value) { return liveValues.count(value); }
143   void setProvedLive(Value value) {
144     changed |= liveValues.insert(value).second;
145   }
146 
147   /// Operation methods.
148   bool wasProvenLive(Operation *op) { return liveOps.count(op); }
149   void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; }
150 
151   /// Methods for tracking if we have reached a fixed-point.
152   void resetChanged() { changed = false; }
153   bool hasChanged() { return changed; }
154 
155 private:
156   bool changed = false;
157   DenseSet<Value> liveValues;
158   DenseSet<Operation *> liveOps;
159 };
160 } // namespace
161 
162 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) {
163   Operation *owner = use.getOwner();
164   unsigned operandIndex = use.getOperandNumber();
165   // This pass generally treats all uses of an op as live if the op itself is
166   // considered live. However, for successor operands to terminators we need a
167   // finer-grained notion where we deduce liveness for operands individually.
168   // The reason for this is easiest to think about in terms of a classical phi
169   // node based SSA IR, where each successor operand is really an operand to a
170   // *separate* phi node, rather than all operands to the branch itself as with
171   // the block argument representation that MLIR uses.
172   //
173   // And similarly, because each successor operand is really an operand to a phi
174   // node, rather than to the terminator op itself, a terminator op can't e.g.
175   // "print" the value of a successor operand.
176   if (owner->isKnownTerminator()) {
177     if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner))
178       if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex))
179         return !liveMap.wasProvenLive(*arg);
180     return false;
181   }
182   return false;
183 }
184 
185 static void processValue(Value value, LiveMap &liveMap) {
186   bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) {
187     if (isUseSpeciallyKnownDead(use, liveMap))
188       return false;
189     return liveMap.wasProvenLive(use.getOwner());
190   });
191   if (provedLive)
192     liveMap.setProvedLive(value);
193 }
194 
195 static bool isOpIntrinsicallyLive(Operation *op) {
196   // This pass doesn't modify the CFG, so terminators are never deleted.
197   if (!op->isKnownNonTerminator())
198     return true;
199   // If the op has a side effect, we treat it as live.
200   // TODO: Properly handle region side effects.
201   return !MemoryEffectOpInterface::hasNoEffect(op) || op->getNumRegions() != 0;
202 }
203 
204 static void propagateLiveness(Region &region, LiveMap &liveMap);
205 
206 static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
207   // Terminators are always live.
208   liveMap.setProvedLive(op);
209 
210   // Check to see if we can reason about the successor operands and mutate them.
211   BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
212   if (!branchInterface || !branchInterface.canEraseSuccessorOperand()) {
213     for (Block *successor : op->getSuccessors())
214       for (BlockArgument arg : successor->getArguments())
215         liveMap.setProvedLive(arg);
216     return;
217   }
218 
219   // If we can't reason about the operands to a successor, conservatively mark
220   // all arguments as live.
221   for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
222     if (!branchInterface.getSuccessorOperands(i))
223       for (BlockArgument arg : op->getSuccessor(i)->getArguments())
224         liveMap.setProvedLive(arg);
225   }
226 }
227 
228 static void propagateLiveness(Operation *op, LiveMap &liveMap) {
229   // All Value's are either a block argument or an op result.
230   // We call processValue on those cases.
231 
232   // Recurse on any regions the op has.
233   for (Region &region : op->getRegions())
234     propagateLiveness(region, liveMap);
235 
236   // Process terminator operations.
237   if (op->isKnownTerminator())
238     return propagateTerminatorLiveness(op, liveMap);
239 
240   // Process the op itself.
241   if (isOpIntrinsicallyLive(op)) {
242     liveMap.setProvedLive(op);
243     return;
244   }
245   for (Value value : op->getResults())
246     processValue(value, liveMap);
247   bool provedLive = llvm::any_of(op->getResults(), [&](Value value) {
248     return liveMap.wasProvenLive(value);
249   });
250   if (provedLive)
251     liveMap.setProvedLive(op);
252 }
253 
254 static void propagateLiveness(Region &region, LiveMap &liveMap) {
255   if (region.empty())
256     return;
257 
258   for (Block *block : llvm::post_order(&region.front())) {
259     // We process block arguments after the ops in the block, to promote
260     // faster convergence to a fixed point (we try to visit uses before defs).
261     for (Operation &op : llvm::reverse(block->getOperations()))
262       propagateLiveness(&op, liveMap);
263     for (Value value : block->getArguments())
264       processValue(value, liveMap);
265   }
266 }
267 
268 static void eraseTerminatorSuccessorOperands(Operation *terminator,
269                                              LiveMap &liveMap) {
270   BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator);
271   if (!branchOp)
272     return;
273 
274   for (unsigned succI = 0, succE = terminator->getNumSuccessors();
275        succI < succE; succI++) {
276     // Iterating successors in reverse is not strictly needed, since we
277     // aren't erasing any successors. But it is slightly more efficient
278     // since it will promote later operands of the terminator being erased
279     // first, reducing the quadratic-ness.
280     unsigned succ = succE - succI - 1;
281     Optional<OperandRange> succOperands = branchOp.getSuccessorOperands(succ);
282     if (!succOperands)
283       continue;
284     Block *successor = terminator->getSuccessor(succ);
285 
286     for (unsigned argI = 0, argE = succOperands->size(); argI < argE; ++argI) {
287       // Iterating args in reverse is needed for correctness, to avoid
288       // shifting later args when earlier args are erased.
289       unsigned arg = argE - argI - 1;
290       if (!liveMap.wasProvenLive(successor->getArgument(arg)))
291         branchOp.eraseSuccessorOperand(succ, arg);
292     }
293   }
294 }
295 
296 static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
297                                     LiveMap &liveMap) {
298   bool erasedAnything = false;
299   for (Region &region : regions) {
300     if (region.empty())
301       continue;
302 
303     // We do the deletion in an order that deletes all uses before deleting
304     // defs.
305     // MLIR's SSA structural invariants guarantee that except for block
306     // arguments, the use-def graph is acyclic, so this is possible with a
307     // single walk of ops and then a final pass to clean up block arguments.
308     //
309     // To do this, we visit ops in an order that visits domtree children
310     // before domtree parents. A CFG post-order (with reverse iteration with a
311     // block) satisfies that without needing an explicit domtree calculation.
312     for (Block *block : llvm::post_order(&region.front())) {
313       eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
314       for (Operation &childOp :
315            llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
316         erasedAnything |=
317             succeeded(deleteDeadness(childOp.getRegions(), liveMap));
318         if (!liveMap.wasProvenLive(&childOp)) {
319           erasedAnything = true;
320           childOp.erase();
321         }
322       }
323     }
324     // Delete block arguments.
325     // The entry block has an unknown contract with their enclosing block, so
326     // skip it.
327     for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
328       // Iterate in reverse to avoid shifting later arguments when deleting
329       // earlier arguments.
330       for (unsigned i = 0, e = block.getNumArguments(); i < e; i++)
331         if (!liveMap.wasProvenLive(block.getArgument(e - i - 1))) {
332           block.eraseArgument(e - i - 1);
333           erasedAnything = true;
334         }
335     }
336   }
337   return success(erasedAnything);
338 }
339 
340 // This function performs a simple dead code elimination algorithm over the
341 // given regions.
342 //
343 // The overall goal is to prove that Values are dead, which allows deleting ops
344 // and block arguments.
345 //
346 // This uses an optimistic algorithm that assumes everything is dead until
347 // proved otherwise, allowing it to delete recursively dead cycles.
348 //
349 // This is a simple fixed-point dataflow analysis algorithm on a lattice
350 // {Dead,Alive}. Because liveness flows backward, we generally try to
351 // iterate everything backward to speed up convergence to the fixed-point. This
352 // allows for being able to delete recursively dead cycles of the use-def graph,
353 // including block arguments.
354 //
355 // This function returns success if any operations or arguments were deleted,
356 // failure otherwise.
357 static LogicalResult runRegionDCE(MutableArrayRef<Region> regions) {
358   LiveMap liveMap;
359   do {
360     liveMap.resetChanged();
361 
362     for (Region &region : regions)
363       propagateLiveness(region, liveMap);
364   } while (liveMap.hasChanged());
365 
366   return deleteDeadness(regions, liveMap);
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // Region Simplification
371 //===----------------------------------------------------------------------===//
372 
373 /// Run a set of structural simplifications over the given regions. This
374 /// includes transformations like unreachable block elimination, dead argument
375 /// elimination, as well as some other DCE. This function returns success if any
376 /// of the regions were simplified, failure otherwise.
377 LogicalResult mlir::simplifyRegions(MutableArrayRef<Region> regions) {
378   LogicalResult eliminatedBlocks = eraseUnreachableBlocks(regions);
379   LogicalResult eliminatedOpsOrArgs = runRegionDCE(regions);
380   return success(succeeded(eliminatedBlocks) || succeeded(eliminatedOpsOrArgs));
381 }
382