1 //===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // One-Shot Analysis analyzes function bodies. Function boundaries (FuncOp
10 // bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
11 // ModuleBufferization.cpp is an extension of One-Shot Analysis for simple
12 // call graphs.
13 //
14 // One-Shot Bufferize consists of two phases.
15 //
16 // 1. Analyze ops to decide which OpResults can bufferize inplace, i.e., without
17 //    inserting buffer copies. The analysis queries op bufferization semantics
18 //    via `BufferizableOpInterface`.
19 // 2. Bufferize ops by calling `BufferizableOpInterface::bufferize`. This
20 //    function does not generate buffer copies for OpResults that were decided
21 //    to bufferize inplace during the analysis phase.
22 //
23 // This file contains only the analysis. The actual bufferization is implemented
24 // via `bufferizeOp` (Bufferize.h). For convenience, this file also contains a
25 // helper function `runOneShotBufferize` that analyzes an op (and its nested
26 // ops) and then bufferizes it.
27 //
28 // Inplace bufferization decisions are passed from the analysis to the
29 // bufferization phase via `AnalysisState` and `BufferizationAliasInfo`.
30 // They can be printed for debugging purposes with `testAnalysisOnly`.
31 //
32 // Ops that do not implement `BufferizableOpInterface` can be analyzed but are
33 // treated conservatively. E.g., the analysis has to assume that their tensor
34 // OpOperands bufferize to memory writes. While such ops can be analyzed, they
35 // are not bufferized and remain in the IR. to_tensor and to_memref ops are
36 // inserted at the bufferization boundary.
37 //
38 // This analysis caters to high-performance codegen where buffer reuse is deemed
39 // critical: the analysis should fail if the bufferized form of the function
40 // needs to return a buffer, unless `allowReturnAllocs` is enabled.
41 
42 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
43 
44 #include <random>
45 
46 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
47 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
48 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
49 #include "mlir/Dialect/Func/IR/FuncOps.h"
50 #include "mlir/Dialect/MemRef/IR/MemRef.h"
51 #include "mlir/IR/AsmState.h"
52 #include "mlir/IR/Dominance.h"
53 #include "mlir/IR/Operation.h"
54 #include "mlir/IR/TypeUtilities.h"
55 #include "mlir/Interfaces/ControlFlowInterfaces.h"
56 #include "llvm/ADT/DenseSet.h"
57 #include "llvm/ADT/SetVector.h"
58 
59 using namespace mlir;
60 using namespace mlir::bufferization;
61 
62 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
63 
64 //===----------------------------------------------------------------------===//
65 // Bufferization-specific attribute manipulation.
66 // These are for testing and debugging only. Bufferization information is
67 // stored in BufferizationAliasInfo. When run with `testAnalysisOnly`, the IR
68 // is annotated with the results of the analysis (copied from
69 // BufferizationAliasInfo), so that they can be checked in tests.
70 //===----------------------------------------------------------------------===//
71 
72 /// Attribute marker to specify op results that can be bufferized inPlace.
73 constexpr StringLiteral kInPlaceResultsAttrName = "__inplace_operands_attr__";
74 
75 /// Mark whether OpOperand will be bufferized inplace.
76 static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
77   Operation *op = opOperand.getOwner();
78   auto attr =
79       op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null<ArrayAttr>();
80   SmallVector<StringRef> inPlaceVector;
81   if (attr) {
82     inPlaceVector = SmallVector<StringRef>(
83         llvm::to_vector<4>(attr.getAsValueRange<StringAttr>()));
84   } else {
85     inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
86     for (OpOperand &opOperand : op->getOpOperands())
87       if (opOperand.get().getType().isa<TensorType>())
88         inPlaceVector[opOperand.getOperandNumber()] = "false";
89   }
90 
91   inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
92   op->setAttr(kInPlaceResultsAttrName,
93               OpBuilder(op).getStrArrayAttr(inPlaceVector));
94 }
95 
96 //===----------------------------------------------------------------------===//
97 // BufferizationAliasInfo
98 //===----------------------------------------------------------------------===//
99 
100 BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
101   rootOp->walk([&](Operation *op) {
102     for (Value v : op->getResults())
103       if (v.getType().isa<TensorType>())
104         createAliasInfoEntry(v);
105     for (Region &r : op->getRegions())
106       for (Block &b : r.getBlocks())
107         for (auto bbArg : b.getArguments())
108           if (bbArg.getType().isa<TensorType>())
109             createAliasInfoEntry(bbArg);
110   });
111 }
112 
113 /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
114 /// beginning the alias and equivalence sets only contain `v` itself.
115 void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
116   aliasInfo.insert(v);
117   equivalentInfo.insert(v);
118 }
119 
120 /// Insert an info entry for `newValue` and merge its alias set with that of
121 /// `alias`.
122 void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
123   createAliasInfoEntry(newValue);
124   aliasInfo.unionSets(newValue, alias);
125 }
126 
127 /// Insert an info entry for `newValue` and merge its alias set with that of
128 /// `alias`. Additionally, merge their equivalence classes.
129 void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
130                                                         Value alias) {
131   insertNewBufferAlias(newValue, alias);
132   equivalentInfo.unionSets(newValue, alias);
133 }
134 
135 /// Return `true` if a value was marked as in-place bufferized.
136 bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
137   return inplaceBufferized.contains(&operand);
138 }
139 
140 /// Set the inPlace bufferization spec to true.
141 void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
142                                               AnalysisState &state) {
143   markInPlace(operand);
144   for (OpResult result : state.getAliasingOpResult(operand))
145     aliasInfo.unionSets(result, operand.get());
146 }
147 
148 /// Set the inPlace bufferization spec to false.
149 void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) {
150   assert(!inplaceBufferized.contains(&operand) &&
151          "OpOperand was already decided to bufferize inplace");
152 }
153 
154 /// Apply `fun` to all the members of the equivalence class of `v`.
155 void BufferizationAliasInfo::applyOnEquivalenceClass(
156     Value v, function_ref<void(Value)> fun) const {
157   auto leaderIt = equivalentInfo.findLeader(v);
158   for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
159        ++mit) {
160     fun(*mit);
161   }
162 }
163 
164 /// Apply `fun` to all aliases of `v`.
165 void BufferizationAliasInfo::applyOnAliases(
166     Value v, function_ref<void(Value)> fun) const {
167   auto leaderIt = aliasInfo.findLeader(v);
168   for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
169     fun(*mit);
170   }
171 }
172 
173 BufferizationAliasInfo::EquivalenceClassRangeType
174 BufferizationAliasInfo::getAliases(Value v) const {
175   DenseSet<Value> res;
176   auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v));
177   for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
178        mit != meit; ++mit) {
179     res.insert(static_cast<Value>(*mit));
180   }
181   return BufferizationAliasInfo::EquivalenceClassRangeType(
182       aliasInfo.member_begin(it), aliasInfo.member_end());
183 }
184 
185 //===----------------------------------------------------------------------===//
186 // OneShotAnalysisState
187 //===----------------------------------------------------------------------===//
188 
189 OneShotAnalysisState::OneShotAnalysisState(
190     Operation *op, const OneShotBufferizationOptions &options)
191     : AnalysisState(options), aliasInfo(op) {
192   // Set up alias sets for OpResults that must bufferize in-place. This should
193   // be done before making any other bufferization decisions.
194   op->walk([&](BufferizableOpInterface bufferizableOp) {
195     if (!options.isOpAllowed(bufferizableOp))
196       return WalkResult::skip();
197     for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
198       if (opOperand.get().getType().isa<TensorType>())
199         if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) {
200           for (OpResult opResult :
201                bufferizableOp.getAliasingOpResult(opOperand, *this))
202             aliasInfo.unionAliasSets(opOperand.get(), opResult);
203           aliasInfo.markInPlace(opOperand);
204         }
205     }
206     return WalkResult::advance();
207   });
208 }
209 
210 bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const {
211   return aliasInfo.isInPlace(opOperand);
212 }
213 
214 bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
215                                                          Value v2) const {
216   return aliasInfo.areEquivalentBufferizedValues(v1, v2);
217 }
218 
219 bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1,
220                                                        Value v2) const {
221   return aliasInfo.areAliasingBufferizedValues(v1, v2);
222 }
223 
224 // Gather yielded tensors in `yieldedTensors` by querying all aliases. This is
225 // to ensure that such information is available during bufferization time.
226 // Alias information can no longer be queried through BufferizationAliasInfo
227 // once we have started modifying the IR.
228 void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
229   op->walk([&](Operation *returnOp) {
230     if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp))
231       return WalkResult::advance();
232 
233     for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
234       Value returnVal = returnValOperand.get();
235       // Skip non-tensor values.
236       if (!returnVal.getType().isa<TensorType>())
237         continue;
238 
239       // Add all aliases of the returned value. But only the ones that are in
240       // the same block.
241       aliasInfo.applyOnAliases(returnVal, [&](Value v) {
242         if (auto bbArg = v.dyn_cast<BlockArgument>()) {
243           if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp())
244             yieldedTensors.insert(bbArg);
245           return;
246         }
247         Operation *definingOp = v.getDefiningOp();
248         if (definingOp->getParentOp() == returnOp->getParentOp())
249           yieldedTensors.insert(v);
250       });
251     }
252 
253     return WalkResult::advance();
254   });
255 }
256 
257 void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
258   op->walk([&](Operation *op) {
259     // Skip unknown ops.
260     auto bufferizableOp = getOptions().dynCastBufferizableOp(op);
261     if (!bufferizableOp)
262       return WalkResult::skip();
263 
264     // Check all tensor OpResults.
265     for (OpResult opResult : op->getOpResults()) {
266       if (!opResult.getType().isa<TensorType>())
267         continue;
268 
269       // If there is no preceding memory write, the tensor contents are
270       // undefined.
271       // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
272       // use-def chain, it returns that value, regardless of whether it is a
273       // memory write or not.
274       SetVector<Value> lastWrites = findLastPrecedingWrite(opResult);
275       bool isUndefined = llvm::none_of(lastWrites, [&](Value lastWrite) {
276         if (auto bufferizableOp = getOptions().dynCastBufferizableOp(lastWrite))
277           return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(),
278                                               *this);
279         return true;
280       });
281       if (isUndefined)
282         for (OpOperand &use : opResult.getUses())
283           undefinedTensorUses.insert(&use);
284     }
285 
286     return WalkResult::advance();
287   });
288 }
289 
290 bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
291   return undefinedTensorUses.contains(opOperand);
292 }
293 
294 bool OneShotAnalysisState::isTensorYielded(Value tensor) const {
295   return yieldedTensors.contains(tensor);
296 }
297 
298 bool OneShotAnalysisState::isValueWritten(Value value) const {
299   bool isWritten = false;
300   aliasInfo.applyOnAliases(value, [&](Value val) {
301     for (OpOperand &use : val.getUses())
302       if (isInPlace(use) && bufferizesToMemoryWrite(use))
303         isWritten = true;
304   });
305   return isWritten;
306 }
307 
308 bool OneShotAnalysisState::isWritable(Value value) const {
309   // TODO: Out-of-place bufferized value could be considered writable.
310   if (auto bufferizableOp = getOptions().dynCastBufferizableOp(value))
311     return bufferizableOp.isWritable(value, *this);
312 
313   // Query BufferizableOpInterface to see if the BlockArgument is writable.
314   if (auto bbArg = value.dyn_cast<BlockArgument>())
315     if (auto bufferizableOp =
316             getOptions().dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
317       return bufferizableOp.isWritable(bbArg, *this);
318 
319   // Not a bufferizable op: The conservative answer is "not writable".
320   return false;
321 }
322 
323 //===----------------------------------------------------------------------===//
324 // Bufferization-specific alias analysis.
325 //===----------------------------------------------------------------------===//
326 
327 /// Return true if opOperand has been decided to bufferize in-place.
328 static bool isInplaceMemoryWrite(OpOperand &opOperand,
329                                  const BufferizationAliasInfo &aliasInfo,
330                                  const AnalysisState &state) {
331   // OpOperands that do not bufferize to a memory write do not write in-place.
332   if (!state.bufferizesToMemoryWrite(opOperand))
333     return false;
334   // Check current bufferization decisions.
335   return aliasInfo.isInPlace(opOperand);
336 }
337 
338 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
339 /// properly dominates `b` and `b` is not inside `a`.
340 static bool happensBefore(Operation *a, Operation *b,
341                           const DominanceInfo &domInfo) {
342   do {
343     // TODO: Instead of isProperAncestor + properlyDominates, we should use
344     // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
345     if (a->isProperAncestor(b))
346       return false;
347     if (domInfo.properlyDominates(a, b))
348       return true;
349   } while ((a = a->getParentOp()));
350   return false;
351 }
352 
353 /// For each given value, find the closest enclosing repetitive region. If this
354 /// is the same region for each value, return it. Otherwise return None.
355 /// Note: If there is no enclosing repetitive region, return nullptr.
356 static Optional<Region *>
357 getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values) {
358   if (values.empty())
359     return None;
360   Region *r = getEnclosingRepetitiveRegion(values.front());
361   for (Value value : values.drop_front())
362     if (getEnclosingRepetitiveRegion(value) != r)
363       return None;
364   return r;
365 }
366 
367 /// Return `true` if the given tensor value is a memory write. Most values are
368 /// tensor writes, but ops that define a tensor SSA value without specifying its
369 /// contents (e.g., alloc_tensor) are not.
370 static bool isMemoryWrite(Value value, const AnalysisState &state) {
371   auto opResult = value.dyn_cast<OpResult>();
372   if (!opResult)
373     return true;
374   auto bufferizableOp = state.getOptions().dynCastBufferizableOp(value);
375   if (!bufferizableOp)
376     return true;
377   return bufferizableOp.isMemoryWrite(opResult, state);
378 }
379 
380 /// Annotate IR with details about the detected RaW conflict.
381 static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
382                              Value lastWrite) {
383   static uint64_t counter = 0;
384   Operation *readingOp = uRead->getOwner();
385   Operation *conflictingWritingOp = uConflictingWrite->getOwner();
386 
387   OpBuilder b(conflictingWritingOp->getContext());
388   std::string id = "C_" + std::to_string(counter++);
389 
390   std::string conflictingWriteAttr =
391       id +
392       "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) +
393       "]";
394   conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr());
395 
396   std::string readAttr =
397       id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
398   readingOp->setAttr(readAttr, b.getUnitAttr());
399 
400   if (auto opResult = lastWrite.dyn_cast<OpResult>()) {
401     std::string lastWriteAttr = id + "[LAST-WRITE: result " +
402                                 std::to_string(opResult.getResultNumber()) +
403                                 "]";
404     opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr());
405   } else {
406     auto bbArg = lastWrite.cast<BlockArgument>();
407     std::string lastWriteAttr =
408         id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
409     bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr());
410   }
411 }
412 
413 /// Given sets of uses and writes, return true if there is a RaW conflict under
414 /// the assumption that all given reads/writes alias the same buffer and that
415 /// all given writes bufferize inplace.
416 ///
417 /// A conflict is: According to SSA use-def chains, a read R is supposed to read
418 /// the result of a write W1. But because of bufferization decisions, R actually
419 /// reads another write W2.
420 static bool hasReadAfterWriteInterference(
421     const DenseSet<OpOperand *> &usesRead,
422     const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo,
423     AnalysisState &state, const BufferizationAliasInfo &aliasInfo) {
424   const BufferizationOptions &options = state.getOptions();
425 
426   // Gather all written aliases. Skip over aliases that are not actual writes.
427   SmallVector<Value> writtenAliases;
428   for (OpOperand *uWrite : usesWrite)
429     if (isMemoryWrite(uWrite->get(), state))
430       writtenAliases.push_back(uWrite->get());
431   // Find the inner-most enclosing repetitive region of each alias. If this is
432   // the same region for every alias, save it in `repetitiveRegionOfWrites`.
433   Optional<Region *> repetitiveRegionOfWrites =
434       getCommonEnclosingRepetitiveRegion(writtenAliases);
435 
436   for (OpOperand *uRead : usesRead) {
437     Operation *readingOp = uRead->getOwner();
438 
439     // Find most recent writes of uRead by following the SSA use-def chain.
440     // E.g.:
441     //
442     // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
443     // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
444     // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
445     //
446     // In the above example, if uRead is the OpOperand of reading_op, lastWrite
447     // is %0. Note that operations that create an alias but do not write (such
448     // as ExtractSliceOp) are skipped.
449     SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get());
450 
451     // Look for conflicting memory writes. Potential conflicts are writes to an
452     // alias that have been decided to bufferize inplace.
453     for (OpOperand *uConflictingWrite : usesWrite) {
454       // Throughout this loop, check for multiple requirements that have to be
455       // met for uConflictingWrite to be an actual conflict.
456       Operation *conflictingWritingOp = uConflictingWrite->getOwner();
457 
458       // Check if conflictingWritingOp is in the same repetitive region as all
459       // written aliases. If this is not the case, there is no meaningful
460       // `happensBefore` relationship because conflictingWritingOp may be
461       // executed multiple times. E.g.:
462       //
463       // %0 = ... : tensor<?xf32>
464       // scf.for ... {
465       //   "reading_op"(%0) : tensor<?xf32>
466       //   %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
467       //   ...
468       // }
469       //
470       // In the above example, reading_op happens before writing_op according to
471       // op dominance. However, both ops may happen multiple times; in
472       // particular, the second execution of reading_op happens after the first
473       // execution of writing_op. This is problematic if the tensor they operate
474       // on (%0) is defined outside of the loop.
475       //
476       // Counter example:
477       //
478       // scf.for ... {
479       //   %0 = ... : tensor<?xf32>
480       //   "reading_op"(%0) : tensor<?xf32>
481       //   %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
482       //   ...
483       // }
484       //
485       // In this example, %0 is in the same repetitive region as
486       // conflictingWritingOp, so op dominance can be used to compute the
487       // `happensBefore` relationship.
488       //
489       // Note: iter_args of loops are not aliases of their respective block
490       // arguments, so op domanice can be used when analyzing ops that operate
491       // on them.
492       //
493       // Note: If `writtenAliases` is empty, there are no memory writes outside
494       // of the repetitive region of conflictingWritingOp, which means that all
495       // relevant aliases are inside the same repetitive region.
496       bool canUseOpDominance =
497           writtenAliases.empty() ||
498           repetitiveRegionOfWrites ==
499               getEnclosingRepetitiveRegion(conflictingWritingOp);
500 
501       // No conflict if the readingOp dominates conflictingWritingOp, i.e., the
502       // write is not visible when reading.
503       //
504       // Note: If ops are executed multiple times (e.g., because they are inside
505       //       a loop), there may be no meaningful `happensBefore` relationship.
506       if (canUseOpDominance &&
507           happensBefore(readingOp, conflictingWritingOp, domInfo))
508         continue;
509 
510       // No conflict if the reading use equals the use of the conflicting write.
511       // A use cannot conflict with itself.
512       //
513       // Note: Just being the same op is not enough. It has to be the same use.
514       // Note: If the op is executed multiple times (e.g., because it is inside
515       //       a loop), it may be conflicting with itself.
516       if (canUseOpDominance && uConflictingWrite == uRead)
517         continue;
518 
519       // No conflict if the op interface says so.
520       if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp))
521         if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
522           continue;
523 
524       if (conflictingWritingOp != readingOp)
525         if (auto bufferizableOp =
526                 options.dynCastBufferizableOp(conflictingWritingOp))
527           if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
528             continue;
529 
530       // Ops are not conflicting if they are in mutually exclusive regions.
531       //
532       // Note: If ops are executed multiple times (e.g., because they are inside
533       //       a loop), mutually exclusive regions may be executed multiple
534       //       times.
535       if (canUseOpDominance &&
536           insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp))
537         continue;
538 
539       // Check all possible last writes.
540       for (Value lastWrite : lastWrites) {
541         // No conflict if the conflicting write happens before the last
542         // write.
543         if (Operation *writingOp = lastWrite.getDefiningOp()) {
544           if (happensBefore(conflictingWritingOp, writingOp, domInfo))
545             // conflictingWritingOp happens before writingOp. No conflict.
546             continue;
547           // No conflict if conflictingWritingOp is contained in writingOp.
548           if (writingOp->isProperAncestor(conflictingWritingOp))
549             continue;
550         } else {
551           auto bbArg = lastWrite.cast<BlockArgument>();
552           Block *block = bbArg.getOwner();
553           if (!block->findAncestorOpInBlock(*conflictingWritingOp))
554             // conflictingWritingOp happens outside of the block. No
555             // conflict.
556             continue;
557         }
558 
559         // No conflict if the conflicting write and the last write are the same
560         // use.
561         SmallVector<OpResult> aliasingOpResult =
562             state.getAliasingOpResult(*uConflictingWrite);
563         if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite)
564           continue;
565 
566         // All requirements are met. Conflict found!
567 
568         if (options.printConflicts)
569           annotateConflict(uRead, uConflictingWrite, lastWrite);
570 
571         return true;
572       }
573     }
574   }
575 
576   return false;
577 }
578 
579 // Helper function to iterate on aliases of `root` and capture the writes.
580 static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root,
581                                      const BufferizationAliasInfo &aliasInfo,
582                                      const AnalysisState &state) {
583   aliasInfo.applyOnAliases(root, [&](Value alias) {
584     for (auto &use : alias.getUses())
585       // Inplace write to a value that aliases root.
586       if (isInplaceMemoryWrite(use, aliasInfo, state))
587         res.insert(&use);
588   });
589 }
590 
591 // Helper function to iterate on aliases of `root` and capture the reads.
592 static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
593                              const BufferizationAliasInfo &aliasInfo,
594                              const AnalysisState &state) {
595   aliasInfo.applyOnAliases(root, [&](Value alias) {
596     for (auto &use : alias.getUses())
597       // Read to a value that aliases root.
598       if (state.bufferizesToMemoryRead(use))
599         res.insert(&use);
600   });
601 }
602 
603 /// Return true if bufferizing `operand` inplace would create a conflict. A read
604 /// R and a write W of the same alias set is a conflict if inplace bufferization
605 /// of W changes the value read by R to a value different from the one that
606 /// would be expected by tracing back R's origin through SSA use-def chains.
607 /// A conflict can only be introduced by a new alias and/or an inplace
608 /// bufferization decision.
609 ///
610 /// Example:
611 /// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?}
612 /// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32>
613 /// %e = tensor.extract_slice %1
614 /// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32>
615 /// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32>
616 ///
617 /// In the above example, the two TransferWriteOps have already been decided to
618 /// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a
619 /// conflict because:
620 /// * According to SSA use-def chains, we expect to read the result of %1.
621 /// * However, adding an alias {%0, %t} would mean that the second
622 ///   TransferWriteOp overwrites the first one. Therefore, the TransferReadOp
623 ///   would no longer be reading the result of %1.
624 ///
625 /// If `checkConsistencyOnly` is true, this function checks if there is a
626 /// read-after-write conflict without bufferizing `operand` inplace. This would
627 /// indicate a problem with the current inplace bufferization decisions.
628 ///
629 /// Note: If `checkConsistencyOnly`, this function may be called with a null
630 /// OpResult. In that case, only the consistency of bufferization decisions
631 /// involving aliases of the given OpOperand are checked.
632 static bool wouldCreateReadAfterWriteInterference(
633     OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state,
634     const BufferizationAliasInfo &aliasInfo,
635     bool checkConsistencyOnly = false) {
636   // Collect reads and writes of all aliases of OpOperand and OpResult.
637   DenseSet<OpOperand *> usesRead, usesWrite;
638   getAliasingReads(usesRead, operand.get(), aliasInfo, state);
639   getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state);
640   for (OpResult result : state.getAliasingOpResult(operand)) {
641     getAliasingReads(usesRead, result, aliasInfo, state);
642     getAliasingInplaceWrites(usesWrite, result, aliasInfo, state);
643   }
644   if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
645     usesWrite.insert(&operand);
646 
647   return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state,
648                                        aliasInfo);
649 }
650 
651 /// Check the reverse SSA use-def chain (following aliasing OpOperands) for
652 /// non-writable tensor values. Stop searching when an out-of-place bufferized
653 /// OpOperand was found (or when the OpOperand was not bufferized yet).
654 /// `currentOpOperand` is assumed to be in-place, even if that decision was not
655 /// materialized in `aliasInfo` yet.
656 static bool
657 hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand,
658                                       const BufferizationAliasInfo &aliasInfo,
659                                       const OneShotAnalysisState &state) {
660   SmallVector<Value> worklist;
661   worklist.push_back(value);
662   while (!worklist.empty()) {
663     Value nextVal = worklist.pop_back_val();
664     if (!state.isWritable(nextVal))
665       return true;
666 
667     // If `nextVal` is not a BlockArgument: End of use-def chain reached.
668     auto opResult = nextVal.dyn_cast<OpResult>();
669     if (!opResult)
670       continue;
671 
672     // Follow reverse SSA use-def chain.
673     SmallVector<OpOperand *> aliasingOpOperands =
674         state.getAliasingOpOperand(opResult);
675     for (OpOperand *opOperand : aliasingOpOperands)
676       if (aliasInfo.isInPlace(*opOperand) || currentOpOperand == opOperand)
677         worklist.push_back(opOperand->get());
678   }
679   return false;
680 }
681 
682 /// Return true if bufferizing `operand` inplace would create a write to a
683 /// non-writable buffer.
684 static bool wouldCreateWriteToNonWritableBuffer(
685     OpOperand &operand, const BufferizationAliasInfo &aliasInfo,
686     OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
687   // Collect writes of all aliases of OpOperand and OpResult.
688   DenseSet<OpOperand *> usesWrite;
689   getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state);
690   for (OpResult result : state.getAliasingOpResult(operand)) {
691     getAliasingInplaceWrites(usesWrite, result, aliasInfo, state);
692   }
693   if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
694     usesWrite.insert(&operand);
695 
696   // Assuming that `operand` bufferizes in-place: For each write (to each
697   // alias), check if there is a non-writable tensor in the reverse SSA use-def
698   // chain.
699   for (OpOperand *uWrite : usesWrite)
700     if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand,
701                                               aliasInfo, state))
702       return true;
703 
704   return false;
705 }
706 
707 //===----------------------------------------------------------------------===//
708 // Bufferization analyses.
709 //===----------------------------------------------------------------------===//
710 
711 /// Determine if `operand` can be bufferized in-place.
712 static LogicalResult bufferizableInPlaceAnalysisImpl(
713     OpOperand &operand, BufferizationAliasInfo &aliasInfo,
714     OneShotAnalysisState &state, const DominanceInfo &domInfo) {
715   bool foundInterference =
716       wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) ||
717       wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo);
718 
719   if (foundInterference)
720     aliasInfo.bufferizeOutOfPlace(operand);
721   else
722     aliasInfo.bufferizeInPlace(operand, state);
723 
724   return success();
725 }
726 
727 /// Analyze the `ops` to determine which OpOperands are inplaceable. Walk ops in
728 /// reverse and bufferize ops greedily. This is a good starter heuristic.
729 ///
730 /// Even if an op does not read or write, it may still create an alias when
731 /// bufferized in-place. An example of such ops is tensor.extract_slice.
732 ///
733 /// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace:
734 ///
735 /// When bufferized out of place, an ExtractSliceOp lowers to alloc + copy. This
736 /// cannot change the flow of information for either the source or the
737 /// result buffers.
738 ///
739 /// When bufferized inplace, an ExtractSliceOp does not by itself create any
740 /// read or write from memory. Instead, it has the effect of merging the alias
741 /// sets of the source and the result buffers.
742 ///
743 /// An analysis is required to ensure inplace bufferization would not result in
744 /// RaW dependence violations.
745 static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
746                                      BufferizationAliasInfo &aliasInfo,
747                                      OneShotAnalysisState &state,
748                                      const DominanceInfo &domInfo,
749                                      unsigned analysisFuzzerSeed = 0) {
750   if (analysisFuzzerSeed) {
751     // This is a fuzzer. For testing purposes only. Randomize the order in which
752     // operations are analyzed. The bufferization quality is likely worse, but
753     // we want to make sure that no assertions are triggered anywhere.
754     std::mt19937 g(analysisFuzzerSeed);
755     llvm::shuffle(ops.begin(), ops.end(), g);
756   }
757 
758   // Walk ops in reverse for better interference analysis.
759   for (Operation *op : reverse(ops))
760     for (OpOperand &opOperand : op->getOpOperands())
761       if (opOperand.get().getType().isa<TensorType>())
762         if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
763           if (failed(bufferizableInPlaceAnalysisImpl(opOperand, aliasInfo,
764                                                      state, domInfo)))
765             return failure();
766 
767   return success();
768 }
769 
770 /// Return true if the given op has a tensor result or a tensor operand.
771 static bool hasTensorSemantics(Operation *op) {
772   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
773   bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
774   return hasTensorResult || hasTensorOperand;
775 }
776 
777 /// Analyze all ops that are contained in `op`.
778 static LogicalResult inPlaceAnalysis(Operation *op,
779                                      BufferizationAliasInfo &aliasInfo,
780                                      OneShotAnalysisState &state,
781                                      const DominanceInfo &domInfo,
782                                      unsigned analysisFuzzerSeed = 0) {
783   // Collect ops so we can build our own reverse traversal.
784   SmallVector<Operation *> ops;
785   op->walk([&](Operation *op) {
786     // No tensors => no buffers.
787     if (!hasTensorSemantics(op))
788       return;
789     ops.push_back(op);
790   });
791 
792   return inPlaceAnalysis(ops, aliasInfo, state, domInfo, analysisFuzzerSeed);
793 }
794 
795 /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
796 static void equivalenceAnalysis(SmallVector<Operation *> &ops,
797                                 BufferizationAliasInfo &aliasInfo,
798                                 AnalysisState &state) {
799   for (Operation *op : ops)
800     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
801       for (OpResult opResult : op->getOpResults())
802         if (opResult.getType().isa<TensorType>())
803           for (OpOperand *opOperand :
804                bufferizableOp.getAliasingOpOperand(opResult, state))
805             if (state.isInPlace(*opOperand))
806               if (bufferizableOp.bufferRelation(opResult, state) ==
807                   BufferRelation::Equivalent)
808                 aliasInfo.unionEquivalenceClasses(opResult, opOperand->get());
809 }
810 
811 /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
812 /// in `op`.
813 static void equivalenceAnalysis(Operation *op,
814                                 BufferizationAliasInfo &aliasInfo,
815                                 AnalysisState &state) {
816   // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
817   SmallVector<Operation *> ops;
818   op->walk<WalkOrder::PostOrder>([&](Operation *op) {
819     // No tensors => no buffers.
820     if (none_of(op->getResultTypes(), isaTensor))
821       return;
822     ops.push_back(op);
823   });
824 
825   equivalenceAnalysis(ops, aliasInfo, state);
826 }
827 
828 /// Assert that the current bufferization decisions are consistent.
829 static LogicalResult
830 checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
831                           AnalysisState &state,
832                           const BufferizationAliasInfo &aliasInfo) {
833   const BufferizationOptions &options = state.getOptions();
834   Operation *inconsistentOp = nullptr;
835   WalkResult walkResult = op->walk([&](Operation *op) {
836     if (auto bufferizableOp = options.dynCastBufferizableOp(op))
837       for (OpOperand &opOperand : op->getOpOperands())
838         if (opOperand.get().getType().isa<TensorType>()) {
839           if (wouldCreateReadAfterWriteInterference(
840                   opOperand, domInfo, state, aliasInfo,
841                   /*checkConsistencyOnly=*/true)) {
842             // This error can happen if certain "mustBufferizeInPlace" interface
843             // methods are implemented incorrectly, such that the IR already has
844             // a RaW conflict before making any bufferization decisions.
845             inconsistentOp = op;
846             return WalkResult::interrupt();
847           }
848         }
849     return WalkResult::advance();
850   });
851 
852   if (walkResult.wasInterrupted())
853     return inconsistentOp->emitError("input IR has RaW conflict");
854   return success();
855 }
856 
857 /// Annotate the IR with the result of the analysis. For testing/debugging only.
858 static void
859 annotateOpsWithBufferizationMarkers(Operation *op,
860                                     const BufferizationAliasInfo &aliasInfo,
861                                     AnalysisState &state) {
862   op->walk([&](Operation *op) {
863     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
864       for (OpOperand &opOperand : op->getOpOperands())
865         if (opOperand.get().getType().isa<TensorType>())
866           setInPlaceOpOperand(opOperand, aliasInfo.isInPlace(opOperand));
867   });
868 }
869 
870 /// Assert that IR is in destination-passing style. I.e., every value that is
871 /// returned or yielded from a block is:
872 /// * aliasing a bbArg of that block or a parent block, or
873 /// * aliasing an OpResult of a op in a parent block.
874 ///
875 /// Example:
876 /// ```
877 /// %0 = "some_op" : tensor<?xf32>
878 /// %1 = scf.if %c -> (tensor<?xf32>) {
879 ///   scf.yield %0 : tensor<?xf32>
880 /// } else {
881 ///   %t = linalg.alloc_tensor : tensor<?xf32>
882 ///   scf.yield %t : tensor<?xf32>
883 /// }
884 /// ```
885 /// In the above example, the first scf.yield op satifies destination-passing
886 /// style because the yielded value %0 is defined in the parent block. The
887 /// second scf.yield op does not satisfy destination-passing style because the
888 /// yielded value %t is defined in the same block as the scf.yield op.
889 // TODO: The current implementation checks for equivalent values instead of
890 // aliasing values, which is stricter than needed. We can currently not check
891 // for aliasing values because the analysis is a maybe-alias analysis and we
892 // need a must-alias analysis here.
893 static LogicalResult
894 assertDestinationPassingStyle(Operation *op, AnalysisState &state,
895                               BufferizationAliasInfo &aliasInfo,
896                               SmallVector<Operation *> &newOps) {
897   LogicalResult status = success();
898   DominanceInfo domInfo(op);
899   op->walk([&](Operation *returnOp) {
900     if (!isRegionReturnLike(returnOp) ||
901         !state.getOptions().isOpAllowed(returnOp))
902       return WalkResult::advance();
903 
904     for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
905       Value returnVal = returnValOperand.get();
906       // Skip non-tensor values.
907       if (!returnVal.getType().isa<TensorType>())
908         continue;
909 
910       bool foundEquivValue = false;
911       aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
912         if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
913           Operation *definingOp = bbArg.getOwner()->getParentOp();
914           if (definingOp->isProperAncestor(returnOp))
915             foundEquivValue = true;
916           return;
917         }
918 
919         Operation *definingOp = equivVal.getDefiningOp();
920         if (definingOp->getBlock()->findAncestorOpInBlock(
921                 *returnOp->getParentOp()))
922           // Skip ops that happen after `returnOp` and parent ops.
923           if (happensBefore(definingOp, returnOp, domInfo))
924             foundEquivValue = true;
925       });
926 
927       if (!foundEquivValue)
928         status =
929             returnOp->emitError()
930             << "operand #" << returnValOperand.getOperandNumber()
931             << " of ReturnLike op does not satisfy destination passing style";
932     }
933 
934     return WalkResult::advance();
935   });
936 
937   return status;
938 }
939 
940 LogicalResult bufferization::analyzeOp(Operation *op,
941                                        OneShotAnalysisState &state) {
942   DominanceInfo domInfo(op);
943   BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
944   const auto &options =
945       static_cast<const OneShotBufferizationOptions &>(state.getOptions());
946 
947   // Catch incorrect API usage.
948   assert((state.hasDialectState(func::FuncDialect::getDialectNamespace()) ||
949           !options.bufferizeFunctionBoundaries) &&
950          "must use ModuleBufferize to bufferize function boundaries");
951 
952   if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
953     return failure();
954 
955   // If the analysis fails, just return.
956   if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo,
957                              options.analysisFuzzerSeed)))
958     return failure();
959   equivalenceAnalysis(op, aliasInfo, state);
960 
961   bool failedAnalysis = false;
962   if (!options.allowReturnAllocs) {
963     SmallVector<Operation *> newOps;
964     failedAnalysis |=
965         failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps));
966   }
967 
968   // Gather some extra analysis data.
969   state.gatherYieldedTensors(op);
970   state.gatherUndefinedTensorUses(op);
971 
972   // Analysis verification: After setting up alias/equivalence sets, each op
973   // can check for expected invariants/limitations and fail the analysis if
974   // necessary.
975   op->walk([&](Operation *op) {
976     if (BufferizableOpInterface bufferizableOp =
977             options.dynCastBufferizableOp(op))
978       failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state));
979   });
980 
981   // Annotate operations if we only want to report the analysis.
982   if (options.testAnalysisOnly)
983     annotateOpsWithBufferizationMarkers(op, aliasInfo, state);
984 
985   return success(!failedAnalysis);
986 }
987 
988 LogicalResult
989 bufferization::runOneShotBufferize(Operation *op,
990                                    const OneShotBufferizationOptions &options) {
991   OneShotAnalysisState state(op, options);
992   if (failed(analyzeOp(op, state)))
993     return failure();
994   if (options.testAnalysisOnly)
995     return success();
996   return bufferizeOp(op, state);
997 }
998