1 //===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // One-Shot Analysis analyzes function bodies. Function boundaries (FuncOp
10 // bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
11 // ModuleBufferization.cpp is an extension of One-Shot Analysis for simple
12 // call graphs.
13 //
14 // One-Shot Bufferize consists of two phases.
15 //
16 // 1. Analyze ops to decide which OpResults can bufferize inplace, i.e., without
17 //    inserting buffer copies. The analysis queries op bufferization semantics
18 //    via `BufferizableOpInterface`.
19 // 2. Bufferize ops by calling `BufferizableOpInterface::bufferize`. This
20 //    function does not generate buffer copies for OpResults that were decided
21 //    to bufferize inplace during the analysis phase.
22 //
23 // This file contains only the analysis. The actual bufferization is implemented
24 // via `bufferizeOp` (Bufferize.h). For convenience, this file also contains a
25 // helper function `runOneShotBufferize` that analyzes an op (and its nested
26 // ops) and then bufferizes it.
27 //
28 // Inplace bufferization decisions are passed from the analysis to the
29 // bufferization phase via `AnalysisState` and `BufferizationAliasInfo`.
30 // They can be printed for debugging purposes with `testAnalysisOnly`.
31 //
32 // Ops that do not implement `BufferizableOpInterface` can be analyzed but are
33 // treated conservatively. E.g., the analysis has to assume that their tensor
34 // OpOperands bufferize to memory writes. While such ops can be analyzed, they
35 // are not bufferized and remain in the IR. to_tensor and to_memref ops are
36 // inserted at the bufferization boundary.
37 //
38 // This analysis caters to high-performance codegen where buffer reuse is deemed
39 // critical: the analysis should fail if the bufferized form of the function
40 // needs to return a buffer, unless `allowReturnAllocs` is enabled.
41 
42 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
43 
44 #include <random>
45 
46 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
47 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
48 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
49 #include "mlir/Dialect/Func/IR/FuncOps.h"
50 #include "mlir/Dialect/MemRef/IR/MemRef.h"
51 #include "mlir/IR/AsmState.h"
52 #include "mlir/IR/Dominance.h"
53 #include "mlir/IR/Operation.h"
54 #include "mlir/IR/TypeUtilities.h"
55 #include "mlir/Interfaces/ControlFlowInterfaces.h"
56 #include "llvm/ADT/DenseSet.h"
57 #include "llvm/ADT/SetVector.h"
58 
59 using namespace mlir;
60 using namespace mlir::bufferization;
61 
62 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
63 
64 //===----------------------------------------------------------------------===//
65 // Bufferization-specific attribute manipulation.
66 // These are for testing and debugging only. Bufferization information is
67 // stored in BufferizationAliasInfo. When run with `testAnalysisOnly`, the IR
68 // is annotated with the results of the analysis (copied from
69 // BufferizationAliasInfo), so that they can be checked in tests.
70 //===----------------------------------------------------------------------===//
71 
72 /// Attribute marker to specify op results that can be bufferized inPlace.
73 constexpr StringLiteral kInPlaceResultsAttrName = "__inplace_operands_attr__";
74 
75 /// Mark whether OpOperand will be bufferized inplace.
76 static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
77   Operation *op = opOperand.getOwner();
78   auto attr =
79       op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null<ArrayAttr>();
80   SmallVector<StringRef> inPlaceVector;
81   if (attr) {
82     inPlaceVector = SmallVector<StringRef>(
83         llvm::to_vector<4>(attr.getAsValueRange<StringAttr>()));
84   } else {
85     inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
86     for (OpOperand &opOperand : op->getOpOperands())
87       if (opOperand.get().getType().isa<TensorType>())
88         inPlaceVector[opOperand.getOperandNumber()] = "false";
89   }
90 
91   inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
92   op->setAttr(kInPlaceResultsAttrName,
93               OpBuilder(op).getStrArrayAttr(inPlaceVector));
94 }
95 
96 //===----------------------------------------------------------------------===//
97 // BufferizationAliasInfo
98 //===----------------------------------------------------------------------===//
99 
100 BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
101   rootOp->walk([&](Operation *op) {
102     for (Value v : op->getResults())
103       if (v.getType().isa<TensorType>())
104         createAliasInfoEntry(v);
105     for (Region &r : op->getRegions())
106       for (Block &b : r.getBlocks())
107         for (auto bbArg : b.getArguments())
108           if (bbArg.getType().isa<TensorType>())
109             createAliasInfoEntry(bbArg);
110   });
111 }
112 
113 /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
114 /// beginning the alias and equivalence sets only contain `v` itself.
115 void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
116   aliasInfo.insert(v);
117   equivalentInfo.insert(v);
118 }
119 
120 /// Insert an info entry for `newValue` and merge its alias set with that of
121 /// `alias`.
122 void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
123   createAliasInfoEntry(newValue);
124   aliasInfo.unionSets(newValue, alias);
125 }
126 
127 /// Insert an info entry for `newValue` and merge its alias set with that of
128 /// `alias`. Additionally, merge their equivalence classes.
129 void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
130                                                         Value alias) {
131   insertNewBufferAlias(newValue, alias);
132   equivalentInfo.unionSets(newValue, alias);
133 }
134 
135 /// Return `true` if a value was marked as in-place bufferized.
136 bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
137   return inplaceBufferized.contains(&operand);
138 }
139 
140 /// Set the inPlace bufferization spec to true.
141 void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
142                                               AnalysisState &state) {
143   markInPlace(operand);
144   for (OpResult result : state.getAliasingOpResult(operand))
145     aliasInfo.unionSets(result, operand.get());
146 }
147 
148 /// Set the inPlace bufferization spec to false.
149 void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) {
150   assert(!inplaceBufferized.contains(&operand) &&
151          "OpOperand was already decided to bufferize inplace");
152 }
153 
154 /// Apply `fun` to all the members of the equivalence class of `v`.
155 void BufferizationAliasInfo::applyOnEquivalenceClass(
156     Value v, function_ref<void(Value)> fun) const {
157   auto leaderIt = equivalentInfo.findLeader(v);
158   for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
159        ++mit) {
160     fun(*mit);
161   }
162 }
163 
164 /// Apply `fun` to all aliases of `v`.
165 void BufferizationAliasInfo::applyOnAliases(
166     Value v, function_ref<void(Value)> fun) const {
167   auto leaderIt = aliasInfo.findLeader(v);
168   for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
169     fun(*mit);
170   }
171 }
172 
173 BufferizationAliasInfo::EquivalenceClassRangeType
174 BufferizationAliasInfo::getAliases(Value v) const {
175   DenseSet<Value> res;
176   auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v));
177   for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
178        mit != meit; ++mit) {
179     res.insert(static_cast<Value>(*mit));
180   }
181   return BufferizationAliasInfo::EquivalenceClassRangeType(
182       aliasInfo.member_begin(it), aliasInfo.member_end());
183 }
184 
185 //===----------------------------------------------------------------------===//
186 // OneShotAnalysisState
187 //===----------------------------------------------------------------------===//
188 
189 OneShotAnalysisState::OneShotAnalysisState(
190     Operation *op, const OneShotBufferizationOptions &options)
191     : AnalysisState(options), aliasInfo(op) {
192   // Set up alias sets for OpResults that must bufferize in-place. This should
193   // be done before making any other bufferization decisions.
194   op->walk([&](BufferizableOpInterface bufferizableOp) {
195     if (!options.isOpAllowed(bufferizableOp))
196       return WalkResult::skip();
197     for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
198       if (opOperand.get().getType().isa<TensorType>())
199         if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) {
200           for (OpResult opResult :
201                bufferizableOp.getAliasingOpResult(opOperand, *this))
202             aliasInfo.unionAliasSets(opOperand.get(), opResult);
203           aliasInfo.markInPlace(opOperand);
204         }
205     }
206     return WalkResult::advance();
207   });
208 }
209 
210 bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const {
211   return aliasInfo.isInPlace(opOperand);
212 }
213 
214 bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
215                                                          Value v2) const {
216   return aliasInfo.areEquivalentBufferizedValues(v1, v2);
217 }
218 
219 // Gather yielded tensors in `yieldedTensors` by querying all aliases. This is
220 // to ensure that such information is available during bufferization time.
221 // Alias information can no longer be queried through BufferizationAliasInfo
222 // once we have started modifying the IR.
223 void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
224   op->walk([&](Operation *returnOp) {
225     if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp))
226       return WalkResult::advance();
227 
228     for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
229       Value returnVal = returnValOperand.get();
230       // Skip non-tensor values.
231       if (!returnVal.getType().isa<TensorType>())
232         continue;
233 
234       // Add all aliases of the returned value. But only the ones that are in
235       // the same block.
236       aliasInfo.applyOnAliases(returnVal, [&](Value v) {
237         if (auto bbArg = v.dyn_cast<BlockArgument>()) {
238           if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp())
239             yieldedTensors.insert(bbArg);
240           return;
241         }
242         Operation *definingOp = v.getDefiningOp();
243         if (definingOp->getParentOp() == returnOp->getParentOp())
244           yieldedTensors.insert(v);
245       });
246     }
247 
248     return WalkResult::advance();
249   });
250 }
251 
252 void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
253   op->walk([&](Operation *op) {
254     // Skip unknown ops.
255     auto bufferizableOp = getOptions().dynCastBufferizableOp(op);
256     if (!bufferizableOp)
257       return WalkResult::skip();
258 
259     // Check all tensor OpResults.
260     for (OpResult opResult : op->getOpResults()) {
261       if (!opResult.getType().isa<TensorType>())
262         continue;
263 
264       // If there is no preceding memory write, the tensor contents are
265       // undefined.
266       // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
267       // use-def chain, it returns that value, regardless of whether it is a
268       // memory write or not.
269       SetVector<Value> lastWrites = findLastPrecedingWrite(opResult);
270       bool isUndefined = llvm::none_of(lastWrites, [&](Value lastWrite) {
271         if (auto bufferizableOp = getOptions().dynCastBufferizableOp(lastWrite))
272           return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(),
273                                               *this);
274         return true;
275       });
276       if (isUndefined)
277         for (OpOperand &use : opResult.getUses())
278           undefinedTensorUses.insert(&use);
279     }
280 
281     return WalkResult::advance();
282   });
283 }
284 
285 bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
286   return undefinedTensorUses.contains(opOperand);
287 }
288 
289 bool OneShotAnalysisState::isTensorYielded(Value tensor) const {
290   return yieldedTensors.contains(tensor);
291 }
292 
293 //===----------------------------------------------------------------------===//
294 // Bufferization-specific alias analysis.
295 //===----------------------------------------------------------------------===//
296 
297 /// Return true if opOperand has been decided to bufferize in-place.
298 static bool isInplaceMemoryWrite(OpOperand &opOperand,
299                                  const BufferizationAliasInfo &aliasInfo,
300                                  AnalysisState &state) {
301   // OpOperands that do not bufferize to a memory write do not write in-place.
302   if (!state.bufferizesToMemoryWrite(opOperand))
303     return false;
304   // Check current bufferization decisions.
305   return aliasInfo.isInPlace(opOperand);
306 }
307 
308 /// Return true if, under current bufferization decisions, the buffer of `value`
309 /// is not writable.
310 static bool aliasesNonWritableBuffer(Value value,
311                                      const BufferizationAliasInfo &aliasInfo,
312                                      AnalysisState &state) {
313   bool foundNonWritableBuffer = false;
314   aliasInfo.applyOnAliases(value, [&](Value v) {
315     // Query BufferizableOpInterface to see if the value is writable.
316     // TODO: Out-of-place bufferized value could be considered writable.
317     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(v))
318       if (bufferizableOp && bufferizableOp.isWritable(v, state))
319         return;
320 
321     // Query BufferizableOpInterface to see if the BlockArgument is writable.
322     if (auto bbArg = v.dyn_cast<BlockArgument>())
323       if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(
324               bbArg.getOwner()->getParentOp()))
325         if (bufferizableOp.isWritable(bbArg, state))
326           return;
327 
328     foundNonWritableBuffer = true;
329   });
330 
331   return foundNonWritableBuffer;
332 }
333 
334 /// Return true if the buffer to which `operand` would bufferize is equivalent
335 /// to some buffer write.
336 static bool aliasesInPlaceWrite(Value value,
337                                 const BufferizationAliasInfo &aliasInfo,
338                                 AnalysisState &state) {
339   bool foundInplaceWrite = false;
340   aliasInfo.applyOnAliases(value, [&](Value v) {
341     for (auto &use : v.getUses()) {
342       if (isInplaceMemoryWrite(use, aliasInfo, state)) {
343         foundInplaceWrite = true;
344         return;
345       }
346     }
347   });
348   return foundInplaceWrite;
349 }
350 
351 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
352 /// properly dominates `b` and `b` is not inside `a`.
353 static bool happensBefore(Operation *a, Operation *b,
354                           const DominanceInfo &domInfo) {
355   do {
356     // TODO: Instead of isProperAncestor + properlyDominates, we should use
357     // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
358     if (a->isProperAncestor(b))
359       return false;
360     if (domInfo.properlyDominates(a, b))
361       return true;
362   } while ((a = a->getParentOp()));
363   return false;
364 }
365 
366 /// For each given value, find the closest enclosing repetitive region. If this
367 /// is the same region for each value, return it. Otherwise return None.
368 /// Note: If there is no enclosing repetitive region, return nullptr.
369 static Optional<Region *>
370 getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values) {
371   if (values.empty())
372     return None;
373   Region *r = getEnclosingRepetitiveRegion(values.front());
374   for (Value value : values.drop_front())
375     if (getEnclosingRepetitiveRegion(value) != r)
376       return None;
377   return r;
378 }
379 
380 /// Return `true` if the given tensor value is a memory write. Most values are
381 /// tensor writes, but ops that define a tensor SSA value without specifying its
382 /// contents (e.g., alloc_tensor) are not.
383 static bool isMemoryWrite(Value value, const AnalysisState &state) {
384   auto opResult = value.dyn_cast<OpResult>();
385   if (!opResult)
386     return true;
387   auto bufferizableOp = state.getOptions().dynCastBufferizableOp(value);
388   if (!bufferizableOp)
389     return true;
390   return bufferizableOp.isMemoryWrite(opResult, state);
391 }
392 
393 /// Annotate IR with details about the detected RaW conflict.
394 static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
395                              Value lastWrite) {
396   static uint64_t counter = 0;
397   Operation *readingOp = uRead->getOwner();
398   Operation *conflictingWritingOp = uConflictingWrite->getOwner();
399 
400   OpBuilder b(conflictingWritingOp->getContext());
401   std::string id = "C_" + std::to_string(counter++);
402 
403   std::string conflictingWriteAttr =
404       id +
405       "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) +
406       "]";
407   conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr());
408 
409   std::string readAttr =
410       id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
411   readingOp->setAttr(readAttr, b.getUnitAttr());
412 
413   if (auto opResult = lastWrite.dyn_cast<OpResult>()) {
414     std::string lastWriteAttr = id + "[LAST-WRITE: result " +
415                                 std::to_string(opResult.getResultNumber()) +
416                                 "]";
417     opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr());
418   } else {
419     auto bbArg = lastWrite.cast<BlockArgument>();
420     std::string lastWriteAttr =
421         id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
422     bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr());
423   }
424 }
425 
426 /// Given sets of uses and writes, return true if there is a RaW conflict under
427 /// the assumption that all given reads/writes alias the same buffer and that
428 /// all given writes bufferize inplace.
429 ///
430 /// A conflict is: According to SSA use-def chains, a read R is supposed to read
431 /// the result of a write W1. But because of bufferization decisions, R actually
432 /// reads another write W2.
433 static bool hasReadAfterWriteInterference(
434     const DenseSet<OpOperand *> &usesRead,
435     const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo,
436     AnalysisState &state, const BufferizationAliasInfo &aliasInfo) {
437   const BufferizationOptions &options = state.getOptions();
438 
439   // Gather all written aliases. Skip over aliases that are not actual writes.
440   SmallVector<Value> writtenAliases;
441   for (OpOperand *uWrite : usesWrite)
442     if (isMemoryWrite(uWrite->get(), state))
443       writtenAliases.push_back(uWrite->get());
444   // Find the inner-most enclosing repetitive region of each alias. If this is
445   // the same region for every alias, save it in `repetitiveRegionOfWrites`.
446   Optional<Region *> repetitiveRegionOfWrites =
447       getCommonEnclosingRepetitiveRegion(writtenAliases);
448 
449   for (OpOperand *uRead : usesRead) {
450     Operation *readingOp = uRead->getOwner();
451 
452     // Find most recent writes of uRead by following the SSA use-def chain.
453     // E.g.:
454     //
455     // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
456     // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
457     // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
458     //
459     // In the above example, if uRead is the OpOperand of reading_op, lastWrite
460     // is %0. Note that operations that create an alias but do not write (such
461     // as ExtractSliceOp) are skipped.
462     SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get());
463 
464     // Look for conflicting memory writes. Potential conflicts are writes to an
465     // alias that have been decided to bufferize inplace.
466     for (OpOperand *uConflictingWrite : usesWrite) {
467       // Throughout this loop, check for multiple requirements that have to be
468       // met for uConflictingWrite to be an actual conflict.
469       Operation *conflictingWritingOp = uConflictingWrite->getOwner();
470 
471       // Check if conflictingWritingOp is in the same repetitive region as all
472       // written aliases. If this is not the case, there is no meaningful
473       // `happensBefore` relationship because conflictingWritingOp may be
474       // executed multiple times. E.g.:
475       //
476       // %0 = ... : tensor<?xf32>
477       // scf.for ... {
478       //   "reading_op"(%0) : tensor<?xf32>
479       //   %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
480       //   ...
481       // }
482       //
483       // In the above example, reading_op happens before writing_op according to
484       // op dominance. However, both ops may happen multiple times; in
485       // particular, the second execution of reading_op happens after the first
486       // execution of writing_op. This is problematic if the tensor they operate
487       // on (%0) is defined outside of the loop.
488       //
489       // Counter example:
490       //
491       // scf.for ... {
492       //   %0 = ... : tensor<?xf32>
493       //   "reading_op"(%0) : tensor<?xf32>
494       //   %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
495       //   ...
496       // }
497       //
498       // In this example, %0 is in the same repetitive region as
499       // conflictingWritingOp, so op dominance can be used to compute the
500       // `happensBefore` relationship.
501       //
502       // Note: iter_args of loops are not aliases of their respective block
503       // arguments, so op domanice can be used when analyzing ops that operate
504       // on them.
505       //
506       // Note: If `writtenAliases` is empty, there are no memory writes outside
507       // of the repetitive region of conflictingWritingOp, which means that all
508       // relevant aliases are inside the same repetitive region.
509       bool canUseOpDominance =
510           writtenAliases.empty() ||
511           repetitiveRegionOfWrites ==
512               getEnclosingRepetitiveRegion(conflictingWritingOp);
513 
514       // No conflict if the readingOp dominates conflictingWritingOp, i.e., the
515       // write is not visible when reading.
516       //
517       // Note: If ops are executed multiple times (e.g., because they are inside
518       //       a loop), there may be no meaningful `happensBefore` relationship.
519       if (canUseOpDominance &&
520           happensBefore(readingOp, conflictingWritingOp, domInfo))
521         continue;
522 
523       // No conflict if the reading use equals the use of the conflicting write.
524       // A use cannot conflict with itself.
525       //
526       // Note: Just being the same op is not enough. It has to be the same use.
527       // Note: If the op is executed multiple times (e.g., because it is inside
528       //       a loop), it may be conflicting with itself.
529       if (canUseOpDominance && uConflictingWrite == uRead)
530         continue;
531 
532       // No conflict if the op interface says so.
533       if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp))
534         if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
535           continue;
536 
537       if (conflictingWritingOp != readingOp)
538         if (auto bufferizableOp =
539                 options.dynCastBufferizableOp(conflictingWritingOp))
540           if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
541             continue;
542 
543       // Ops are not conflicting if they are in mutually exclusive regions.
544       //
545       // Note: If ops are executed multiple times (e.g., because they are inside
546       //       a loop), mutually exclusive regions may be executed multiple
547       //       times.
548       if (canUseOpDominance &&
549           insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp))
550         continue;
551 
552       // Check all possible last writes.
553       for (Value lastWrite : lastWrites) {
554         // No conflict if the conflicting write happens before the last
555         // write.
556         if (Operation *writingOp = lastWrite.getDefiningOp()) {
557           if (happensBefore(conflictingWritingOp, writingOp, domInfo))
558             // conflictingWritingOp happens before writingOp. No conflict.
559             continue;
560           // No conflict if conflictingWritingOp is contained in writingOp.
561           if (writingOp->isProperAncestor(conflictingWritingOp))
562             continue;
563         } else {
564           auto bbArg = lastWrite.cast<BlockArgument>();
565           Block *block = bbArg.getOwner();
566           if (!block->findAncestorOpInBlock(*conflictingWritingOp))
567             // conflictingWritingOp happens outside of the block. No
568             // conflict.
569             continue;
570         }
571 
572         // No conflict if the conflicting write and the last write are the same
573         // use.
574         SmallVector<OpResult> aliasingOpResult =
575             state.getAliasingOpResult(*uConflictingWrite);
576         if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite)
577           continue;
578 
579         // All requirements are met. Conflict found!
580 
581         if (options.printConflicts)
582           annotateConflict(uRead, uConflictingWrite, lastWrite);
583 
584         return true;
585       }
586     }
587   }
588 
589   return false;
590 }
591 
592 /// Return true if bufferizing `operand` inplace would create a conflict. A read
593 /// R and a write W of the same alias set is a conflict if inplace bufferization
594 /// of W changes the value read by R to a value different from the one that
595 /// would be expected by tracing back R's origin through SSA use-def chains.
596 /// A conflict can only be introduced by a new alias and/or an inplace
597 /// bufferization decision.
598 ///
599 /// Example:
600 /// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?}
601 /// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32>
602 /// %e = tensor.extract_slice %1
603 /// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32>
604 /// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32>
605 ///
606 /// In the above example, the two TransferWriteOps have already been decided to
607 /// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a
608 /// conflict because:
609 /// * According to SSA use-def chains, we expect to read the result of %1.
610 /// * However, adding an alias {%0, %t} would mean that the second
611 ///   TransferWriteOp overwrites the first one. Therefore, the TransferReadOp
612 ///   would no longer be reading the result of %1.
613 ///
614 /// If `checkConsistencyOnly` is true, this function checks if there is a
615 /// read-after-write conflict without bufferizing `operand` inplace. This would
616 /// indicate a problem with the current inplace bufferization decisions.
617 ///
618 /// Note: If `checkConsistencyOnly`, this function may be called with a null
619 /// OpResult. In that case, only the consistency of bufferization decisions
620 /// involving aliases of the given OpOperand are checked.
621 static bool wouldCreateReadAfterWriteInterference(
622     OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state,
623     const BufferizationAliasInfo &aliasInfo,
624     bool checkConsistencyOnly = false) {
625   // Helper function to iterate on aliases of `root` and capture the reads.
626   auto getAliasingReads = [&](DenseSet<OpOperand *> &res, Value root) {
627     aliasInfo.applyOnAliases(root, [&](Value alias) {
628       for (auto &use : alias.getUses())
629         // Read to a value that aliases root.
630         if (state.bufferizesToMemoryRead(use))
631           res.insert(&use);
632     });
633   };
634 
635   // Helper function to iterate on aliases of `root` and capture the writes.
636   auto getAliasingInplaceWrites = [&](DenseSet<OpOperand *> &res, Value root) {
637     aliasInfo.applyOnAliases(root, [&](Value alias) {
638       for (auto &use : alias.getUses())
639         // Inplace write to a value that aliases root.
640         if (isInplaceMemoryWrite(use, aliasInfo, state))
641           res.insert(&use);
642     });
643   };
644 
645   // Collect reads and writes of all aliases of OpOperand and OpResult.
646   DenseSet<OpOperand *> usesRead, usesWrite;
647   getAliasingReads(usesRead, operand.get());
648   getAliasingInplaceWrites(usesWrite, operand.get());
649   for (OpResult result : state.getAliasingOpResult(operand)) {
650     getAliasingReads(usesRead, result);
651     getAliasingInplaceWrites(usesWrite, result);
652   }
653   if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
654     usesWrite.insert(&operand);
655 
656   return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state,
657                                        aliasInfo);
658 }
659 
660 /// Return true if bufferizing `opOperand` inplace would create a write to a
661 /// non-writable buffer.
662 static bool
663 wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand,
664                                     const BufferizationAliasInfo &aliasInfo,
665                                     AnalysisState &state) {
666   // Certain buffers are not writeable:
667   //   1. A function bbArg that is not inplaceable or
668   //   2. A constant op.
669   bool nonWritable =
670       aliasesNonWritableBuffer(opOperand.get(), aliasInfo, state);
671   if (!nonWritable)
672     return false;
673 
674   // This is a problem only if the buffer is written to via some alias.
675   bool hasWrite = aliasesInPlaceWrite(opOperand.get(), aliasInfo, state) ||
676                   state.bufferizesToMemoryWrite(opOperand);
677 
678   for (OpResult opResult : state.getAliasingOpResult(opOperand))
679     hasWrite |= aliasesInPlaceWrite(opResult, aliasInfo, state);
680 
681   return hasWrite;
682 }
683 
684 //===----------------------------------------------------------------------===//
685 // Bufferization analyses.
686 //===----------------------------------------------------------------------===//
687 
688 /// Determine if `operand` can be bufferized in-place.
689 static LogicalResult bufferizableInPlaceAnalysisImpl(
690     OpOperand &operand, BufferizationAliasInfo &aliasInfo, AnalysisState &state,
691     const DominanceInfo &domInfo) {
692   bool foundInterference =
693       wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) ||
694       wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo);
695 
696   if (foundInterference)
697     aliasInfo.bufferizeOutOfPlace(operand);
698   else
699     aliasInfo.bufferizeInPlace(operand, state);
700 
701   return success();
702 }
703 
704 /// Analyze the `ops` to determine which OpOperands are inplaceable. Walk ops in
705 /// reverse and bufferize ops greedily. This is a good starter heuristic.
706 ///
707 /// Even if an op does not read or write, it may still create an alias when
708 /// bufferized in-place. An example of such ops is tensor.extract_slice.
709 ///
710 /// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace:
711 ///
712 /// When bufferized out of place, an ExtractSliceOp lowers to alloc + copy. This
713 /// cannot change the flow of information for either the source or the
714 /// result buffers.
715 ///
716 /// When bufferized inplace, an ExtractSliceOp does not by itself create any
717 /// read or write from memory. Instead, it has the effect of merging the alias
718 /// sets of the source and the result buffers.
719 ///
720 /// An analysis is required to ensure inplace bufferization would not result in
721 /// RaW dependence violations.
722 static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
723                                      BufferizationAliasInfo &aliasInfo,
724                                      AnalysisState &state,
725                                      const DominanceInfo &domInfo,
726                                      unsigned analysisFuzzerSeed = 0) {
727   if (analysisFuzzerSeed) {
728     // This is a fuzzer. For testing purposes only. Randomize the order in which
729     // operations are analyzed. The bufferization quality is likely worse, but
730     // we want to make sure that no assertions are triggered anywhere.
731     std::mt19937 g(analysisFuzzerSeed);
732     llvm::shuffle(ops.begin(), ops.end(), g);
733   }
734 
735   // Walk ops in reverse for better interference analysis.
736   for (Operation *op : reverse(ops))
737     for (OpOperand &opOperand : op->getOpOperands())
738       if (opOperand.get().getType().isa<TensorType>())
739         if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
740           if (failed(bufferizableInPlaceAnalysisImpl(opOperand, aliasInfo,
741                                                      state, domInfo)))
742             return failure();
743 
744   return success();
745 }
746 
747 /// Return true if the given op has a tensor result or a tensor operand.
748 static bool hasTensorSemantics(Operation *op) {
749   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
750   bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
751   return hasTensorResult || hasTensorOperand;
752 }
753 
754 /// Analyze all ops that are contained in `op`.
755 static LogicalResult inPlaceAnalysis(Operation *op,
756                                      BufferizationAliasInfo &aliasInfo,
757                                      AnalysisState &state,
758                                      const DominanceInfo &domInfo,
759                                      unsigned analysisFuzzerSeed = 0) {
760   // Collect ops so we can build our own reverse traversal.
761   SmallVector<Operation *> ops;
762   op->walk([&](Operation *op) {
763     // No tensors => no buffers.
764     if (!hasTensorSemantics(op))
765       return;
766     ops.push_back(op);
767   });
768 
769   return inPlaceAnalysis(ops, aliasInfo, state, domInfo, analysisFuzzerSeed);
770 }
771 
772 /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
773 static void equivalenceAnalysis(SmallVector<Operation *> &ops,
774                                 BufferizationAliasInfo &aliasInfo,
775                                 AnalysisState &state) {
776   for (Operation *op : ops)
777     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
778       for (OpResult opResult : op->getOpResults())
779         if (opResult.getType().isa<TensorType>())
780           for (OpOperand *opOperand :
781                bufferizableOp.getAliasingOpOperand(opResult, state))
782             if (state.isInPlace(*opOperand))
783               if (bufferizableOp.bufferRelation(opResult, state) ==
784                   BufferRelation::Equivalent)
785                 aliasInfo.unionEquivalenceClasses(opResult, opOperand->get());
786 }
787 
788 /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
789 /// in `op`.
790 static void equivalenceAnalysis(Operation *op,
791                                 BufferizationAliasInfo &aliasInfo,
792                                 AnalysisState &state) {
793   // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
794   SmallVector<Operation *> ops;
795   op->walk<WalkOrder::PostOrder>([&](Operation *op) {
796     // No tensors => no buffers.
797     if (none_of(op->getResultTypes(), isaTensor))
798       return;
799     ops.push_back(op);
800   });
801 
802   equivalenceAnalysis(ops, aliasInfo, state);
803 }
804 
805 /// Assert that the current bufferization decisions are consistent.
806 static LogicalResult
807 checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
808                           AnalysisState &state,
809                           const BufferizationAliasInfo &aliasInfo) {
810   const BufferizationOptions &options = state.getOptions();
811   Operation *inconsistentOp = nullptr;
812   WalkResult walkResult = op->walk([&](Operation *op) {
813     if (auto bufferizableOp = options.dynCastBufferizableOp(op))
814       for (OpOperand &opOperand : op->getOpOperands())
815         if (opOperand.get().getType().isa<TensorType>()) {
816           if (wouldCreateReadAfterWriteInterference(
817                   opOperand, domInfo, state, aliasInfo,
818                   /*checkConsistencyOnly=*/true)) {
819             // This error can happen if certain "mustBufferizeInPlace" interface
820             // methods are implemented incorrectly, such that the IR already has
821             // a RaW conflict before making any bufferization decisions.
822             inconsistentOp = op;
823             return WalkResult::interrupt();
824           }
825         }
826     return WalkResult::advance();
827   });
828 
829   if (walkResult.wasInterrupted())
830     return inconsistentOp->emitError("input IR has RaW conflict");
831   return success();
832 }
833 
834 /// Annotate the IR with the result of the analysis. For testing/debugging only.
835 static void
836 annotateOpsWithBufferizationMarkers(Operation *op,
837                                     const BufferizationAliasInfo &aliasInfo,
838                                     AnalysisState &state) {
839   op->walk([&](Operation *op) {
840     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
841       for (OpOperand &opOperand : op->getOpOperands())
842         if (opOperand.get().getType().isa<TensorType>())
843           setInPlaceOpOperand(opOperand, aliasInfo.isInPlace(opOperand));
844   });
845 }
846 
847 /// Assert that IR is in destination-passing style. I.e., every value that is
848 /// returned or yielded from a block is:
849 /// * aliasing a bbArg of that block or a parent block, or
850 /// * aliasing an OpResult of a op in a parent block.
851 ///
852 /// Example:
853 /// ```
854 /// %0 = "some_op" : tensor<?xf32>
855 /// %1 = scf.if %c -> (tensor<?xf32>) {
856 ///   scf.yield %0 : tensor<?xf32>
857 /// } else {
858 ///   %t = linalg.alloc_tensor : tensor<?xf32>
859 ///   scf.yield %t : tensor<?xf32>
860 /// }
861 /// ```
862 /// In the above example, the first scf.yield op satifies destination-passing
863 /// style because the yielded value %0 is defined in the parent block. The
864 /// second scf.yield op does not satisfy destination-passing style because the
865 /// yielded value %t is defined in the same block as the scf.yield op.
866 // TODO: The current implementation checks for equivalent values instead of
867 // aliasing values, which is stricter than needed. We can currently not check
868 // for aliasing values because the analysis is a maybe-alias analysis and we
869 // need a must-alias analysis here.
870 static LogicalResult
871 assertDestinationPassingStyle(Operation *op, AnalysisState &state,
872                               BufferizationAliasInfo &aliasInfo,
873                               SmallVector<Operation *> &newOps) {
874   LogicalResult status = success();
875   DominanceInfo domInfo(op);
876   op->walk([&](Operation *returnOp) {
877     if (!isRegionReturnLike(returnOp) ||
878         !state.getOptions().isOpAllowed(returnOp))
879       return WalkResult::advance();
880 
881     for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
882       Value returnVal = returnValOperand.get();
883       // Skip non-tensor values.
884       if (!returnVal.getType().isa<TensorType>())
885         continue;
886 
887       bool foundEquivValue = false;
888       aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
889         if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
890           Operation *definingOp = bbArg.getOwner()->getParentOp();
891           if (definingOp->isProperAncestor(returnOp))
892             foundEquivValue = true;
893           return;
894         }
895 
896         Operation *definingOp = equivVal.getDefiningOp();
897         if (definingOp->getBlock()->findAncestorOpInBlock(
898                 *returnOp->getParentOp()))
899           // Skip ops that happen after `returnOp` and parent ops.
900           if (happensBefore(definingOp, returnOp, domInfo))
901             foundEquivValue = true;
902       });
903 
904       if (!foundEquivValue)
905         status =
906             returnOp->emitError()
907             << "operand #" << returnValOperand.getOperandNumber()
908             << " of ReturnLike op does not satisfy destination passing style";
909     }
910 
911     return WalkResult::advance();
912   });
913 
914   return status;
915 }
916 
917 LogicalResult bufferization::analyzeOp(Operation *op,
918                                        OneShotAnalysisState &state) {
919   DominanceInfo domInfo(op);
920   BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
921   const auto &options =
922       static_cast<const OneShotBufferizationOptions &>(state.getOptions());
923 
924   // Catch incorrect API usage.
925   assert((state.hasDialectState(func::FuncDialect::getDialectNamespace()) ||
926           !options.bufferizeFunctionBoundaries) &&
927          "must use ModuleBufferize to bufferize function boundaries");
928 
929   if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
930     return failure();
931 
932   // If the analysis fails, just return.
933   if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo,
934                              options.analysisFuzzerSeed)))
935     return failure();
936   equivalenceAnalysis(op, aliasInfo, state);
937 
938   for (const PostAnalysisStepFn &fn : options.postAnalysisSteps) {
939     SmallVector<Operation *> newOps;
940     if (failed(fn(op, state, aliasInfo, newOps)))
941       return failure();
942     // Analyze ops that were created by the PostAnalysisStepFn.
943     if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo)))
944       return failure();
945     equivalenceAnalysis(newOps, aliasInfo, state);
946   }
947 
948   bool failedAnalysis = false;
949   if (!options.allowReturnAllocs) {
950     SmallVector<Operation *> newOps;
951     failedAnalysis |=
952         failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps));
953   }
954 
955   // Gather some extra analysis data.
956   state.gatherYieldedTensors(op);
957   state.gatherUndefinedTensorUses(op);
958 
959   // Analysis verification: After setting up alias/equivalence sets, each op
960   // can check for expected invariants/limitations and fail the analysis if
961   // necessary.
962   op->walk([&](Operation *op) {
963     if (BufferizableOpInterface bufferizableOp =
964             options.dynCastBufferizableOp(op))
965       failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state));
966   });
967 
968   // Annotate operations if we only want to report the analysis.
969   if (options.testAnalysisOnly)
970     annotateOpsWithBufferizationMarkers(op, aliasInfo, state);
971 
972   return success(!failedAnalysis);
973 }
974 
975 LogicalResult
976 bufferization::runOneShotBufferize(Operation *op,
977                                    const OneShotBufferizationOptions &options) {
978   OneShotAnalysisState state(op, options);
979   if (failed(analyzeOp(op, state)))
980     return failure();
981   if (options.testAnalysisOnly)
982     return success();
983   return bufferizeOp(op, state);
984 }
985