1 //===- CheckUses.cpp - Expensive transform value validity checks ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines a pass that performs expensive opt-in checks for Transform
10 // dialect values being potentially used after they have been consumed.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
15 #include "mlir/Dialect/Transform/Transforms/Passes.h"
16 #include "mlir/Interfaces/SideEffectInterfaces.h"
17 #include "mlir/Pass/Pass.h"
18 #include "llvm/ADT/SetOperations.h"
19 
20 using namespace mlir;
21 
22 namespace {
23 
24 /// Returns a reference to a cached set of blocks that are reachable from the
25 /// given block via edges computed by the `getNextNodes` function. For example,
26 /// if `getNextNodes` returns successors of a block, this will return the set of
27 /// reachable blocks; if it returns predecessors of a block, this will return
28 /// the set of blocks from which the given block can be reached. The block is
29 /// considered reachable form itself only if there is a cycle.
30 template <typename FnTy>
31 const llvm::SmallPtrSet<Block *, 4> &
getReachableImpl(Block * block,FnTy getNextNodes,DenseMap<Block *,llvm::SmallPtrSet<Block *,4>> & cache)32 getReachableImpl(Block *block, FnTy getNextNodes,
33                  DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> &cache) {
34   auto it = cache.find(block);
35   if (it != cache.end())
36     return it->getSecond();
37 
38   llvm::SmallPtrSet<Block *, 4> &reachable = cache[block];
39   SmallVector<Block *> worklist;
40   worklist.push_back(block);
41   while (!worklist.empty()) {
42     Block *current = worklist.pop_back_val();
43     for (Block *predecessor : getNextNodes(current)) {
44       // The block is reachable from its transitive predecessors. Only add
45       // them to the worklist if they weren't already visited.
46       if (reachable.insert(predecessor).second)
47         worklist.push_back(predecessor);
48     }
49   }
50   return reachable;
51 }
52 
53 /// An analysis that identifies whether a value allocated by a Transform op may
54 /// be used by another such op after it may have been freed by a third op on
55 /// some control flow path. This is conceptually similar to a data flow
56 /// analysis, but relies on side effects related to particular values that
57 /// currently cannot be modeled by the MLIR data flow analysis framework (also,
58 /// the lattice element would be rather expensive as it would need to include
59 /// live and/or freed values for each operation).
60 ///
61 /// This analysis is conservatively pessimisic: it will consider that a value
62 /// may be freed if it is freed on any possible control flow path between its
63 /// allocation and a relevant use, even if the control never actually flows
64 /// through the operation that frees the value. It also does not differentiate
65 /// between may- (freed on at least one control flow path) and must-free (freed
66 /// on all possible control flow paths) because it would require expensive graph
67 /// algorithms.
68 ///
69 /// It is intended as an additional non-blocking verification or debugging aid
70 /// for ops in the Transform dialect. It leverages the requirement for Transform
71 /// dialect ops to implement the MemoryEffectsOpInterface, and expects the
72 /// values in the Transform IR to have an allocation effect on the
73 /// TransformMappingResource when defined.
74 class TransformOpMemFreeAnalysis {
75 public:
76   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransformOpMemFreeAnalysis)
77 
78   /// Computes the analysis for Transform ops nested in the given operation.
TransformOpMemFreeAnalysis(Operation * root)79   explicit TransformOpMemFreeAnalysis(Operation *root) {
80     root->walk([&](Operation *op) {
81       if (isa<transform::TransformOpInterface>(op)) {
82         collectFreedValues(op);
83         return WalkResult::skip();
84       }
85       return WalkResult::advance();
86     });
87   }
88 
89   /// A list of operations that may be deleting a value. Non-empty list
90   /// contextually converts to boolean "true" value.
91   class PotentialDeleters {
92   public:
93     /// Creates an empty list that corresponds to the value being live.
live()94     static PotentialDeleters live() { return PotentialDeleters({}); }
95 
96     /// Creates a list from the operations that may be deleting the value.
maybeFreed(ArrayRef<Operation * > deleters)97     static PotentialDeleters maybeFreed(ArrayRef<Operation *> deleters) {
98       return PotentialDeleters(deleters);
99     }
100 
101     /// Converts to "true" if there are operations that may be deleting the
102     /// value.
operator bool() const103     explicit operator bool() const { return !deleters.empty(); }
104 
105     /// Concatenates the lists of operations that may be deleting the value. The
106     /// value is known to be live if the reuslting list is still empty.
operator |=(const PotentialDeleters & other)107     PotentialDeleters &operator|=(const PotentialDeleters &other) {
108       llvm::append_range(deleters, other.deleters);
109       return *this;
110     }
111 
112     /// Returns the list of ops that may be deleting the value.
getOps() const113     ArrayRef<Operation *> getOps() const { return deleters; }
114 
115   private:
116     /// Constructs the list from the given operations.
PotentialDeleters(ArrayRef<Operation * > ops)117     explicit PotentialDeleters(ArrayRef<Operation *> ops) {
118       llvm::append_range(deleters, ops);
119     }
120 
121     /// The list of operations that may be deleting the value.
122     SmallVector<Operation *> deleters;
123   };
124 
125   /// Returns the list of operations that may be deleting the operand value on
126   /// any control flow path between the definition of the value and its use as
127   /// the given operand. For the purposes of this analysis, the value is
128   /// considered to be allocated at its definition point and never re-allocated.
isUseLive(OpOperand & operand)129   PotentialDeleters isUseLive(OpOperand &operand) {
130     const llvm::SmallPtrSet<Operation *, 2> &deleters = freedBy[operand.get()];
131     if (deleters.empty())
132       return live();
133 
134 #ifndef NDEBUG
135     // Check that the definition point actually allcoates the value.
136     Operation *valueSource =
137         operand.get().isa<OpResult>()
138             ? operand.get().getDefiningOp()
139             : operand.get().getParentBlock()->getParentOp();
140     auto iface = cast<MemoryEffectOpInterface>(valueSource);
141     SmallVector<MemoryEffects::EffectInstance> instances;
142     iface.getEffectsOnResource(transform::TransformMappingResource::get(),
143                                instances);
144     assert(hasEffect<MemoryEffects::Allocate>(instances, operand.get()) &&
145            "expected the op defining the value to have an allocation effect "
146            "on it");
147 #endif
148 
149     // Collect ancestors of the use operation.
150     Block *defBlock = operand.get().getParentBlock();
151     SmallVector<Operation *> ancestors;
152     Operation *ancestor = operand.getOwner();
153     do {
154       ancestors.push_back(ancestor);
155       if (ancestor->getParentRegion() == defBlock->getParent())
156         break;
157       ancestor = ancestor->getParentOp();
158     } while (true);
159     std::reverse(ancestors.begin(), ancestors.end());
160 
161     // Consider the control flow from the definition point of the value to its
162     // use point. If the use is located in some nested region, consider the path
163     // from the entry block of the region to the use.
164     for (Operation *ancestor : ancestors) {
165       // The block should be considered partially if it is the block that
166       // contains the definition (allocation) of the value being used, and the
167       // value is defined in the middle of the block, i.e., is not a block
168       // argument.
169       bool isOutermost = ancestor == ancestors.front();
170       bool isFromBlockPartial = isOutermost && operand.get().isa<OpResult>();
171 
172       // Check if the value may be freed by operations between its definition
173       // (allocation) point in its block and the terminator of the block or the
174       // ancestor of the use if it is located in the same block. This is only
175       // done for partial blocks here, full blocks will be considered below
176       // similarly to other blocks.
177       if (isFromBlockPartial) {
178         bool defUseSameBlock = ancestor->getBlock() == defBlock;
179         // Consider all ops from the def to its block terminator, except the
180         // when the use is in the same block, in which case only consider the
181         // ops until the user.
182         if (PotentialDeleters potentialDeleters = isFreedInBlockAfter(
183                 operand.get().getDefiningOp(), operand.get(),
184                 defUseSameBlock ? ancestor : nullptr))
185           return potentialDeleters;
186       }
187 
188       // Check if the value may be freed by opeations preceding the ancestor in
189       // its block. Skip the check for partial blocks that contain both the
190       // definition and the use point, as this has been already checked above.
191       if (!isFromBlockPartial || ancestor->getBlock() != defBlock) {
192         if (PotentialDeleters potentialDeleters =
193                 isFreedInBlockBefore(ancestor, operand.get()))
194           return potentialDeleters;
195       }
196 
197       // Check if the value may be freed by operations in any of the blocks
198       // between the definition point (in the outermost region) or the entry
199       // block of the region (in other regions) and the operand or its ancestor
200       // in the region. This includes the entire "form" block if (1) the block
201       // has not been considered as partial above and (2) the block can be
202       // reached again through some control-flow loop. This includes the entire
203       // "to" block if it can be reached form itself through some control-flow
204       // cycle, regardless of whether it has been visited before.
205       Block *ancestorBlock = ancestor->getBlock();
206       Block *from =
207           isOutermost ? defBlock : &ancestorBlock->getParent()->front();
208       if (PotentialDeleters potentialDeleters =
209               isMaybeFreedOnPaths(from, ancestorBlock, operand.get(),
210                                   /*alwaysIncludeFrom=*/!isFromBlockPartial))
211         return potentialDeleters;
212     }
213     return live();
214   }
215 
216 private:
217   /// Make PotentialDeleters constructors available with shorter names.
maybeFreed(ArrayRef<Operation * > deleters)218   static PotentialDeleters maybeFreed(ArrayRef<Operation *> deleters) {
219     return PotentialDeleters::maybeFreed(deleters);
220   }
live()221   static PotentialDeleters live() { return PotentialDeleters::live(); }
222 
223   /// Returns the list of operations that may be deleting the given value betwen
224   /// the first and last operations, non-inclusive. `getNext` indicates the
225   /// direction of the traversal.
226   PotentialDeleters
isFreedBetween(Value value,Operation * first,Operation * last,llvm::function_ref<Operation * (Operation *)> getNext) const227   isFreedBetween(Value value, Operation *first, Operation *last,
228                  llvm::function_ref<Operation *(Operation *)> getNext) const {
229     auto it = freedBy.find(value);
230     if (it == freedBy.end())
231       return live();
232     const llvm::SmallPtrSet<Operation *, 2> &deleters = it->getSecond();
233     for (Operation *op = getNext(first); op != last; op = getNext(op)) {
234       if (deleters.contains(op))
235         return maybeFreed(op);
236     }
237     return live();
238   }
239 
240   /// Returns the list of operations that may be deleting the given value
241   /// between `root` and `before` values. `root` is expected to be in the same
242   /// block as `before` and precede it. If `before` is null, consider all
243   /// operations until the end of the block including the terminator.
isFreedInBlockAfter(Operation * root,Value value,Operation * before=nullptr) const244   PotentialDeleters isFreedInBlockAfter(Operation *root, Value value,
245                                         Operation *before = nullptr) const {
246     return isFreedBetween(value, root, before,
247                           [](Operation *op) { return op->getNextNode(); });
248   }
249 
250   /// Returns the list of operations that may be deleting the given value
251   /// between the entry of the block and the `root` operation.
isFreedInBlockBefore(Operation * root,Value value) const252   PotentialDeleters isFreedInBlockBefore(Operation *root, Value value) const {
253     return isFreedBetween(value, root, nullptr,
254                           [](Operation *op) { return op->getPrevNode(); });
255   }
256 
257   /// Returns the list of operations that may be deleting the given value on
258   /// any of the control flow paths between the "form" and the "to" block. The
259   /// operations from any block visited on any control flow path are
260   /// consdiered. The "from" block is considered if there is a control flow
261   /// cycle going through it, i.e., if there is a possibility that all
262   /// operations in this block are visited or if the `alwaysIncludeFrom` flag is
263   /// set. The "to" block is considered only if there is a control flow cycle
264   /// going through it.
isMaybeFreedOnPaths(Block * from,Block * to,Value value,bool alwaysIncludeFrom)265   PotentialDeleters isMaybeFreedOnPaths(Block *from, Block *to, Value value,
266                                         bool alwaysIncludeFrom) {
267     // Find all blocks that lie on any path between "from" and "to", i.e., the
268     // intersection of blocks reachable from "from" and blocks from which "to"
269     // is rechable.
270     const llvm::SmallPtrSet<Block *, 4> &sources = getReachableFrom(to);
271     if (!sources.contains(from))
272       return live();
273 
274     llvm::SmallPtrSet<Block *, 4> reachable(getReachable(from));
275     llvm::set_intersect(reachable, sources);
276 
277     // If requested, include the "from" block that may not be present in the set
278     // of visited blocks when there is no cycle going through it.
279     if (alwaysIncludeFrom)
280       reachable.insert(from);
281 
282     // Join potential deleters from all blocks as we don't know here which of
283     // the paths through the control flow is taken.
284     PotentialDeleters potentialDeleters = live();
285     for (Block *block : reachable) {
286       for (Operation &op : *block) {
287         if (freedBy[value].count(&op))
288           potentialDeleters |= maybeFreed(&op);
289       }
290     }
291     return potentialDeleters;
292   }
293 
294   /// Popualtes `reachable` with the set of blocks that are rechable from the
295   /// given block. A block is considered reachable from itself if there is a
296   /// cycle in the control-flow graph that invovles the block.
getReachable(Block * block)297   const llvm::SmallPtrSet<Block *, 4> &getReachable(Block *block) {
298     return getReachableImpl(
299         block, [](Block *b) { return b->getSuccessors(); }, reachableCache);
300   }
301 
302   /// Populates `sources` with the set of blocks from which the given block is
303   /// reachable.
getReachableFrom(Block * block)304   const llvm::SmallPtrSet<Block *, 4> &getReachableFrom(Block *block) {
305     return getReachableImpl(
306         block, [](Block *b) { return b->getPredecessors(); },
307         reachableFromCache);
308   }
309 
310   /// Returns true of `instances` contains an effect of `EffectTy` on `value`.
311   template <typename EffectTy>
hasEffect(ArrayRef<MemoryEffects::EffectInstance> instances,Value value)312   static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> instances,
313                         Value value) {
314     return llvm::any_of(instances,
315                         [&](const MemoryEffects::EffectInstance &instance) {
316                           return instance.getValue() == value &&
317                                  isa<EffectTy>(instance.getEffect());
318                         });
319   }
320 
321   /// Records the values that are being freed by an operation or any of its
322   /// children in `freedBy`.
collectFreedValues(Operation * root)323   void collectFreedValues(Operation *root) {
324     SmallVector<MemoryEffects::EffectInstance> instances;
325     root->walk([&](Operation *child) {
326       // TODO: extend this to conservatively handle operations with undeclared
327       // side effects as maybe freeing the operands.
328       auto iface = cast<MemoryEffectOpInterface>(child);
329       instances.clear();
330       iface.getEffectsOnResource(transform::TransformMappingResource::get(),
331                                  instances);
332       for (Value operand : child->getOperands()) {
333         if (hasEffect<MemoryEffects::Free>(instances, operand)) {
334           // All parents of the operation that frees a value should be
335           // considered as potentially freeing the value as well.
336           //
337           // TODO: differentiate between must-free/may-free as well as between
338           // this op having the effect and children having the effect. This may
339           // require some analysis of all control flow paths through the nested
340           // regions as well as a mechanism to separate proper side effects from
341           // those obtained by nesting.
342           Operation *parent = child;
343           do {
344             freedBy[operand].insert(parent);
345             if (parent == root)
346               break;
347             parent = parent->getParentOp();
348           } while (true);
349         }
350       }
351     });
352   }
353 
354   /// The mapping from a value to operations that have a Free memory effect on
355   /// the TransformMappingResource and associated with this value, or to
356   /// Transform operations transitively containing such operations.
357   DenseMap<Value, llvm::SmallPtrSet<Operation *, 2>> freedBy;
358 
359   /// Caches for sets of reachable blocks.
360   DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> reachableCache;
361   DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> reachableFromCache;
362 };
363 
364 #define GEN_PASS_CLASSES
365 #include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
366 
367 //// A simple pass that warns about any use of a value by a transform operation
368 // that may be using the value after it has been freed.
369 class CheckUsesPass : public CheckUsesBase<CheckUsesPass> {
370 public:
runOnOperation()371   void runOnOperation() override {
372     auto &analysis = getAnalysis<TransformOpMemFreeAnalysis>();
373 
374     getOperation()->walk([&](Operation *child) {
375       for (OpOperand &operand : child->getOpOperands()) {
376         TransformOpMemFreeAnalysis::PotentialDeleters deleters =
377             analysis.isUseLive(operand);
378         if (!deleters)
379           continue;
380 
381         InFlightDiagnostic diag = child->emitWarning()
382                                   << "operand #" << operand.getOperandNumber()
383                                   << " may be used after free";
384         diag.attachNote(operand.get().getLoc()) << "allocated here";
385         for (Operation *d : deleters.getOps()) {
386           diag.attachNote(d->getLoc()) << "freed here";
387         }
388       }
389     });
390   }
391 };
392 
393 } // namespace
394 
395 namespace mlir {
396 namespace transform {
createCheckUsesPass()397 std::unique_ptr<Pass> createCheckUsesPass() {
398   return std::make_unique<CheckUsesPass>();
399 }
400 } // namespace transform
401 } // namespace mlir
402