1 //===- CheckUses.cpp - Expensive transform value validity checks ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines a pass that performs expensive opt-in checks for Transform
10 // dialect values being potentially used after they have been consumed.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Analysis/DataFlowAnalysis.h"
15 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
16 #include "mlir/Dialect/Transform/Transforms/Passes.h"
17 #include "mlir/Interfaces/SideEffectInterfaces.h"
18 #include "mlir/Pass/Pass.h"
19 #include "llvm/ADT/SetOperations.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 
25 /// Returns a reference to a cached set of blocks that are reachable from the
26 /// given block via edges computed by the `getNextNodes` function. For example,
27 /// if `getNextNodes` returns successors of a block, this will return the set of
28 /// reachable blocks; if it returns predecessors of a block, this will return
29 /// the set of blocks from which the given block can be reached. The block is
30 /// considered reachable form itself only if there is a cycle.
31 template <typename FnTy>
32 const llvm::SmallPtrSet<Block *, 4> &
33 getReachableImpl(Block *block, FnTy getNextNodes,
34                  DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> &cache) {
35   auto it = cache.find(block);
36   if (it != cache.end())
37     return it->getSecond();
38 
39   llvm::SmallPtrSet<Block *, 4> &reachable = cache[block];
40   SmallVector<Block *> worklist;
41   worklist.push_back(block);
42   while (!worklist.empty()) {
43     Block *current = worklist.pop_back_val();
44     for (Block *predecessor : getNextNodes(current)) {
45       // The block is reachable from its transitive predecessors. Only add
46       // them to the worklist if they weren't already visited.
47       if (reachable.insert(predecessor).second)
48         worklist.push_back(predecessor);
49     }
50   }
51   return reachable;
52 }
53 
54 /// An analysis that identifies whether a value allocated by a Transform op may
55 /// be used by another such op after it may have been freed by a third op on
56 /// some control flow path. This is conceptually similar to a data flow
57 /// analysis, but relies on side effects related to particular values that
58 /// currently cannot be modeled by the MLIR data flow analysis framework (also,
59 /// the lattice element would be rather expensive as it would need to include
60 /// live and/or freed values for each operation).
61 ///
62 /// This analysis is conservatively pessimisic: it will consider that a value
63 /// may be freed if it is freed on any possible control flow path between its
64 /// allocation and a relevant use, even if the control never actually flows
65 /// through the operation that frees the value. It also does not differentiate
66 /// between may- (freed on at least one control flow path) and must-free (freed
67 /// on all possible control flow paths) because it would require expensive graph
68 /// algorithms.
69 ///
70 /// It is intended as an additional non-blocking verification or debugging aid
71 /// for ops in the Transform dialect. It leverages the requirement for Transform
72 /// dialect ops to implement the MemoryEffectsOpInterface, and expects the
73 /// values in the Transform IR to have an allocation effect on the
74 /// TransformMappingResource when defined.
75 class TransformOpMemFreeAnalysis {
76 public:
77   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransformOpMemFreeAnalysis)
78 
79   /// Computes the analysis for Transform ops nested in the given operation.
80   explicit TransformOpMemFreeAnalysis(Operation *root) {
81     root->walk([&](Operation *op) {
82       if (isa<transform::TransformOpInterface>(op)) {
83         collectFreedValues(op);
84         return WalkResult::skip();
85       }
86       return WalkResult::advance();
87     });
88   }
89 
90   /// A list of operations that may be deleting a value. Non-empty list
91   /// contextually converts to boolean "true" value.
92   class PotentialDeleters {
93   public:
94     /// Creates an empty list that corresponds to the value being live.
95     static PotentialDeleters live() { return PotentialDeleters({}); }
96 
97     /// Creates a list from the operations that may be deleting the value.
98     static PotentialDeleters maybeFreed(ArrayRef<Operation *> deleters) {
99       return PotentialDeleters(deleters);
100     }
101 
102     /// Converts to "true" if there are operations that may be deleting the
103     /// value.
104     explicit operator bool() const { return !deleters.empty(); }
105 
106     /// Concatenates the lists of operations that may be deleting the value. The
107     /// value is known to be live if the reuslting list is still empty.
108     PotentialDeleters &operator|=(const PotentialDeleters &other) {
109       llvm::append_range(deleters, other.deleters);
110       return *this;
111     }
112 
113     /// Returns the list of ops that may be deleting the value.
114     ArrayRef<Operation *> getOps() const { return deleters; }
115 
116   private:
117     /// Constructs the list from the given operations.
118     explicit PotentialDeleters(ArrayRef<Operation *> ops) {
119       llvm::append_range(deleters, ops);
120     }
121 
122     /// The list of operations that may be deleting the value.
123     SmallVector<Operation *> deleters;
124   };
125 
126   /// Returns the list of operations that may be deleting the operand value on
127   /// any control flow path between the definition of the value and its use as
128   /// the given operand. For the purposes of this analysis, the value is
129   /// considered to be allocated at its definition point and never re-allocated.
130   PotentialDeleters isUseLive(OpOperand &operand) {
131     const llvm::SmallPtrSet<Operation *, 2> &deleters = freedBy[operand.get()];
132     if (deleters.empty())
133       return live();
134 
135 #ifndef NDEBUG
136     // Check that the definition point actually allcoates the value.
137     Operation *valueSource =
138         operand.get().isa<OpResult>()
139             ? operand.get().getDefiningOp()
140             : operand.get().getParentBlock()->getParentOp();
141     auto iface = cast<MemoryEffectOpInterface>(valueSource);
142     SmallVector<MemoryEffects::EffectInstance> instances;
143     iface.getEffectsOnResource(transform::TransformMappingResource::get(),
144                                instances);
145     assert(hasEffect<MemoryEffects::Allocate>(instances, operand.get()) &&
146            "expected the op defining the value to have an allocation effect "
147            "on it");
148 #endif
149 
150     // Collect ancestors of the use operation.
151     Block *defBlock = operand.get().getParentBlock();
152     SmallVector<Operation *> ancestors;
153     Operation *ancestor = operand.getOwner();
154     do {
155       ancestors.push_back(ancestor);
156       if (ancestor->getParentRegion() == defBlock->getParent())
157         break;
158       ancestor = ancestor->getParentOp();
159     } while (true);
160     std::reverse(ancestors.begin(), ancestors.end());
161 
162     // Consider the control flow from the definition point of the value to its
163     // use point. If the use is located in some nested region, consider the path
164     // from the entry block of the region to the use.
165     for (Operation *ancestor : ancestors) {
166       // The block should be considered partially if it is the block that
167       // contains the definition (allocation) of the value being used, and the
168       // value is defined in the middle of the block, i.e., is not a block
169       // argument.
170       bool isOutermost = ancestor == ancestors.front();
171       bool isFromBlockPartial = isOutermost && operand.get().isa<OpResult>();
172 
173       // Check if the value may be freed by operations between its definition
174       // (allocation) point in its block and the terminator of the block or the
175       // ancestor of the use if it is located in the same block. This is only
176       // done for partial blocks here, full blocks will be considered below
177       // similarly to other blocks.
178       if (isFromBlockPartial) {
179         bool defUseSameBlock = ancestor->getBlock() == defBlock;
180         // Consider all ops from the def to its block terminator, except the
181         // when the use is in the same block, in which case only consider the
182         // ops until the user.
183         if (PotentialDeleters potentialDeleters = isFreedInBlockAfter(
184                 operand.get().getDefiningOp(), operand.get(),
185                 defUseSameBlock ? ancestor : nullptr))
186           return potentialDeleters;
187       }
188 
189       // Check if the value may be freed by opeations preceding the ancestor in
190       // its block. Skip the check for partial blocks that contain both the
191       // definition and the use point, as this has been already checked above.
192       if (!isFromBlockPartial || ancestor->getBlock() != defBlock) {
193         if (PotentialDeleters potentialDeleters =
194                 isFreedInBlockBefore(ancestor, operand.get()))
195           return potentialDeleters;
196       }
197 
198       // Check if the value may be freed by operations in any of the blocks
199       // between the definition point (in the outermost region) or the entry
200       // block of the region (in other regions) and the operand or its ancestor
201       // in the region. This includes the entire "form" block if (1) the block
202       // has not been considered as partial above and (2) the block can be
203       // reached again through some control-flow loop. This includes the entire
204       // "to" block if it can be reached form itself through some control-flow
205       // cycle, regardless of whether it has been visited before.
206       Block *ancestorBlock = ancestor->getBlock();
207       Block *from =
208           isOutermost ? defBlock : &ancestorBlock->getParent()->front();
209       if (PotentialDeleters potentialDeleters =
210               isMaybeFreedOnPaths(from, ancestorBlock, operand.get(),
211                                   /*alwaysIncludeFrom=*/!isFromBlockPartial))
212         return potentialDeleters;
213     }
214     return live();
215   }
216 
217 private:
218   /// Make PotentialDeleters constructors available with shorter names.
219   static PotentialDeleters maybeFreed(ArrayRef<Operation *> deleters) {
220     return PotentialDeleters::maybeFreed(deleters);
221   }
222   static PotentialDeleters live() { return PotentialDeleters::live(); }
223 
224   /// Returns the list of operations that may be deleting the given value betwen
225   /// the first and last operations, non-inclusive. `getNext` indicates the
226   /// direction of the traversal.
227   PotentialDeleters
228   isFreedBetween(Value value, Operation *first, Operation *last,
229                  llvm::function_ref<Operation *(Operation *)> getNext) const {
230     auto it = freedBy.find(value);
231     if (it == freedBy.end())
232       return live();
233     const llvm::SmallPtrSet<Operation *, 2> &deleters = it->getSecond();
234     for (Operation *op = getNext(first); op != last; op = getNext(op)) {
235       if (deleters.contains(op))
236         return maybeFreed(op);
237     }
238     return live();
239   }
240 
241   /// Returns the list of operations that may be deleting the given value
242   /// between `root` and `before` values. `root` is expected to be in the same
243   /// block as `before` and precede it. If `before` is null, consider all
244   /// operations until the end of the block including the terminator.
245   PotentialDeleters isFreedInBlockAfter(Operation *root, Value value,
246                                         Operation *before = nullptr) const {
247     return isFreedBetween(value, root, before,
248                           [](Operation *op) { return op->getNextNode(); });
249   }
250 
251   /// Returns the list of operations that may be deleting the given value
252   /// between the entry of the block and the `root` operation.
253   PotentialDeleters isFreedInBlockBefore(Operation *root, Value value) const {
254     return isFreedBetween(value, root, nullptr,
255                           [](Operation *op) { return op->getPrevNode(); });
256   }
257 
258   /// Returns the list of operations that may be deleting the given value on
259   /// any of the control flow paths between the "form" and the "to" block. The
260   /// operations from any block visited on any control flow path are
261   /// consdiered. The "from" block is considered if there is a control flow
262   /// cycle going through it, i.e., if there is a possibility that all
263   /// operations in this block are visited or if the `alwaysIncludeFrom` flag is
264   /// set. The "to" block is considered only if there is a control flow cycle
265   /// going through it.
266   PotentialDeleters isMaybeFreedOnPaths(Block *from, Block *to, Value value,
267                                         bool alwaysIncludeFrom) {
268     // Find all blocks that lie on any path between "from" and "to", i.e., the
269     // intersection of blocks reachable from "from" and blocks from which "to"
270     // is rechable.
271     const llvm::SmallPtrSet<Block *, 4> &sources = getReachableFrom(to);
272     if (!sources.contains(from))
273       return live();
274 
275     llvm::SmallPtrSet<Block *, 4> reachable(getReachable(from));
276     llvm::set_intersect(reachable, sources);
277 
278     // If requested, include the "from" block that may not be present in the set
279     // of visited blocks when there is no cycle going through it.
280     if (alwaysIncludeFrom)
281       reachable.insert(from);
282 
283     // Join potential deleters from all blocks as we don't know here which of
284     // the paths through the control flow is taken.
285     PotentialDeleters potentialDeleters = live();
286     for (Block *block : reachable) {
287       for (Operation &op : *block) {
288         if (freedBy[value].count(&op))
289           potentialDeleters |= maybeFreed(&op);
290       }
291     }
292     return potentialDeleters;
293   }
294 
295   /// Popualtes `reachable` with the set of blocks that are rechable from the
296   /// given block. A block is considered reachable from itself if there is a
297   /// cycle in the control-flow graph that invovles the block.
298   const llvm::SmallPtrSet<Block *, 4> &getReachable(Block *block) {
299     return getReachableImpl(
300         block, [](Block *b) { return b->getSuccessors(); }, reachableCache);
301   }
302 
303   /// Populates `sources` with the set of blocks from which the given block is
304   /// reachable.
305   const llvm::SmallPtrSet<Block *, 4> &getReachableFrom(Block *block) {
306     return getReachableImpl(
307         block, [](Block *b) { return b->getPredecessors(); },
308         reachableFromCache);
309   }
310 
311   /// Returns true of `instances` contains an effect of `EffectTy` on `value`.
312   template <typename EffectTy>
313   static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> instances,
314                         Value value) {
315     return llvm::any_of(instances,
316                         [&](const MemoryEffects::EffectInstance &instance) {
317                           return instance.getValue() == value &&
318                                  isa<EffectTy>(instance.getEffect());
319                         });
320   }
321 
322   /// Records the values that are being freed by an operation or any of its
323   /// children in `freedBy`.
324   void collectFreedValues(Operation *root) {
325     SmallVector<MemoryEffects::EffectInstance> instances;
326     root->walk([&](Operation *child) {
327       // TODO: extend this to conservatively handle operations with undeclared
328       // side effects as maybe freeing the operands.
329       auto iface = cast<MemoryEffectOpInterface>(child);
330       instances.clear();
331       iface.getEffectsOnResource(transform::TransformMappingResource::get(),
332                                  instances);
333       for (Value operand : child->getOperands()) {
334         if (hasEffect<MemoryEffects::Free>(instances, operand)) {
335           // All parents of the operation that frees a value should be
336           // considered as potentially freeing the value as well.
337           //
338           // TODO: differentiate between must-free/may-free as well as between
339           // this op having the effect and children having the effect. This may
340           // require some analysis of all control flow paths through the nested
341           // regions as well as a mechanism to separate proper side effects from
342           // those obtained by nesting.
343           Operation *parent = child;
344           do {
345             freedBy[operand].insert(parent);
346             if (parent == root)
347               break;
348             parent = parent->getParentOp();
349           } while (true);
350         }
351       }
352     });
353   }
354 
355   /// The mapping from a value to operations that have a Free memory effect on
356   /// the TransformMappingResource and associated with this value, or to
357   /// Transform operations transitively containing such operations.
358   DenseMap<Value, llvm::SmallPtrSet<Operation *, 2>> freedBy;
359 
360   /// Caches for sets of reachable blocks.
361   DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> reachableCache;
362   DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> reachableFromCache;
363 };
364 
365 #define GEN_PASS_CLASSES
366 #include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
367 
368 //// A simple pass that warns about any use of a value by a transform operation
369 // that may be using the value after it has been freed.
370 class CheckUsesPass : public CheckUsesBase<CheckUsesPass> {
371 public:
372   void runOnOperation() override {
373     auto &analysis = getAnalysis<TransformOpMemFreeAnalysis>();
374 
375     getOperation()->walk([&](Operation *child) {
376       for (OpOperand &operand : child->getOpOperands()) {
377         TransformOpMemFreeAnalysis::PotentialDeleters deleters =
378             analysis.isUseLive(operand);
379         if (!deleters)
380           continue;
381 
382         InFlightDiagnostic diag = child->emitWarning()
383                                   << "operand #" << operand.getOperandNumber()
384                                   << " may be used after free";
385         diag.attachNote(operand.get().getLoc()) << "allocated here";
386         for (Operation *d : deleters.getOps()) {
387           diag.attachNote(d->getLoc()) << "freed here";
388         }
389       }
390     });
391   }
392 };
393 
394 } // namespace
395 
396 namespace mlir {
397 namespace transform {
398 std::unique_ptr<Pass> createCheckUsesPass() {
399   return std::make_unique<CheckUsesPass>();
400 }
401 } // namespace transform
402 } // namespace mlir
403