1 //===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // One-Shot Analysis analyzes function bodies. Function boundaries (FuncOp
10 // bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
11 // ModuleBufferization.cpp is an extension of One-Shot Analysis for simple
12 // call graphs.
13 //
14 // One-Shot Bufferize consists of two phases.
15 //
16 // 1. Analyze ops to decide which OpResults can bufferize inplace, i.e., without
17 //    inserting buffer copies. The analysis queries op bufferization semantics
18 //    via `BufferizableOpInterface`.
19 // 2. Bufferize ops by calling `BufferizableOpInterface::bufferize`. This
20 //    function does not generate buffer copies for OpResults that were decided
21 //    to bufferize inplace during the analysis phase.
22 //
23 // This file contains only the analysis. The actual bufferization is implemented
24 // via `bufferizeOp` (Bufferize.h). For convenience, this file also contains a
25 // helper function `runOneShotBufferize` that analyzes an op (and its nested
26 // ops) and then bufferizes it.
27 //
28 // Inplace bufferization decisions are passed from the analysis to the
29 // bufferization phase via `AnalysisState` and `BufferizationAliasInfo`.
30 // They can be printed for debugging purposes with `testAnalysisOnly`.
31 //
32 // Ops that do not implement `BufferizableOpInterface` can be analyzed but are
33 // treated conservatively. E.g., the analysis has to assume that their tensor
34 // OpOperands bufferize to memory writes. While such ops can be analyzed, they
35 // are not bufferized and remain in the IR. to_tensor and to_memref ops are
36 // inserted at the bufferization boundary.
37 //
38 // This analysis caters to high-performance codegen where buffer reuse is deemed
39 // critical: the analysis should fail if the bufferized form of the function
40 // needs to return a buffer, unless `allowReturnAllocs` is enabled.
41 
42 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
43 
44 #include <random>
45 
46 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
47 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
48 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
49 #include "mlir/Dialect/Func/IR/FuncOps.h"
50 #include "mlir/Dialect/MemRef/IR/MemRef.h"
51 #include "mlir/IR/AsmState.h"
52 #include "mlir/IR/Dominance.h"
53 #include "mlir/IR/Operation.h"
54 #include "mlir/IR/TypeUtilities.h"
55 #include "mlir/Interfaces/ControlFlowInterfaces.h"
56 #include "llvm/ADT/DenseSet.h"
57 #include "llvm/ADT/SetVector.h"
58 
59 using namespace mlir;
60 using namespace mlir::bufferization;
61 
62 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
63 
64 //===----------------------------------------------------------------------===//
65 // Bufferization-specific attribute manipulation.
66 // These are for testing and debugging only. Bufferization information is
67 // stored in BufferizationAliasInfo. When run with `testAnalysisOnly`, the IR
68 // is annotated with the results of the analysis (copied from
69 // BufferizationAliasInfo), so that they can be checked in tests.
70 //===----------------------------------------------------------------------===//
71 
72 /// Attribute marker to specify op results that can be bufferized inPlace.
73 constexpr StringLiteral kInPlaceResultsAttrName = "__inplace_operands_attr__";
74 
75 /// Mark whether OpOperand will be bufferized inplace.
76 static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
77   Operation *op = opOperand.getOwner();
78   auto attr =
79       op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null<ArrayAttr>();
80   SmallVector<StringRef> inPlaceVector;
81   if (attr) {
82     inPlaceVector = SmallVector<StringRef>(
83         llvm::to_vector<4>(attr.getAsValueRange<StringAttr>()));
84   } else {
85     inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
86     for (OpOperand &opOperand : op->getOpOperands())
87       if (opOperand.get().getType().isa<TensorType>())
88         inPlaceVector[opOperand.getOperandNumber()] = "false";
89   }
90 
91   inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
92   op->setAttr(kInPlaceResultsAttrName,
93               OpBuilder(op).getStrArrayAttr(inPlaceVector));
94 }
95 
96 //===----------------------------------------------------------------------===//
97 // BufferizationAliasInfo
98 //===----------------------------------------------------------------------===//
99 
100 BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
101   rootOp->walk([&](Operation *op) {
102     for (Value v : op->getResults())
103       if (v.getType().isa<TensorType>())
104         createAliasInfoEntry(v);
105     for (Region &r : op->getRegions())
106       for (Block &b : r.getBlocks())
107         for (auto bbArg : b.getArguments())
108           if (bbArg.getType().isa<TensorType>())
109             createAliasInfoEntry(bbArg);
110   });
111 }
112 
113 /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
114 /// beginning the alias and equivalence sets only contain `v` itself.
115 void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
116   aliasInfo.insert(v);
117   equivalentInfo.insert(v);
118 }
119 
120 /// Insert an info entry for `newValue` and merge its alias set with that of
121 /// `alias`.
122 void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
123   createAliasInfoEntry(newValue);
124   aliasInfo.unionSets(newValue, alias);
125 }
126 
127 /// Insert an info entry for `newValue` and merge its alias set with that of
128 /// `alias`. Additionally, merge their equivalence classes.
129 void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
130                                                         Value alias) {
131   insertNewBufferAlias(newValue, alias);
132   equivalentInfo.unionSets(newValue, alias);
133 }
134 
135 /// Return `true` if a value was marked as in-place bufferized.
136 bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
137   return inplaceBufferized.contains(&operand);
138 }
139 
140 /// Set the inPlace bufferization spec to true.
141 void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
142                                               AnalysisState &state) {
143   markInPlace(operand);
144   for (OpResult result : state.getAliasingOpResult(operand))
145     aliasInfo.unionSets(result, operand.get());
146 }
147 
148 /// Set the inPlace bufferization spec to false.
149 void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) {
150   assert(!inplaceBufferized.contains(&operand) &&
151          "OpOperand was already decided to bufferize inplace");
152 }
153 
154 /// Apply `fun` to all the members of the equivalence class of `v`.
155 void BufferizationAliasInfo::applyOnEquivalenceClass(
156     Value v, function_ref<void(Value)> fun) const {
157   auto leaderIt = equivalentInfo.findLeader(v);
158   for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
159        ++mit) {
160     fun(*mit);
161   }
162 }
163 
164 /// Apply `fun` to all aliases of `v`.
165 void BufferizationAliasInfo::applyOnAliases(
166     Value v, function_ref<void(Value)> fun) const {
167   auto leaderIt = aliasInfo.findLeader(v);
168   for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
169     fun(*mit);
170   }
171 }
172 
173 BufferizationAliasInfo::EquivalenceClassRangeType
174 BufferizationAliasInfo::getAliases(Value v) const {
175   DenseSet<Value> res;
176   auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v));
177   for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
178        mit != meit; ++mit) {
179     res.insert(static_cast<Value>(*mit));
180   }
181   return BufferizationAliasInfo::EquivalenceClassRangeType(
182       aliasInfo.member_begin(it), aliasInfo.member_end());
183 }
184 
185 //===----------------------------------------------------------------------===//
186 // OneShotAnalysisState
187 //===----------------------------------------------------------------------===//
188 
189 OneShotAnalysisState::OneShotAnalysisState(
190     Operation *op, const OneShotBufferizationOptions &options)
191     : AnalysisState(options), aliasInfo(op) {
192   // Set up alias sets for OpResults that must bufferize in-place. This should
193   // be done before making any other bufferization decisions.
194   op->walk([&](BufferizableOpInterface bufferizableOp) {
195     if (!options.isOpAllowed(bufferizableOp))
196       return WalkResult::skip();
197     for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
198       if (opOperand.get().getType().isa<TensorType>())
199         if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) {
200           for (OpResult opResult :
201                bufferizableOp.getAliasingOpResult(opOperand, *this))
202             aliasInfo.unionAliasSets(opOperand.get(), opResult);
203           aliasInfo.markInPlace(opOperand);
204         }
205     }
206     return WalkResult::advance();
207   });
208 }
209 
210 bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const {
211   return aliasInfo.isInPlace(opOperand);
212 }
213 
214 bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
215                                                          Value v2) const {
216   return aliasInfo.areEquivalentBufferizedValues(v1, v2);
217 }
218 
219 // Gather yielded tensors in `yieldedTensors` by querying all aliases. This is
220 // to ensure that such information is available during bufferization time.
221 // Alias information can no longer be queried through BufferizationAliasInfo
222 // once we have started modifying the IR.
223 void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
224   op->walk([&](Operation *returnOp) {
225     if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp))
226       return WalkResult::advance();
227 
228     for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
229       Value returnVal = returnValOperand.get();
230       // Skip non-tensor values.
231       if (!returnVal.getType().isa<TensorType>())
232         continue;
233 
234       // Add all aliases of the returned value. But only the ones that are in
235       // the same block.
236       aliasInfo.applyOnAliases(returnVal, [&](Value v) {
237         if (auto bbArg = v.dyn_cast<BlockArgument>()) {
238           if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp())
239             yieldedTensors.insert(bbArg);
240           return;
241         }
242         Operation *definingOp = v.getDefiningOp();
243         if (definingOp->getParentOp() == returnOp->getParentOp())
244           yieldedTensors.insert(v);
245       });
246     }
247 
248     return WalkResult::advance();
249   });
250 }
251 
252 bool OneShotAnalysisState::isTensorYielded(Value tensor) const {
253   return yieldedTensors.contains(tensor);
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // Bufferization-specific alias analysis.
258 //===----------------------------------------------------------------------===//
259 
260 /// Return true if opOperand has been decided to bufferize in-place.
261 static bool isInplaceMemoryWrite(OpOperand &opOperand,
262                                  const BufferizationAliasInfo &aliasInfo,
263                                  AnalysisState &state) {
264   // OpOperands that do not bufferize to a memory write do not write in-place.
265   if (!state.bufferizesToMemoryWrite(opOperand))
266     return false;
267   // Check current bufferization decisions.
268   return aliasInfo.isInPlace(opOperand);
269 }
270 
271 /// Return true if, under current bufferization decisions, the buffer of `value`
272 /// is not writable.
273 static bool aliasesNonWritableBuffer(Value value,
274                                      const BufferizationAliasInfo &aliasInfo,
275                                      AnalysisState &state) {
276   bool foundNonWritableBuffer = false;
277   aliasInfo.applyOnAliases(value, [&](Value v) {
278     // Query BufferizableOpInterface to see if the value is writable.
279     // TODO: Out-of-place bufferized value could be considered writable.
280     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(v))
281       if (bufferizableOp && bufferizableOp.isWritable(v, state))
282         return;
283 
284     // Query BufferizableOpInterface to see if the BlockArgument is writable.
285     if (auto bbArg = v.dyn_cast<BlockArgument>())
286       if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(
287               bbArg.getOwner()->getParentOp()))
288         if (bufferizableOp.isWritable(bbArg, state))
289           return;
290 
291     foundNonWritableBuffer = true;
292   });
293 
294   return foundNonWritableBuffer;
295 }
296 
297 /// Return true if the buffer to which `operand` would bufferize is equivalent
298 /// to some buffer write.
299 static bool aliasesInPlaceWrite(Value value,
300                                 const BufferizationAliasInfo &aliasInfo,
301                                 AnalysisState &state) {
302   bool foundInplaceWrite = false;
303   aliasInfo.applyOnAliases(value, [&](Value v) {
304     for (auto &use : v.getUses()) {
305       if (isInplaceMemoryWrite(use, aliasInfo, state)) {
306         foundInplaceWrite = true;
307         return;
308       }
309     }
310   });
311   return foundInplaceWrite;
312 }
313 
314 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
315 /// properly dominates `b` and `b` is not inside `a`.
316 static bool happensBefore(Operation *a, Operation *b,
317                           const DominanceInfo &domInfo) {
318   do {
319     // TODO: Instead of isProperAncestor + properlyDominates, we should use
320     // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
321     if (a->isProperAncestor(b))
322       return false;
323     if (domInfo.properlyDominates(a, b))
324       return true;
325   } while ((a = a->getParentOp()));
326   return false;
327 }
328 
329 /// For each given value, find the closest enclosing repetitive region. If this
330 /// is the same region for each value, return it. Otherwise return None.
331 /// Note: If there is no enclosing repetitive region, return nullptr.
332 static Optional<Region *>
333 getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values) {
334   if (values.empty())
335     return None;
336   Region *r = getEnclosingRepetitiveRegion(values.front());
337   for (Value value : values.drop_front())
338     if (getEnclosingRepetitiveRegion(value) != r)
339       return None;
340   return r;
341 }
342 
343 /// Annotate IR with details about the detected RaW conflict.
344 static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
345                              Value lastWrite) {
346   static uint64_t counter = 0;
347   Operation *readingOp = uRead->getOwner();
348   Operation *conflictingWritingOp = uConflictingWrite->getOwner();
349 
350   OpBuilder b(conflictingWritingOp->getContext());
351   std::string id = "C_" + std::to_string(counter++);
352 
353   std::string conflictingWriteAttr =
354       id +
355       "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) +
356       "]";
357   conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr());
358 
359   std::string readAttr =
360       id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
361   readingOp->setAttr(readAttr, b.getUnitAttr());
362 
363   if (auto opResult = lastWrite.dyn_cast<OpResult>()) {
364     std::string lastWriteAttr = id + "[LAST-WRITE: result " +
365                                 std::to_string(opResult.getResultNumber()) +
366                                 "]";
367     opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr());
368   } else {
369     auto bbArg = lastWrite.cast<BlockArgument>();
370     std::string lastWriteAttr =
371         id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
372     bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr());
373   }
374 }
375 
376 /// Given sets of uses and writes, return true if there is a RaW conflict under
377 /// the assumption that all given reads/writes alias the same buffer and that
378 /// all given writes bufferize inplace.
379 ///
380 /// A conflict is: According to SSA use-def chains, a read R is supposed to read
381 /// the result of a write W1. But because of bufferization decisions, R actually
382 /// reads another write W2.
383 static bool hasReadAfterWriteInterference(
384     const DenseSet<OpOperand *> &usesRead,
385     const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo,
386     AnalysisState &state, const BufferizationAliasInfo &aliasInfo) {
387   const BufferizationOptions &options = state.getOptions();
388 
389   // Gather all written aliases.
390   SmallVector<Value> writtenAliases;
391   for (OpOperand *uWrite : usesWrite)
392     writtenAliases.push_back(uWrite->get());
393   // Find the inner-most enclosing repetitive region of each alias. If this is
394   // the same region for every alias, save it in `repetitiveRegionOfWrites`.
395   Optional<Region *> repetitiveRegionOfWrites =
396       getCommonEnclosingRepetitiveRegion(writtenAliases);
397 
398   for (OpOperand *uRead : usesRead) {
399     Operation *readingOp = uRead->getOwner();
400 
401     // Find most recent writes of uRead by following the SSA use-def chain.
402     // E.g.:
403     //
404     // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
405     // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
406     // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
407     //
408     // In the above example, if uRead is the OpOperand of reading_op, lastWrite
409     // is %0. Note that operations that create an alias but do not write (such
410     // as ExtractSliceOp) are skipped.
411     SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get());
412 
413     // Look for conflicting memory writes. Potential conflicts are writes to an
414     // alias that have been decided to bufferize inplace.
415     for (OpOperand *uConflictingWrite : usesWrite) {
416       // Throughout this loop, check for multiple requirements that have to be
417       // met for uConflictingWrite to be an actual conflict.
418       Operation *conflictingWritingOp = uConflictingWrite->getOwner();
419 
420       // Check if conflictingWritingOp is in the same repetitive region as all
421       // written aliases. If this is not the case, there is no meaningful
422       // `happensBefore` relationship because conflictingWritingOp may be
423       // executed multiple times. E.g.:
424       //
425       // %0 = ... : tensor<?xf32>
426       // scf.for ... {
427       //   "reading_op"(%0) : tensor<?xf32>
428       //   %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
429       //   ...
430       // }
431       //
432       // In the above example, reading_op happens before writing_op according to
433       // op dominance. However, both ops may happen multiple times; in
434       // particular, the second execution of reading_op happens after the first
435       // execution of writing_op. This is problematic if the tensor they operate
436       // on (%0) is defined outside of the loop.
437       //
438       // Counter example:
439       //
440       // scf.for ... {
441       //   %0 = ... : tensor<?xf32>
442       //   "reading_op"(%0) : tensor<?xf32>
443       //   %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
444       //   ...
445       // }
446       //
447       // In this example, %0 is in the same repetitive region as
448       // conflictingWritingOp, so op dominance can be used to compute the
449       // `happensBefore` relationship.
450       //
451       // Note: iter_args of loops are not aliases of their respective block
452       // arguments, so op domanice can be used when analyzing ops that operate
453       // on them.
454       bool canUseOpDominance =
455           repetitiveRegionOfWrites ==
456           getEnclosingRepetitiveRegion(conflictingWritingOp);
457 
458       // No conflict if the readingOp dominates conflictingWritingOp, i.e., the
459       // write is not visible when reading.
460       //
461       // Note: If ops are executed multiple times (e.g., because they are inside
462       //       a loop), there may be no meaningful `happensBefore` relationship.
463       if (canUseOpDominance &&
464           happensBefore(readingOp, conflictingWritingOp, domInfo))
465         continue;
466 
467       // No conflict if the reading use equals the use of the conflicting write.
468       // A use cannot conflict with itself.
469       //
470       // Note: Just being the same op is not enough. It has to be the same use.
471       // Note: If the op is executed multiple times (e.g., because it is inside
472       //       a loop), it may be conflicting with itself.
473       if (canUseOpDominance && uConflictingWrite == uRead)
474         continue;
475 
476       // No conflict if the op interface says so.
477       if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp))
478         if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
479           continue;
480 
481       if (conflictingWritingOp != readingOp)
482         if (auto bufferizableOp =
483                 options.dynCastBufferizableOp(conflictingWritingOp))
484           if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
485             continue;
486 
487       // Ops are not conflicting if they are in mutually exclusive regions.
488       //
489       // Note: If ops are executed multiple times (e.g., because they are inside
490       //       a loop), mutually exclusive regions may be executed multiple
491       //       times.
492       if (canUseOpDominance &&
493           insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp))
494         continue;
495 
496       // Check all possible last writes.
497       for (Value lastWrite : lastWrites) {
498         // No conflict if the conflicting write happens before the last
499         // write.
500         if (Operation *writingOp = lastWrite.getDefiningOp()) {
501           if (happensBefore(conflictingWritingOp, writingOp, domInfo))
502             // conflictingWritingOp happens before writingOp. No conflict.
503             continue;
504           // No conflict if conflictingWritingOp is contained in writingOp.
505           if (writingOp->isProperAncestor(conflictingWritingOp))
506             continue;
507         } else {
508           auto bbArg = lastWrite.cast<BlockArgument>();
509           Block *block = bbArg.getOwner();
510           if (!block->findAncestorOpInBlock(*conflictingWritingOp))
511             // conflictingWritingOp happens outside of the block. No
512             // conflict.
513             continue;
514         }
515 
516         // No conflict if the conflicting write and the last write are the same
517         // use.
518         SmallVector<OpResult> aliasingOpResult =
519             state.getAliasingOpResult(*uConflictingWrite);
520         if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite)
521           continue;
522 
523         // All requirements are met. Conflict found!
524 
525         if (options.printConflicts)
526           annotateConflict(uRead, uConflictingWrite, lastWrite);
527 
528         return true;
529       }
530     }
531   }
532 
533   return false;
534 }
535 
536 /// Return true if bufferizing `operand` inplace would create a conflict. A read
537 /// R and a write W of the same alias set is a conflict if inplace bufferization
538 /// of W changes the value read by R to a value different from the one that
539 /// would be expected by tracing back R's origin through SSA use-def chains.
540 /// A conflict can only be introduced by a new alias and/or an inplace
541 /// bufferization decision.
542 ///
543 /// Example:
544 /// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?}
545 /// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32>
546 /// %e = tensor.extract_slice %1
547 /// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32>
548 /// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32>
549 ///
550 /// In the above example, the two TransferWriteOps have already been decided to
551 /// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a
552 /// conflict because:
553 /// * According to SSA use-def chains, we expect to read the result of %1.
554 /// * However, adding an alias {%0, %t} would mean that the second
555 ///   TransferWriteOp overwrites the first one. Therefore, the TransferReadOp
556 ///   would no longer be reading the result of %1.
557 ///
558 /// If `checkConsistencyOnly` is true, this function checks if there is a
559 /// read-after-write conflict without bufferizing `operand` inplace. This would
560 /// indicate a problem with the current inplace bufferization decisions.
561 ///
562 /// Note: If `checkConsistencyOnly`, this function may be called with a null
563 /// OpResult. In that case, only the consistency of bufferization decisions
564 /// involving aliases of the given OpOperand are checked.
565 static bool wouldCreateReadAfterWriteInterference(
566     OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state,
567     const BufferizationAliasInfo &aliasInfo,
568     bool checkConsistencyOnly = false) {
569   // Helper function to iterate on aliases of `root` and capture the reads.
570   auto getAliasingReads = [&](DenseSet<OpOperand *> &res, Value root) {
571     aliasInfo.applyOnAliases(root, [&](Value alias) {
572       for (auto &use : alias.getUses())
573         // Read to a value that aliases root.
574         if (state.bufferizesToMemoryRead(use))
575           res.insert(&use);
576     });
577   };
578 
579   // Helper function to iterate on aliases of `root` and capture the writes.
580   auto getAliasingInplaceWrites = [&](DenseSet<OpOperand *> &res, Value root) {
581     aliasInfo.applyOnAliases(root, [&](Value alias) {
582       for (auto &use : alias.getUses())
583         // Inplace write to a value that aliases root.
584         if (isInplaceMemoryWrite(use, aliasInfo, state))
585           res.insert(&use);
586     });
587   };
588 
589   // Collect reads and writes of all aliases of OpOperand and OpResult.
590   DenseSet<OpOperand *> usesRead, usesWrite;
591   getAliasingReads(usesRead, operand.get());
592   getAliasingInplaceWrites(usesWrite, operand.get());
593   for (OpResult result : state.getAliasingOpResult(operand)) {
594     getAliasingReads(usesRead, result);
595     getAliasingInplaceWrites(usesWrite, result);
596   }
597   if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
598     usesWrite.insert(&operand);
599 
600   return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state,
601                                        aliasInfo);
602 }
603 
604 /// Return true if bufferizing `opOperand` inplace would create a write to a
605 /// non-writable buffer.
606 static bool
607 wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand,
608                                     const BufferizationAliasInfo &aliasInfo,
609                                     AnalysisState &state) {
610   // Certain buffers are not writeable:
611   //   1. A function bbArg that is not inplaceable or
612   //   2. A constant op.
613   bool nonWritable =
614       aliasesNonWritableBuffer(opOperand.get(), aliasInfo, state);
615   if (!nonWritable)
616     return false;
617 
618   // This is a problem only if the buffer is written to via some alias.
619   bool hasWrite = aliasesInPlaceWrite(opOperand.get(), aliasInfo, state) ||
620                   state.bufferizesToMemoryWrite(opOperand);
621 
622   for (OpResult opResult : state.getAliasingOpResult(opOperand))
623     hasWrite |= aliasesInPlaceWrite(opResult, aliasInfo, state);
624 
625   return hasWrite;
626 }
627 
628 //===----------------------------------------------------------------------===//
629 // Bufferization analyses.
630 //===----------------------------------------------------------------------===//
631 
632 /// Determine if `operand` can be bufferized in-place.
633 static LogicalResult bufferizableInPlaceAnalysisImpl(
634     OpOperand &operand, BufferizationAliasInfo &aliasInfo, AnalysisState &state,
635     const DominanceInfo &domInfo) {
636   bool foundInterference =
637       wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) ||
638       wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo);
639 
640   if (foundInterference)
641     aliasInfo.bufferizeOutOfPlace(operand);
642   else
643     aliasInfo.bufferizeInPlace(operand, state);
644 
645   return success();
646 }
647 
648 /// Analyze the `ops` to determine which OpOperands are inplaceable. Walk ops in
649 /// reverse and bufferize ops greedily. This is a good starter heuristic.
650 ///
651 /// Even if an op does not read or write, it may still create an alias when
652 /// bufferized in-place. An example of such ops is tensor.extract_slice.
653 ///
654 /// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace:
655 ///
656 /// When bufferized out of place, an ExtractSliceOp lowers to alloc + copy. This
657 /// cannot change the flow of information for either the source or the
658 /// result buffers.
659 ///
660 /// When bufferized inplace, an ExtractSliceOp does not by itself create any
661 /// read or write from memory. Instead, it has the effect of merging the alias
662 /// sets of the source and the result buffers.
663 ///
664 /// An analysis is required to ensure inplace bufferization would not result in
665 /// RaW dependence violations.
666 static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
667                                      BufferizationAliasInfo &aliasInfo,
668                                      AnalysisState &state,
669                                      const DominanceInfo &domInfo,
670                                      unsigned analysisFuzzerSeed = 0) {
671   if (analysisFuzzerSeed) {
672     // This is a fuzzer. For testing purposes only. Randomize the order in which
673     // operations are analyzed. The bufferization quality is likely worse, but
674     // we want to make sure that no assertions are triggered anywhere.
675     std::mt19937 g(analysisFuzzerSeed);
676     llvm::shuffle(ops.begin(), ops.end(), g);
677   }
678 
679   // Walk ops in reverse for better interference analysis.
680   for (Operation *op : reverse(ops))
681     for (OpOperand &opOperand : op->getOpOperands())
682       if (opOperand.get().getType().isa<TensorType>())
683         if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
684           if (failed(bufferizableInPlaceAnalysisImpl(opOperand, aliasInfo,
685                                                      state, domInfo)))
686             return failure();
687 
688   return success();
689 }
690 
691 /// Return true if the given op has a tensor result or a tensor operand.
692 static bool hasTensorSemantics(Operation *op) {
693   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
694   bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
695   return hasTensorResult || hasTensorOperand;
696 }
697 
698 /// Analyze all ops that are contained in `op`.
699 static LogicalResult inPlaceAnalysis(Operation *op,
700                                      BufferizationAliasInfo &aliasInfo,
701                                      AnalysisState &state,
702                                      const DominanceInfo &domInfo,
703                                      unsigned analysisFuzzerSeed = 0) {
704   // Collect ops so we can build our own reverse traversal.
705   SmallVector<Operation *> ops;
706   op->walk([&](Operation *op) {
707     // No tensors => no buffers.
708     if (!hasTensorSemantics(op))
709       return;
710     ops.push_back(op);
711   });
712 
713   return inPlaceAnalysis(ops, aliasInfo, state, domInfo, analysisFuzzerSeed);
714 }
715 
716 /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
717 static void equivalenceAnalysis(SmallVector<Operation *> &ops,
718                                 BufferizationAliasInfo &aliasInfo,
719                                 AnalysisState &state) {
720   for (Operation *op : ops)
721     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
722       for (OpResult opResult : op->getOpResults())
723         if (opResult.getType().isa<TensorType>())
724           for (OpOperand *opOperand :
725                bufferizableOp.getAliasingOpOperand(opResult, state))
726             if (state.isInPlace(*opOperand))
727               if (bufferizableOp.bufferRelation(opResult, state) ==
728                   BufferRelation::Equivalent)
729                 aliasInfo.unionEquivalenceClasses(opResult, opOperand->get());
730 }
731 
732 /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
733 /// in `op`.
734 static void equivalenceAnalysis(Operation *op,
735                                 BufferizationAliasInfo &aliasInfo,
736                                 AnalysisState &state) {
737   // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
738   SmallVector<Operation *> ops;
739   op->walk<WalkOrder::PostOrder>([&](Operation *op) {
740     // No tensors => no buffers.
741     if (none_of(op->getResultTypes(), isaTensor))
742       return;
743     ops.push_back(op);
744   });
745 
746   equivalenceAnalysis(ops, aliasInfo, state);
747 }
748 
749 /// Assert that the current bufferization decisions are consistent.
750 static LogicalResult
751 checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
752                           AnalysisState &state,
753                           const BufferizationAliasInfo &aliasInfo) {
754   const BufferizationOptions &options = state.getOptions();
755   Operation *inconsistentOp = nullptr;
756   WalkResult walkResult = op->walk([&](Operation *op) {
757     if (auto bufferizableOp = options.dynCastBufferizableOp(op))
758       for (OpOperand &opOperand : op->getOpOperands())
759         if (opOperand.get().getType().isa<TensorType>()) {
760           if (wouldCreateReadAfterWriteInterference(
761                   opOperand, domInfo, state, aliasInfo,
762                   /*checkConsistencyOnly=*/true)) {
763             // This error can happen if certain "mustBufferizeInPlace" interface
764             // methods are implemented incorrectly, such that the IR already has
765             // a RaW conflict before making any bufferization decisions.
766             inconsistentOp = op;
767             return WalkResult::interrupt();
768           }
769         }
770     return WalkResult::advance();
771   });
772 
773   if (walkResult.wasInterrupted())
774     return inconsistentOp->emitError("input IR has RaW conflict");
775   return success();
776 }
777 
778 /// Annotate the IR with the result of the analysis. For testing/debugging only.
779 static void
780 annotateOpsWithBufferizationMarkers(Operation *op,
781                                     const BufferizationAliasInfo &aliasInfo,
782                                     AnalysisState &state) {
783   op->walk([&](Operation *op) {
784     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
785       for (OpOperand &opOperand : op->getOpOperands())
786         if (opOperand.get().getType().isa<TensorType>())
787           setInPlaceOpOperand(opOperand, aliasInfo.isInPlace(opOperand));
788   });
789 }
790 
791 /// Assert that IR is in destination-passing style. I.e., every value that is
792 /// returned or yielded from a block is:
793 /// * aliasing a bbArg of that block or a parent block, or
794 /// * aliasing an OpResult of a op in a parent block.
795 ///
796 /// Example:
797 /// ```
798 /// %0 = "some_op" : tensor<?xf32>
799 /// %1 = scf.if %c -> (tensor<?xf32>) {
800 ///   scf.yield %0 : tensor<?xf32>
801 /// } else {
802 ///   %t = linalg.init_tensor : tensor<?xf32>
803 ///   scf.yield %t : tensor<?xf32>
804 /// }
805 /// ```
806 /// In the above example, the first scf.yield op satifies destination-passing
807 /// style because the yielded value %0 is defined in the parent block. The
808 /// second scf.yield op does not satisfy destination-passing style because the
809 /// yielded value %t is defined in the same block as the scf.yield op.
810 // TODO: The current implementation checks for equivalent values instead of
811 // aliasing values, which is stricter than needed. We can currently not check
812 // for aliasing values because the analysis is a maybe-alias analysis and we
813 // need a must-alias analysis here.
814 static LogicalResult
815 assertDestinationPassingStyle(Operation *op, AnalysisState &state,
816                               BufferizationAliasInfo &aliasInfo,
817                               SmallVector<Operation *> &newOps) {
818   LogicalResult status = success();
819   DominanceInfo domInfo(op);
820   op->walk([&](Operation *returnOp) {
821     if (!isRegionReturnLike(returnOp) ||
822         !state.getOptions().isOpAllowed(returnOp))
823       return WalkResult::advance();
824 
825     for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
826       Value returnVal = returnValOperand.get();
827       // Skip non-tensor values.
828       if (!returnVal.getType().isa<TensorType>())
829         continue;
830 
831       bool foundEquivValue = false;
832       aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
833         if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
834           Operation *definingOp = bbArg.getOwner()->getParentOp();
835           if (definingOp->isProperAncestor(returnOp))
836             foundEquivValue = true;
837           return;
838         }
839 
840         Operation *definingOp = equivVal.getDefiningOp();
841         if (definingOp->getBlock()->findAncestorOpInBlock(
842                 *returnOp->getParentOp()))
843           // Skip ops that happen after `returnOp` and parent ops.
844           if (happensBefore(definingOp, returnOp, domInfo))
845             foundEquivValue = true;
846       });
847 
848       if (!foundEquivValue)
849         status =
850             returnOp->emitError()
851             << "operand #" << returnValOperand.getOperandNumber()
852             << " of ReturnLike op does not satisfy destination passing style";
853     }
854 
855     return WalkResult::advance();
856   });
857 
858   return status;
859 }
860 
861 LogicalResult bufferization::analyzeOp(Operation *op,
862                                        OneShotAnalysisState &state) {
863   DominanceInfo domInfo(op);
864   BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
865   const auto &options =
866       static_cast<const OneShotBufferizationOptions &>(state.getOptions());
867 
868   // Catch incorrect API usage.
869   assert((state.hasDialectState(func::FuncDialect::getDialectNamespace()) ||
870           !options.bufferizeFunctionBoundaries) &&
871          "must use ModuleBufferize to bufferize function boundaries");
872 
873   if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
874     return failure();
875 
876   // If the analysis fails, just return.
877   if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo,
878                              options.analysisFuzzerSeed)))
879     return failure();
880   equivalenceAnalysis(op, aliasInfo, state);
881 
882   for (const PostAnalysisStepFn &fn : options.postAnalysisSteps) {
883     SmallVector<Operation *> newOps;
884     if (failed(fn(op, state, aliasInfo, newOps)))
885       return failure();
886     // Analyze ops that were created by the PostAnalysisStepFn.
887     if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo)))
888       return failure();
889     equivalenceAnalysis(newOps, aliasInfo, state);
890   }
891 
892   bool failedAnalysis = false;
893   if (!options.allowReturnAllocs) {
894     SmallVector<Operation *> newOps;
895     failedAnalysis |=
896         failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps));
897   }
898 
899   // Gather all yielded tensors.
900   state.gatherYieldedTensors(op);
901 
902   // Analysis verification: After setting up alias/equivalence sets, each op
903   // can check for expected invariants/limitations and fail the analysis if
904   // necessary.
905   op->walk([&](Operation *op) {
906     if (BufferizableOpInterface bufferizableOp =
907             options.dynCastBufferizableOp(op))
908       failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state));
909   });
910 
911   // Annotate operations if we only want to report the analysis.
912   if (options.testAnalysisOnly)
913     annotateOpsWithBufferizationMarkers(op, aliasInfo, state);
914 
915   return success(!failedAnalysis);
916 }
917 
918 LogicalResult
919 bufferization::runOneShotBufferize(Operation *op,
920                                    const OneShotBufferizationOptions &options) {
921   OneShotAnalysisState state(op, options);
922   if (failed(analyzeOp(op, state)))
923     return failure();
924   if (options.testAnalysisOnly)
925     return success();
926   return bufferizeOp(op, state);
927 }
928