1*73c3dff1SAlex Zinenko //===- CheckUses.cpp - Expensive transform value validity checks ----------===//
2*73c3dff1SAlex Zinenko //
3*73c3dff1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*73c3dff1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5*73c3dff1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*73c3dff1SAlex Zinenko //
7*73c3dff1SAlex Zinenko //===----------------------------------------------------------------------===//
8*73c3dff1SAlex Zinenko //
9*73c3dff1SAlex Zinenko // This file defines a pass that performs expensive opt-in checks for Transform
10*73c3dff1SAlex Zinenko // dialect values being potentially used after they have been consumed.
11*73c3dff1SAlex Zinenko //
12*73c3dff1SAlex Zinenko //===----------------------------------------------------------------------===//
13*73c3dff1SAlex Zinenko 
14*73c3dff1SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
15*73c3dff1SAlex Zinenko #include "mlir/Dialect/Transform/Transforms/Passes.h"
16*73c3dff1SAlex Zinenko #include "mlir/Interfaces/SideEffectInterfaces.h"
17*73c3dff1SAlex Zinenko #include "mlir/Pass/Pass.h"
18*73c3dff1SAlex Zinenko #include "llvm/ADT/SetOperations.h"
19*73c3dff1SAlex Zinenko 
20*73c3dff1SAlex Zinenko using namespace mlir;
21*73c3dff1SAlex Zinenko 
22*73c3dff1SAlex Zinenko namespace {
23*73c3dff1SAlex Zinenko 
24*73c3dff1SAlex Zinenko /// Returns a reference to a cached set of blocks that are reachable from the
25*73c3dff1SAlex Zinenko /// given block via edges computed by the `getNextNodes` function. For example,
26*73c3dff1SAlex Zinenko /// if `getNextNodes` returns successors of a block, this will return the set of
27*73c3dff1SAlex Zinenko /// reachable blocks; if it returns predecessors of a block, this will return
28*73c3dff1SAlex Zinenko /// the set of blocks from which the given block can be reached. The block is
29*73c3dff1SAlex Zinenko /// considered reachable form itself only if there is a cycle.
30*73c3dff1SAlex Zinenko template <typename FnTy>
31*73c3dff1SAlex Zinenko const llvm::SmallPtrSet<Block *, 4> &
getReachableImpl(Block * block,FnTy getNextNodes,DenseMap<Block *,llvm::SmallPtrSet<Block *,4>> & cache)32*73c3dff1SAlex Zinenko getReachableImpl(Block *block, FnTy getNextNodes,
33*73c3dff1SAlex Zinenko                  DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> &cache) {
34*73c3dff1SAlex Zinenko   auto it = cache.find(block);
35*73c3dff1SAlex Zinenko   if (it != cache.end())
36*73c3dff1SAlex Zinenko     return it->getSecond();
37*73c3dff1SAlex Zinenko 
38*73c3dff1SAlex Zinenko   llvm::SmallPtrSet<Block *, 4> &reachable = cache[block];
39*73c3dff1SAlex Zinenko   SmallVector<Block *> worklist;
40*73c3dff1SAlex Zinenko   worklist.push_back(block);
41*73c3dff1SAlex Zinenko   while (!worklist.empty()) {
42*73c3dff1SAlex Zinenko     Block *current = worklist.pop_back_val();
43*73c3dff1SAlex Zinenko     for (Block *predecessor : getNextNodes(current)) {
44*73c3dff1SAlex Zinenko       // The block is reachable from its transitive predecessors. Only add
45*73c3dff1SAlex Zinenko       // them to the worklist if they weren't already visited.
46*73c3dff1SAlex Zinenko       if (reachable.insert(predecessor).second)
47*73c3dff1SAlex Zinenko         worklist.push_back(predecessor);
48*73c3dff1SAlex Zinenko     }
49*73c3dff1SAlex Zinenko   }
50*73c3dff1SAlex Zinenko   return reachable;
51*73c3dff1SAlex Zinenko }
52*73c3dff1SAlex Zinenko 
53*73c3dff1SAlex Zinenko /// An analysis that identifies whether a value allocated by a Transform op may
54*73c3dff1SAlex Zinenko /// be used by another such op after it may have been freed by a third op on
55*73c3dff1SAlex Zinenko /// some control flow path. This is conceptually similar to a data flow
56*73c3dff1SAlex Zinenko /// analysis, but relies on side effects related to particular values that
57*73c3dff1SAlex Zinenko /// currently cannot be modeled by the MLIR data flow analysis framework (also,
58*73c3dff1SAlex Zinenko /// the lattice element would be rather expensive as it would need to include
59*73c3dff1SAlex Zinenko /// live and/or freed values for each operation).
60*73c3dff1SAlex Zinenko ///
61*73c3dff1SAlex Zinenko /// This analysis is conservatively pessimisic: it will consider that a value
62*73c3dff1SAlex Zinenko /// may be freed if it is freed on any possible control flow path between its
63*73c3dff1SAlex Zinenko /// allocation and a relevant use, even if the control never actually flows
64*73c3dff1SAlex Zinenko /// through the operation that frees the value. It also does not differentiate
65*73c3dff1SAlex Zinenko /// between may- (freed on at least one control flow path) and must-free (freed
66*73c3dff1SAlex Zinenko /// on all possible control flow paths) because it would require expensive graph
67*73c3dff1SAlex Zinenko /// algorithms.
68*73c3dff1SAlex Zinenko ///
69*73c3dff1SAlex Zinenko /// It is intended as an additional non-blocking verification or debugging aid
70*73c3dff1SAlex Zinenko /// for ops in the Transform dialect. It leverages the requirement for Transform
71*73c3dff1SAlex Zinenko /// dialect ops to implement the MemoryEffectsOpInterface, and expects the
72*73c3dff1SAlex Zinenko /// values in the Transform IR to have an allocation effect on the
73*73c3dff1SAlex Zinenko /// TransformMappingResource when defined.
74*73c3dff1SAlex Zinenko class TransformOpMemFreeAnalysis {
75*73c3dff1SAlex Zinenko public:
76*73c3dff1SAlex Zinenko   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransformOpMemFreeAnalysis)
77*73c3dff1SAlex Zinenko 
78*73c3dff1SAlex Zinenko   /// Computes the analysis for Transform ops nested in the given operation.
TransformOpMemFreeAnalysis(Operation * root)79*73c3dff1SAlex Zinenko   explicit TransformOpMemFreeAnalysis(Operation *root) {
80*73c3dff1SAlex Zinenko     root->walk([&](Operation *op) {
81*73c3dff1SAlex Zinenko       if (isa<transform::TransformOpInterface>(op)) {
82*73c3dff1SAlex Zinenko         collectFreedValues(op);
83*73c3dff1SAlex Zinenko         return WalkResult::skip();
84*73c3dff1SAlex Zinenko       }
85*73c3dff1SAlex Zinenko       return WalkResult::advance();
86*73c3dff1SAlex Zinenko     });
87*73c3dff1SAlex Zinenko   }
88*73c3dff1SAlex Zinenko 
89*73c3dff1SAlex Zinenko   /// A list of operations that may be deleting a value. Non-empty list
90*73c3dff1SAlex Zinenko   /// contextually converts to boolean "true" value.
91*73c3dff1SAlex Zinenko   class PotentialDeleters {
92*73c3dff1SAlex Zinenko   public:
93*73c3dff1SAlex Zinenko     /// Creates an empty list that corresponds to the value being live.
live()94*73c3dff1SAlex Zinenko     static PotentialDeleters live() { return PotentialDeleters({}); }
95*73c3dff1SAlex Zinenko 
96*73c3dff1SAlex Zinenko     /// Creates a list from the operations that may be deleting the value.
maybeFreed(ArrayRef<Operation * > deleters)97*73c3dff1SAlex Zinenko     static PotentialDeleters maybeFreed(ArrayRef<Operation *> deleters) {
98*73c3dff1SAlex Zinenko       return PotentialDeleters(deleters);
99*73c3dff1SAlex Zinenko     }
100*73c3dff1SAlex Zinenko 
101*73c3dff1SAlex Zinenko     /// Converts to "true" if there are operations that may be deleting the
102*73c3dff1SAlex Zinenko     /// value.
operator bool() const103*73c3dff1SAlex Zinenko     explicit operator bool() const { return !deleters.empty(); }
104*73c3dff1SAlex Zinenko 
105*73c3dff1SAlex Zinenko     /// Concatenates the lists of operations that may be deleting the value. The
106*73c3dff1SAlex Zinenko     /// value is known to be live if the reuslting list is still empty.
operator |=(const PotentialDeleters & other)107*73c3dff1SAlex Zinenko     PotentialDeleters &operator|=(const PotentialDeleters &other) {
108*73c3dff1SAlex Zinenko       llvm::append_range(deleters, other.deleters);
109*73c3dff1SAlex Zinenko       return *this;
110*73c3dff1SAlex Zinenko     }
111*73c3dff1SAlex Zinenko 
112*73c3dff1SAlex Zinenko     /// Returns the list of ops that may be deleting the value.
getOps() const113*73c3dff1SAlex Zinenko     ArrayRef<Operation *> getOps() const { return deleters; }
114*73c3dff1SAlex Zinenko 
115*73c3dff1SAlex Zinenko   private:
116*73c3dff1SAlex Zinenko     /// Constructs the list from the given operations.
PotentialDeleters(ArrayRef<Operation * > ops)117*73c3dff1SAlex Zinenko     explicit PotentialDeleters(ArrayRef<Operation *> ops) {
118*73c3dff1SAlex Zinenko       llvm::append_range(deleters, ops);
119*73c3dff1SAlex Zinenko     }
120*73c3dff1SAlex Zinenko 
121*73c3dff1SAlex Zinenko     /// The list of operations that may be deleting the value.
122*73c3dff1SAlex Zinenko     SmallVector<Operation *> deleters;
123*73c3dff1SAlex Zinenko   };
124*73c3dff1SAlex Zinenko 
125*73c3dff1SAlex Zinenko   /// Returns the list of operations that may be deleting the operand value on
126*73c3dff1SAlex Zinenko   /// any control flow path between the definition of the value and its use as
127*73c3dff1SAlex Zinenko   /// the given operand. For the purposes of this analysis, the value is
128*73c3dff1SAlex Zinenko   /// considered to be allocated at its definition point and never re-allocated.
isUseLive(OpOperand & operand)129*73c3dff1SAlex Zinenko   PotentialDeleters isUseLive(OpOperand &operand) {
130*73c3dff1SAlex Zinenko     const llvm::SmallPtrSet<Operation *, 2> &deleters = freedBy[operand.get()];
131*73c3dff1SAlex Zinenko     if (deleters.empty())
132*73c3dff1SAlex Zinenko       return live();
133*73c3dff1SAlex Zinenko 
134*73c3dff1SAlex Zinenko #ifndef NDEBUG
135*73c3dff1SAlex Zinenko     // Check that the definition point actually allcoates the value.
136*73c3dff1SAlex Zinenko     Operation *valueSource =
137*73c3dff1SAlex Zinenko         operand.get().isa<OpResult>()
138*73c3dff1SAlex Zinenko             ? operand.get().getDefiningOp()
139*73c3dff1SAlex Zinenko             : operand.get().getParentBlock()->getParentOp();
140*73c3dff1SAlex Zinenko     auto iface = cast<MemoryEffectOpInterface>(valueSource);
141*73c3dff1SAlex Zinenko     SmallVector<MemoryEffects::EffectInstance> instances;
142*73c3dff1SAlex Zinenko     iface.getEffectsOnResource(transform::TransformMappingResource::get(),
143*73c3dff1SAlex Zinenko                                instances);
144*73c3dff1SAlex Zinenko     assert(hasEffect<MemoryEffects::Allocate>(instances, operand.get()) &&
145*73c3dff1SAlex Zinenko            "expected the op defining the value to have an allocation effect "
146*73c3dff1SAlex Zinenko            "on it");
147*73c3dff1SAlex Zinenko #endif
148*73c3dff1SAlex Zinenko 
149*73c3dff1SAlex Zinenko     // Collect ancestors of the use operation.
150*73c3dff1SAlex Zinenko     Block *defBlock = operand.get().getParentBlock();
151*73c3dff1SAlex Zinenko     SmallVector<Operation *> ancestors;
152*73c3dff1SAlex Zinenko     Operation *ancestor = operand.getOwner();
153*73c3dff1SAlex Zinenko     do {
154*73c3dff1SAlex Zinenko       ancestors.push_back(ancestor);
155*73c3dff1SAlex Zinenko       if (ancestor->getParentRegion() == defBlock->getParent())
156*73c3dff1SAlex Zinenko         break;
157*73c3dff1SAlex Zinenko       ancestor = ancestor->getParentOp();
158*73c3dff1SAlex Zinenko     } while (true);
159*73c3dff1SAlex Zinenko     std::reverse(ancestors.begin(), ancestors.end());
160*73c3dff1SAlex Zinenko 
161*73c3dff1SAlex Zinenko     // Consider the control flow from the definition point of the value to its
162*73c3dff1SAlex Zinenko     // use point. If the use is located in some nested region, consider the path
163*73c3dff1SAlex Zinenko     // from the entry block of the region to the use.
164*73c3dff1SAlex Zinenko     for (Operation *ancestor : ancestors) {
165*73c3dff1SAlex Zinenko       // The block should be considered partially if it is the block that
166*73c3dff1SAlex Zinenko       // contains the definition (allocation) of the value being used, and the
167*73c3dff1SAlex Zinenko       // value is defined in the middle of the block, i.e., is not a block
168*73c3dff1SAlex Zinenko       // argument.
169*73c3dff1SAlex Zinenko       bool isOutermost = ancestor == ancestors.front();
170*73c3dff1SAlex Zinenko       bool isFromBlockPartial = isOutermost && operand.get().isa<OpResult>();
171*73c3dff1SAlex Zinenko 
172*73c3dff1SAlex Zinenko       // Check if the value may be freed by operations between its definition
173*73c3dff1SAlex Zinenko       // (allocation) point in its block and the terminator of the block or the
174*73c3dff1SAlex Zinenko       // ancestor of the use if it is located in the same block. This is only
175*73c3dff1SAlex Zinenko       // done for partial blocks here, full blocks will be considered below
176*73c3dff1SAlex Zinenko       // similarly to other blocks.
177*73c3dff1SAlex Zinenko       if (isFromBlockPartial) {
178*73c3dff1SAlex Zinenko         bool defUseSameBlock = ancestor->getBlock() == defBlock;
179*73c3dff1SAlex Zinenko         // Consider all ops from the def to its block terminator, except the
180*73c3dff1SAlex Zinenko         // when the use is in the same block, in which case only consider the
181*73c3dff1SAlex Zinenko         // ops until the user.
182*73c3dff1SAlex Zinenko         if (PotentialDeleters potentialDeleters = isFreedInBlockAfter(
183*73c3dff1SAlex Zinenko                 operand.get().getDefiningOp(), operand.get(),
184*73c3dff1SAlex Zinenko                 defUseSameBlock ? ancestor : nullptr))
185*73c3dff1SAlex Zinenko           return potentialDeleters;
186*73c3dff1SAlex Zinenko       }
187*73c3dff1SAlex Zinenko 
188*73c3dff1SAlex Zinenko       // Check if the value may be freed by opeations preceding the ancestor in
189*73c3dff1SAlex Zinenko       // its block. Skip the check for partial blocks that contain both the
190*73c3dff1SAlex Zinenko       // definition and the use point, as this has been already checked above.
191*73c3dff1SAlex Zinenko       if (!isFromBlockPartial || ancestor->getBlock() != defBlock) {
192*73c3dff1SAlex Zinenko         if (PotentialDeleters potentialDeleters =
193*73c3dff1SAlex Zinenko                 isFreedInBlockBefore(ancestor, operand.get()))
194*73c3dff1SAlex Zinenko           return potentialDeleters;
195*73c3dff1SAlex Zinenko       }
196*73c3dff1SAlex Zinenko 
197*73c3dff1SAlex Zinenko       // Check if the value may be freed by operations in any of the blocks
198*73c3dff1SAlex Zinenko       // between the definition point (in the outermost region) or the entry
199*73c3dff1SAlex Zinenko       // block of the region (in other regions) and the operand or its ancestor
200*73c3dff1SAlex Zinenko       // in the region. This includes the entire "form" block if (1) the block
201*73c3dff1SAlex Zinenko       // has not been considered as partial above and (2) the block can be
202*73c3dff1SAlex Zinenko       // reached again through some control-flow loop. This includes the entire
203*73c3dff1SAlex Zinenko       // "to" block if it can be reached form itself through some control-flow
204*73c3dff1SAlex Zinenko       // cycle, regardless of whether it has been visited before.
205*73c3dff1SAlex Zinenko       Block *ancestorBlock = ancestor->getBlock();
206*73c3dff1SAlex Zinenko       Block *from =
207*73c3dff1SAlex Zinenko           isOutermost ? defBlock : &ancestorBlock->getParent()->front();
208*73c3dff1SAlex Zinenko       if (PotentialDeleters potentialDeleters =
209*73c3dff1SAlex Zinenko               isMaybeFreedOnPaths(from, ancestorBlock, operand.get(),
210*73c3dff1SAlex Zinenko                                   /*alwaysIncludeFrom=*/!isFromBlockPartial))
211*73c3dff1SAlex Zinenko         return potentialDeleters;
212*73c3dff1SAlex Zinenko     }
213*73c3dff1SAlex Zinenko     return live();
214*73c3dff1SAlex Zinenko   }
215*73c3dff1SAlex Zinenko 
216*73c3dff1SAlex Zinenko private:
217*73c3dff1SAlex Zinenko   /// Make PotentialDeleters constructors available with shorter names.
maybeFreed(ArrayRef<Operation * > deleters)218*73c3dff1SAlex Zinenko   static PotentialDeleters maybeFreed(ArrayRef<Operation *> deleters) {
219*73c3dff1SAlex Zinenko     return PotentialDeleters::maybeFreed(deleters);
220*73c3dff1SAlex Zinenko   }
live()221*73c3dff1SAlex Zinenko   static PotentialDeleters live() { return PotentialDeleters::live(); }
222*73c3dff1SAlex Zinenko 
223*73c3dff1SAlex Zinenko   /// Returns the list of operations that may be deleting the given value betwen
224*73c3dff1SAlex Zinenko   /// the first and last operations, non-inclusive. `getNext` indicates the
225*73c3dff1SAlex Zinenko   /// direction of the traversal.
226*73c3dff1SAlex Zinenko   PotentialDeleters
isFreedBetween(Value value,Operation * first,Operation * last,llvm::function_ref<Operation * (Operation *)> getNext) const227*73c3dff1SAlex Zinenko   isFreedBetween(Value value, Operation *first, Operation *last,
228*73c3dff1SAlex Zinenko                  llvm::function_ref<Operation *(Operation *)> getNext) const {
229*73c3dff1SAlex Zinenko     auto it = freedBy.find(value);
230*73c3dff1SAlex Zinenko     if (it == freedBy.end())
231*73c3dff1SAlex Zinenko       return live();
232*73c3dff1SAlex Zinenko     const llvm::SmallPtrSet<Operation *, 2> &deleters = it->getSecond();
233*73c3dff1SAlex Zinenko     for (Operation *op = getNext(first); op != last; op = getNext(op)) {
234*73c3dff1SAlex Zinenko       if (deleters.contains(op))
235*73c3dff1SAlex Zinenko         return maybeFreed(op);
236*73c3dff1SAlex Zinenko     }
237*73c3dff1SAlex Zinenko     return live();
238*73c3dff1SAlex Zinenko   }
239*73c3dff1SAlex Zinenko 
240*73c3dff1SAlex Zinenko   /// Returns the list of operations that may be deleting the given value
241*73c3dff1SAlex Zinenko   /// between `root` and `before` values. `root` is expected to be in the same
242*73c3dff1SAlex Zinenko   /// block as `before` and precede it. If `before` is null, consider all
243*73c3dff1SAlex Zinenko   /// operations until the end of the block including the terminator.
isFreedInBlockAfter(Operation * root,Value value,Operation * before=nullptr) const244*73c3dff1SAlex Zinenko   PotentialDeleters isFreedInBlockAfter(Operation *root, Value value,
245*73c3dff1SAlex Zinenko                                         Operation *before = nullptr) const {
246*73c3dff1SAlex Zinenko     return isFreedBetween(value, root, before,
247*73c3dff1SAlex Zinenko                           [](Operation *op) { return op->getNextNode(); });
248*73c3dff1SAlex Zinenko   }
249*73c3dff1SAlex Zinenko 
250*73c3dff1SAlex Zinenko   /// Returns the list of operations that may be deleting the given value
251*73c3dff1SAlex Zinenko   /// between the entry of the block and the `root` operation.
isFreedInBlockBefore(Operation * root,Value value) const252*73c3dff1SAlex Zinenko   PotentialDeleters isFreedInBlockBefore(Operation *root, Value value) const {
253*73c3dff1SAlex Zinenko     return isFreedBetween(value, root, nullptr,
254*73c3dff1SAlex Zinenko                           [](Operation *op) { return op->getPrevNode(); });
255*73c3dff1SAlex Zinenko   }
256*73c3dff1SAlex Zinenko 
257*73c3dff1SAlex Zinenko   /// Returns the list of operations that may be deleting the given value on
258*73c3dff1SAlex Zinenko   /// any of the control flow paths between the "form" and the "to" block. The
259*73c3dff1SAlex Zinenko   /// operations from any block visited on any control flow path are
260*73c3dff1SAlex Zinenko   /// consdiered. The "from" block is considered if there is a control flow
261*73c3dff1SAlex Zinenko   /// cycle going through it, i.e., if there is a possibility that all
262*73c3dff1SAlex Zinenko   /// operations in this block are visited or if the `alwaysIncludeFrom` flag is
263*73c3dff1SAlex Zinenko   /// set. The "to" block is considered only if there is a control flow cycle
264*73c3dff1SAlex Zinenko   /// going through it.
isMaybeFreedOnPaths(Block * from,Block * to,Value value,bool alwaysIncludeFrom)265*73c3dff1SAlex Zinenko   PotentialDeleters isMaybeFreedOnPaths(Block *from, Block *to, Value value,
266*73c3dff1SAlex Zinenko                                         bool alwaysIncludeFrom) {
267*73c3dff1SAlex Zinenko     // Find all blocks that lie on any path between "from" and "to", i.e., the
268*73c3dff1SAlex Zinenko     // intersection of blocks reachable from "from" and blocks from which "to"
269*73c3dff1SAlex Zinenko     // is rechable.
270*73c3dff1SAlex Zinenko     const llvm::SmallPtrSet<Block *, 4> &sources = getReachableFrom(to);
271*73c3dff1SAlex Zinenko     if (!sources.contains(from))
272*73c3dff1SAlex Zinenko       return live();
273*73c3dff1SAlex Zinenko 
274*73c3dff1SAlex Zinenko     llvm::SmallPtrSet<Block *, 4> reachable(getReachable(from));
275*73c3dff1SAlex Zinenko     llvm::set_intersect(reachable, sources);
276*73c3dff1SAlex Zinenko 
277*73c3dff1SAlex Zinenko     // If requested, include the "from" block that may not be present in the set
278*73c3dff1SAlex Zinenko     // of visited blocks when there is no cycle going through it.
279*73c3dff1SAlex Zinenko     if (alwaysIncludeFrom)
280*73c3dff1SAlex Zinenko       reachable.insert(from);
281*73c3dff1SAlex Zinenko 
282*73c3dff1SAlex Zinenko     // Join potential deleters from all blocks as we don't know here which of
283*73c3dff1SAlex Zinenko     // the paths through the control flow is taken.
284*73c3dff1SAlex Zinenko     PotentialDeleters potentialDeleters = live();
285*73c3dff1SAlex Zinenko     for (Block *block : reachable) {
286*73c3dff1SAlex Zinenko       for (Operation &op : *block) {
287*73c3dff1SAlex Zinenko         if (freedBy[value].count(&op))
288*73c3dff1SAlex Zinenko           potentialDeleters |= maybeFreed(&op);
289*73c3dff1SAlex Zinenko       }
290*73c3dff1SAlex Zinenko     }
291*73c3dff1SAlex Zinenko     return potentialDeleters;
292*73c3dff1SAlex Zinenko   }
293*73c3dff1SAlex Zinenko 
294*73c3dff1SAlex Zinenko   /// Popualtes `reachable` with the set of blocks that are rechable from the
295*73c3dff1SAlex Zinenko   /// given block. A block is considered reachable from itself if there is a
296*73c3dff1SAlex Zinenko   /// cycle in the control-flow graph that invovles the block.
getReachable(Block * block)297*73c3dff1SAlex Zinenko   const llvm::SmallPtrSet<Block *, 4> &getReachable(Block *block) {
298*73c3dff1SAlex Zinenko     return getReachableImpl(
299*73c3dff1SAlex Zinenko         block, [](Block *b) { return b->getSuccessors(); }, reachableCache);
300*73c3dff1SAlex Zinenko   }
301*73c3dff1SAlex Zinenko 
302*73c3dff1SAlex Zinenko   /// Populates `sources` with the set of blocks from which the given block is
303*73c3dff1SAlex Zinenko   /// reachable.
getReachableFrom(Block * block)304*73c3dff1SAlex Zinenko   const llvm::SmallPtrSet<Block *, 4> &getReachableFrom(Block *block) {
305*73c3dff1SAlex Zinenko     return getReachableImpl(
306*73c3dff1SAlex Zinenko         block, [](Block *b) { return b->getPredecessors(); },
307*73c3dff1SAlex Zinenko         reachableFromCache);
308*73c3dff1SAlex Zinenko   }
309*73c3dff1SAlex Zinenko 
310*73c3dff1SAlex Zinenko   /// Returns true of `instances` contains an effect of `EffectTy` on `value`.
311*73c3dff1SAlex Zinenko   template <typename EffectTy>
hasEffect(ArrayRef<MemoryEffects::EffectInstance> instances,Value value)312*73c3dff1SAlex Zinenko   static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> instances,
313*73c3dff1SAlex Zinenko                         Value value) {
314*73c3dff1SAlex Zinenko     return llvm::any_of(instances,
315*73c3dff1SAlex Zinenko                         [&](const MemoryEffects::EffectInstance &instance) {
316*73c3dff1SAlex Zinenko                           return instance.getValue() == value &&
317*73c3dff1SAlex Zinenko                                  isa<EffectTy>(instance.getEffect());
318*73c3dff1SAlex Zinenko                         });
319*73c3dff1SAlex Zinenko   }
320*73c3dff1SAlex Zinenko 
321*73c3dff1SAlex Zinenko   /// Records the values that are being freed by an operation or any of its
322*73c3dff1SAlex Zinenko   /// children in `freedBy`.
collectFreedValues(Operation * root)323*73c3dff1SAlex Zinenko   void collectFreedValues(Operation *root) {
324*73c3dff1SAlex Zinenko     SmallVector<MemoryEffects::EffectInstance> instances;
325*73c3dff1SAlex Zinenko     root->walk([&](Operation *child) {
326*73c3dff1SAlex Zinenko       // TODO: extend this to conservatively handle operations with undeclared
327*73c3dff1SAlex Zinenko       // side effects as maybe freeing the operands.
328*73c3dff1SAlex Zinenko       auto iface = cast<MemoryEffectOpInterface>(child);
329*73c3dff1SAlex Zinenko       instances.clear();
330*73c3dff1SAlex Zinenko       iface.getEffectsOnResource(transform::TransformMappingResource::get(),
331*73c3dff1SAlex Zinenko                                  instances);
332*73c3dff1SAlex Zinenko       for (Value operand : child->getOperands()) {
333*73c3dff1SAlex Zinenko         if (hasEffect<MemoryEffects::Free>(instances, operand)) {
334*73c3dff1SAlex Zinenko           // All parents of the operation that frees a value should be
335*73c3dff1SAlex Zinenko           // considered as potentially freeing the value as well.
336*73c3dff1SAlex Zinenko           //
337*73c3dff1SAlex Zinenko           // TODO: differentiate between must-free/may-free as well as between
338*73c3dff1SAlex Zinenko           // this op having the effect and children having the effect. This may
339*73c3dff1SAlex Zinenko           // require some analysis of all control flow paths through the nested
340*73c3dff1SAlex Zinenko           // regions as well as a mechanism to separate proper side effects from
341*73c3dff1SAlex Zinenko           // those obtained by nesting.
342*73c3dff1SAlex Zinenko           Operation *parent = child;
343*73c3dff1SAlex Zinenko           do {
344*73c3dff1SAlex Zinenko             freedBy[operand].insert(parent);
345*73c3dff1SAlex Zinenko             if (parent == root)
346*73c3dff1SAlex Zinenko               break;
347*73c3dff1SAlex Zinenko             parent = parent->getParentOp();
348*73c3dff1SAlex Zinenko           } while (true);
349*73c3dff1SAlex Zinenko         }
350*73c3dff1SAlex Zinenko       }
351*73c3dff1SAlex Zinenko     });
352*73c3dff1SAlex Zinenko   }
353*73c3dff1SAlex Zinenko 
354*73c3dff1SAlex Zinenko   /// The mapping from a value to operations that have a Free memory effect on
355*73c3dff1SAlex Zinenko   /// the TransformMappingResource and associated with this value, or to
356*73c3dff1SAlex Zinenko   /// Transform operations transitively containing such operations.
357*73c3dff1SAlex Zinenko   DenseMap<Value, llvm::SmallPtrSet<Operation *, 2>> freedBy;
358*73c3dff1SAlex Zinenko 
359*73c3dff1SAlex Zinenko   /// Caches for sets of reachable blocks.
360*73c3dff1SAlex Zinenko   DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> reachableCache;
361*73c3dff1SAlex Zinenko   DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> reachableFromCache;
362*73c3dff1SAlex Zinenko };
363*73c3dff1SAlex Zinenko 
364*73c3dff1SAlex Zinenko #define GEN_PASS_CLASSES
365*73c3dff1SAlex Zinenko #include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
366*73c3dff1SAlex Zinenko 
367*73c3dff1SAlex Zinenko //// A simple pass that warns about any use of a value by a transform operation
368*73c3dff1SAlex Zinenko // that may be using the value after it has been freed.
369*73c3dff1SAlex Zinenko class CheckUsesPass : public CheckUsesBase<CheckUsesPass> {
370*73c3dff1SAlex Zinenko public:
runOnOperation()371*73c3dff1SAlex Zinenko   void runOnOperation() override {
372*73c3dff1SAlex Zinenko     auto &analysis = getAnalysis<TransformOpMemFreeAnalysis>();
373*73c3dff1SAlex Zinenko 
374*73c3dff1SAlex Zinenko     getOperation()->walk([&](Operation *child) {
375*73c3dff1SAlex Zinenko       for (OpOperand &operand : child->getOpOperands()) {
376*73c3dff1SAlex Zinenko         TransformOpMemFreeAnalysis::PotentialDeleters deleters =
377*73c3dff1SAlex Zinenko             analysis.isUseLive(operand);
378*73c3dff1SAlex Zinenko         if (!deleters)
379*73c3dff1SAlex Zinenko           continue;
380*73c3dff1SAlex Zinenko 
381*73c3dff1SAlex Zinenko         InFlightDiagnostic diag = child->emitWarning()
382*73c3dff1SAlex Zinenko                                   << "operand #" << operand.getOperandNumber()
383*73c3dff1SAlex Zinenko                                   << " may be used after free";
384*73c3dff1SAlex Zinenko         diag.attachNote(operand.get().getLoc()) << "allocated here";
385*73c3dff1SAlex Zinenko         for (Operation *d : deleters.getOps()) {
386*73c3dff1SAlex Zinenko           diag.attachNote(d->getLoc()) << "freed here";
387*73c3dff1SAlex Zinenko         }
388*73c3dff1SAlex Zinenko       }
389*73c3dff1SAlex Zinenko     });
390*73c3dff1SAlex Zinenko   }
391*73c3dff1SAlex Zinenko };
392*73c3dff1SAlex Zinenko 
393*73c3dff1SAlex Zinenko } // namespace
394*73c3dff1SAlex Zinenko 
395*73c3dff1SAlex Zinenko namespace mlir {
396*73c3dff1SAlex Zinenko namespace transform {
createCheckUsesPass()397*73c3dff1SAlex Zinenko std::unique_ptr<Pass> createCheckUsesPass() {
398*73c3dff1SAlex Zinenko   return std::make_unique<CheckUsesPass>();
399*73c3dff1SAlex Zinenko }
400*73c3dff1SAlex Zinenko } // namespace transform
401*73c3dff1SAlex Zinenko } // namespace mlir
402