1 //===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // One-Shot Analysis analyzes function bodies. Function boundaries (FuncOp
10 // bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
11 // ModuleBufferization.cpp is an extension of One-Shot Analysis for simple
12 // call graphs.
13 //
14 // One-Shot Bufferize consists of two phases.
15 //
16 // 1. Analyze ops to decide which OpResults can bufferize inplace, i.e., without
17 //    inserting buffer copies. The analysis queries op bufferization semantics
18 //    via `BufferizableOpInterface`.
19 // 2. Bufferize ops by calling `BufferizableOpInterface::bufferize`. This
20 //    function does not generate buffer copies for OpResults that were decided
21 //    to bufferize inplace during the analysis phase.
22 //
23 // This file contains only the analysis. The actual bufferization is implemented
24 // via `bufferizeOp` (Bufferize.h). For convenience, this file also contains a
25 // helper function `runOneShotBufferize` that analyzes an op (and its nested
26 // ops) and then bufferizes it.
27 //
28 // Inplace bufferization decisions are passed from the analysis to the
29 // bufferization phase via `AnalysisState` and `BufferizationAliasInfo`.
30 // They can be printed for debugging purposes with `testAnalysisOnly`.
31 //
32 // Ops that do not implement `BufferizableOpInterface` can be analyzed but are
33 // treated conservatively. E.g., the analysis has to assume that their tensor
34 // OpOperands bufferize to memory writes. While such ops can be analyzed, they
35 // are not bufferized and remain in the IR. to_tensor and to_memref ops are
36 // inserted at the bufferization boundary.
37 //
38 // This analysis caters to high-performance codegen where buffer reuse is deemed
39 // critical: the analysis should fail if the bufferized form of the function
40 // needs to return a buffer, unless `allowReturnAllocs` is enabled.
41 
42 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
43 
44 #include <random>
45 
46 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
47 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
48 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
49 #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h"
50 #include "mlir/Dialect/Func/IR/FuncOps.h"
51 #include "mlir/Dialect/MemRef/IR/MemRef.h"
52 #include "mlir/IR/AsmState.h"
53 #include "mlir/IR/Dominance.h"
54 #include "mlir/IR/Operation.h"
55 #include "mlir/IR/TypeUtilities.h"
56 #include "mlir/Interfaces/ControlFlowInterfaces.h"
57 #include "llvm/ADT/DenseSet.h"
58 #include "llvm/ADT/SetVector.h"
59 
60 using namespace mlir;
61 using namespace mlir::bufferization;
62 
isaTensor(Type t)63 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
64 
65 //===----------------------------------------------------------------------===//
66 // Bufferization-specific attribute manipulation.
67 // These are for testing and debugging only. Bufferization information is
68 // stored in BufferizationAliasInfo. When run with `testAnalysisOnly`, the IR
69 // is annotated with the results of the analysis (copied from
70 // BufferizationAliasInfo), so that they can be checked in tests.
71 //===----------------------------------------------------------------------===//
72 
73 /// Attribute marker to specify op results that can be bufferized inPlace.
74 constexpr StringLiteral kInPlaceResultsAttrName = "__inplace_operands_attr__";
75 
76 /// Mark whether OpOperand will be bufferized inplace.
setInPlaceOpOperand(OpOperand & opOperand,bool inPlace)77 static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
78   Operation *op = opOperand.getOwner();
79   auto attr =
80       op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null<ArrayAttr>();
81   SmallVector<StringRef> inPlaceVector;
82   if (attr) {
83     inPlaceVector = SmallVector<StringRef>(
84         llvm::to_vector<4>(attr.getAsValueRange<StringAttr>()));
85   } else {
86     inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
87     for (OpOperand &opOperand : op->getOpOperands())
88       if (opOperand.get().getType().isa<TensorType>())
89         inPlaceVector[opOperand.getOperandNumber()] = "false";
90   }
91 
92   inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
93   op->setAttr(kInPlaceResultsAttrName,
94               OpBuilder(op).getStrArrayAttr(inPlaceVector));
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // BufferizationAliasInfo
99 //===----------------------------------------------------------------------===//
100 
BufferizationAliasInfo(Operation * rootOp)101 BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
102   rootOp->walk([&](Operation *op) {
103     for (Value v : op->getResults())
104       if (v.getType().isa<TensorType>())
105         createAliasInfoEntry(v);
106     for (Region &r : op->getRegions())
107       for (Block &b : r.getBlocks())
108         for (auto bbArg : b.getArguments())
109           if (bbArg.getType().isa<TensorType>())
110             createAliasInfoEntry(bbArg);
111   });
112 }
113 
114 /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
115 /// beginning the alias and equivalence sets only contain `v` itself.
createAliasInfoEntry(Value v)116 void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
117   aliasInfo.insert(v);
118   equivalentInfo.insert(v);
119 }
120 
121 /// Insert an info entry for `newValue` and merge its alias set with that of
122 /// `alias`.
insertNewBufferAlias(Value newValue,Value alias)123 void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
124   createAliasInfoEntry(newValue);
125   aliasInfo.unionSets(newValue, alias);
126 }
127 
128 /// Insert an info entry for `newValue` and merge its alias set with that of
129 /// `alias`. Additionally, merge their equivalence classes.
insertNewBufferEquivalence(Value newValue,Value alias)130 void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
131                                                         Value alias) {
132   insertNewBufferAlias(newValue, alias);
133   equivalentInfo.unionSets(newValue, alias);
134 }
135 
136 /// Return `true` if a value was marked as in-place bufferized.
isInPlace(OpOperand & operand) const137 bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
138   return inplaceBufferized.contains(&operand);
139 }
140 
141 /// Set the inPlace bufferization spec to true.
bufferizeInPlace(OpOperand & operand,AnalysisState & state)142 void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
143                                               AnalysisState &state) {
144   markInPlace(operand);
145   for (OpResult result : state.getAliasingOpResult(operand))
146     aliasInfo.unionSets(result, operand.get());
147 }
148 
149 /// Set the inPlace bufferization spec to false.
bufferizeOutOfPlace(OpOperand & operand)150 void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) {
151   assert(!inplaceBufferized.contains(&operand) &&
152          "OpOperand was already decided to bufferize inplace");
153 }
154 
155 /// Apply `fun` to all the members of the equivalence class of `v`.
applyOnEquivalenceClass(Value v,function_ref<void (Value)> fun) const156 void BufferizationAliasInfo::applyOnEquivalenceClass(
157     Value v, function_ref<void(Value)> fun) const {
158   auto leaderIt = equivalentInfo.findLeader(v);
159   for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
160        ++mit) {
161     fun(*mit);
162   }
163 }
164 
165 /// Apply `fun` to all aliases of `v`.
applyOnAliases(Value v,function_ref<void (Value)> fun) const166 void BufferizationAliasInfo::applyOnAliases(
167     Value v, function_ref<void(Value)> fun) const {
168   auto leaderIt = aliasInfo.findLeader(v);
169   for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
170     fun(*mit);
171   }
172 }
173 
174 BufferizationAliasInfo::EquivalenceClassRangeType
getAliases(Value v) const175 BufferizationAliasInfo::getAliases(Value v) const {
176   DenseSet<Value> res;
177   auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v));
178   for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
179        mit != meit; ++mit) {
180     res.insert(static_cast<Value>(*mit));
181   }
182   return BufferizationAliasInfo::EquivalenceClassRangeType(
183       aliasInfo.member_begin(it), aliasInfo.member_end());
184 }
185 
186 //===----------------------------------------------------------------------===//
187 // OneShotAnalysisState
188 //===----------------------------------------------------------------------===//
189 
OneShotAnalysisState(Operation * op,const OneShotBufferizationOptions & options)190 OneShotAnalysisState::OneShotAnalysisState(
191     Operation *op, const OneShotBufferizationOptions &options)
192     : AnalysisState(options), aliasInfo(op) {
193   // Set up alias sets for OpResults that must bufferize in-place. This should
194   // be done before making any other bufferization decisions.
195   op->walk([&](BufferizableOpInterface bufferizableOp) {
196     if (!options.isOpAllowed(bufferizableOp))
197       return WalkResult::skip();
198     for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
199       if (opOperand.get().getType().isa<TensorType>())
200         if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) {
201           for (OpResult opResult :
202                bufferizableOp.getAliasingOpResult(opOperand, *this))
203             aliasInfo.unionAliasSets(opOperand.get(), opResult);
204           aliasInfo.markInPlace(opOperand);
205         }
206     }
207     return WalkResult::advance();
208   });
209 }
210 
isInPlace(OpOperand & opOperand) const211 bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const {
212   return aliasInfo.isInPlace(opOperand);
213 }
214 
areEquivalentBufferizedValues(Value v1,Value v2) const215 bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
216                                                          Value v2) const {
217   return aliasInfo.areEquivalentBufferizedValues(v1, v2);
218 }
219 
areAliasingBufferizedValues(Value v1,Value v2) const220 bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1,
221                                                        Value v2) const {
222   return aliasInfo.areAliasingBufferizedValues(v1, v2);
223 }
224 
225 // Gather yielded tensors in `yieldedTensors` by querying all aliases. This is
226 // to ensure that such information is available during bufferization time.
227 // Alias information can no longer be queried through BufferizationAliasInfo
228 // once we have started modifying the IR.
gatherYieldedTensors(Operation * op)229 void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
230   op->walk([&](Operation *returnOp) {
231     if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp))
232       return WalkResult::advance();
233 
234     for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
235       Value returnVal = returnValOperand.get();
236       // Skip non-tensor values.
237       if (!returnVal.getType().isa<TensorType>())
238         continue;
239 
240       // Add all aliases of the returned value. But only the ones that are in
241       // the same block.
242       aliasInfo.applyOnAliases(returnVal, [&](Value v) {
243         if (auto bbArg = v.dyn_cast<BlockArgument>()) {
244           if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp())
245             yieldedTensors.insert(bbArg);
246           return;
247         }
248         Operation *definingOp = v.getDefiningOp();
249         if (definingOp->getParentOp() == returnOp->getParentOp())
250           yieldedTensors.insert(v);
251       });
252     }
253 
254     return WalkResult::advance();
255   });
256 }
257 
gatherUndefinedTensorUses(Operation * op)258 void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
259   op->walk([&](Operation *op) {
260     // Skip unknown ops.
261     auto bufferizableOp = getOptions().dynCastBufferizableOp(op);
262     if (!bufferizableOp)
263       return WalkResult::skip();
264 
265     // Check all tensor OpResults.
266     for (OpResult opResult : op->getOpResults()) {
267       if (!opResult.getType().isa<TensorType>())
268         continue;
269 
270       // If there is no preceding memory write, the tensor contents are
271       // undefined.
272       // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
273       // use-def chain, it returns that value, regardless of whether it is a
274       // memory write or not.
275       SetVector<Value> lastWrites = findLastPrecedingWrite(opResult);
276       bool isUndefined = llvm::none_of(lastWrites, [&](Value lastWrite) {
277         if (auto bufferizableOp = getOptions().dynCastBufferizableOp(lastWrite))
278           return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(),
279                                               *this);
280         return true;
281       });
282       if (isUndefined)
283         for (OpOperand &use : opResult.getUses())
284           undefinedTensorUses.insert(&use);
285     }
286 
287     return WalkResult::advance();
288   });
289 }
290 
hasUndefinedContents(OpOperand * opOperand) const291 bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
292   return undefinedTensorUses.contains(opOperand);
293 }
294 
isTensorYielded(Value tensor) const295 bool OneShotAnalysisState::isTensorYielded(Value tensor) const {
296   return yieldedTensors.contains(tensor);
297 }
298 
isValueWritten(Value value) const299 bool OneShotAnalysisState::isValueWritten(Value value) const {
300   bool isWritten = false;
301   aliasInfo.applyOnAliases(value, [&](Value val) {
302     for (OpOperand &use : val.getUses())
303       if (isInPlace(use) && bufferizesToMemoryWrite(use))
304         isWritten = true;
305   });
306   return isWritten;
307 }
308 
isWritable(Value value) const309 bool OneShotAnalysisState::isWritable(Value value) const {
310   // TODO: Out-of-place bufferized value could be considered writable.
311   if (auto bufferizableOp = getOptions().dynCastBufferizableOp(value))
312     return bufferizableOp.isWritable(value, *this);
313 
314   // Query BufferizableOpInterface to see if the BlockArgument is writable.
315   if (auto bbArg = value.dyn_cast<BlockArgument>())
316     if (auto bufferizableOp =
317             getOptions().dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
318       return bufferizableOp.isWritable(bbArg, *this);
319 
320   // Not a bufferizable op: The conservative answer is "not writable".
321   return false;
322 }
323 
324 //===----------------------------------------------------------------------===//
325 // Bufferization-specific alias analysis.
326 //===----------------------------------------------------------------------===//
327 
328 /// Return true if opOperand has been decided to bufferize in-place.
isInplaceMemoryWrite(OpOperand & opOperand,const BufferizationAliasInfo & aliasInfo,const AnalysisState & state)329 static bool isInplaceMemoryWrite(OpOperand &opOperand,
330                                  const BufferizationAliasInfo &aliasInfo,
331                                  const AnalysisState &state) {
332   // OpOperands that do not bufferize to a memory write do not write in-place.
333   if (!state.bufferizesToMemoryWrite(opOperand))
334     return false;
335   // Check current bufferization decisions.
336   return aliasInfo.isInPlace(opOperand);
337 }
338 
339 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
340 /// properly dominates `b` and `b` is not inside `a`.
happensBefore(Operation * a,Operation * b,const DominanceInfo & domInfo)341 static bool happensBefore(Operation *a, Operation *b,
342                           const DominanceInfo &domInfo) {
343   do {
344     // TODO: Instead of isProperAncestor + properlyDominates, we should use
345     // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
346     if (a->isProperAncestor(b))
347       return false;
348     if (domInfo.properlyDominates(a, b))
349       return true;
350   } while ((a = a->getParentOp()));
351   return false;
352 }
353 
354 /// For each given value, find the closest enclosing repetitive region. If this
355 /// is the same region for each value, return it. Otherwise return None.
356 /// Note: If there is no enclosing repetitive region, return nullptr.
357 static Optional<Region *>
getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values)358 getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values) {
359   if (values.empty())
360     return None;
361   Region *r = getEnclosingRepetitiveRegion(values.front());
362   for (Value value : values.drop_front())
363     if (getEnclosingRepetitiveRegion(value) != r)
364       return None;
365   return r;
366 }
367 
368 /// Return `true` if the given tensor value is a memory write. Most values are
369 /// tensor writes, but ops that define a tensor SSA value without specifying its
370 /// contents (e.g., alloc_tensor) are not.
isMemoryWrite(Value value,const AnalysisState & state)371 static bool isMemoryWrite(Value value, const AnalysisState &state) {
372   auto opResult = value.dyn_cast<OpResult>();
373   if (!opResult)
374     return true;
375   auto bufferizableOp = state.getOptions().dynCastBufferizableOp(value);
376   if (!bufferizableOp)
377     return true;
378   return bufferizableOp.isMemoryWrite(opResult, state);
379 }
380 
381 /// Annotate IR with details about the detected RaW conflict.
annotateConflict(OpOperand * uRead,OpOperand * uConflictingWrite,Value lastWrite)382 static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
383                              Value lastWrite) {
384   static uint64_t counter = 0;
385   Operation *readingOp = uRead->getOwner();
386   Operation *conflictingWritingOp = uConflictingWrite->getOwner();
387 
388   OpBuilder b(conflictingWritingOp->getContext());
389   std::string id = "C_" + std::to_string(counter++);
390 
391   std::string conflictingWriteAttr =
392       id +
393       "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) +
394       "]";
395   conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr());
396 
397   std::string readAttr =
398       id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
399   readingOp->setAttr(readAttr, b.getUnitAttr());
400 
401   if (auto opResult = lastWrite.dyn_cast<OpResult>()) {
402     std::string lastWriteAttr = id + "[LAST-WRITE: result " +
403                                 std::to_string(opResult.getResultNumber()) +
404                                 "]";
405     opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr());
406   } else {
407     auto bbArg = lastWrite.cast<BlockArgument>();
408     std::string lastWriteAttr =
409         id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
410     bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr());
411   }
412 }
413 
414 /// Given sets of uses and writes, return true if there is a RaW conflict under
415 /// the assumption that all given reads/writes alias the same buffer and that
416 /// all given writes bufferize inplace.
417 ///
418 /// A conflict is: According to SSA use-def chains, a read R is supposed to read
419 /// the result of a write W1. But because of bufferization decisions, R actually
420 /// reads another write W2.
hasReadAfterWriteInterference(const DenseSet<OpOperand * > & usesRead,const DenseSet<OpOperand * > & usesWrite,const DominanceInfo & domInfo,AnalysisState & state,const BufferizationAliasInfo & aliasInfo)421 static bool hasReadAfterWriteInterference(
422     const DenseSet<OpOperand *> &usesRead,
423     const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo,
424     AnalysisState &state, const BufferizationAliasInfo &aliasInfo) {
425   const BufferizationOptions &options = state.getOptions();
426 
427   // Gather all written aliases. Skip over aliases that are not actual writes.
428   SmallVector<Value> writtenAliases;
429   for (OpOperand *uWrite : usesWrite)
430     if (isMemoryWrite(uWrite->get(), state))
431       writtenAliases.push_back(uWrite->get());
432   // Find the inner-most enclosing repetitive region of each alias. If this is
433   // the same region for every alias, save it in `repetitiveRegionOfWrites`.
434   Optional<Region *> repetitiveRegionOfWrites =
435       getCommonEnclosingRepetitiveRegion(writtenAliases);
436 
437   for (OpOperand *uRead : usesRead) {
438     Operation *readingOp = uRead->getOwner();
439 
440     // Find most recent writes of uRead by following the SSA use-def chain.
441     // E.g.:
442     //
443     // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
444     // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
445     // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
446     //
447     // In the above example, if uRead is the OpOperand of reading_op, lastWrite
448     // is %0. Note that operations that create an alias but do not write (such
449     // as ExtractSliceOp) are skipped.
450     SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get());
451 
452     // Look for conflicting memory writes. Potential conflicts are writes to an
453     // alias that have been decided to bufferize inplace.
454     for (OpOperand *uConflictingWrite : usesWrite) {
455       // Throughout this loop, check for multiple requirements that have to be
456       // met for uConflictingWrite to be an actual conflict.
457       Operation *conflictingWritingOp = uConflictingWrite->getOwner();
458 
459       // Check if conflictingWritingOp is in the same repetitive region as all
460       // written aliases. If this is not the case, there is no meaningful
461       // `happensBefore` relationship because conflictingWritingOp may be
462       // executed multiple times. E.g.:
463       //
464       // %0 = ... : tensor<?xf32>
465       // scf.for ... {
466       //   "reading_op"(%0) : tensor<?xf32>
467       //   %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
468       //   ...
469       // }
470       //
471       // In the above example, reading_op happens before writing_op according to
472       // op dominance. However, both ops may happen multiple times; in
473       // particular, the second execution of reading_op happens after the first
474       // execution of writing_op. This is problematic if the tensor they operate
475       // on (%0) is defined outside of the loop.
476       //
477       // Counter example:
478       //
479       // scf.for ... {
480       //   %0 = ... : tensor<?xf32>
481       //   "reading_op"(%0) : tensor<?xf32>
482       //   %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
483       //   ...
484       // }
485       //
486       // In this example, %0 is in the same repetitive region as
487       // conflictingWritingOp, so op dominance can be used to compute the
488       // `happensBefore` relationship.
489       //
490       // Note: iter_args of loops are not aliases of their respective block
491       // arguments, so op domanice can be used when analyzing ops that operate
492       // on them.
493       //
494       // Note: If `writtenAliases` is empty, there are no memory writes outside
495       // of the repetitive region of conflictingWritingOp, which means that all
496       // relevant aliases are inside the same repetitive region.
497       bool canUseOpDominance =
498           writtenAliases.empty() ||
499           repetitiveRegionOfWrites ==
500               getEnclosingRepetitiveRegion(conflictingWritingOp);
501 
502       // No conflict if the readingOp dominates conflictingWritingOp, i.e., the
503       // write is not visible when reading.
504       //
505       // Note: If ops are executed multiple times (e.g., because they are inside
506       //       a loop), there may be no meaningful `happensBefore` relationship.
507       if (canUseOpDominance &&
508           happensBefore(readingOp, conflictingWritingOp, domInfo))
509         continue;
510 
511       // No conflict if the reading use equals the use of the conflicting write.
512       // A use cannot conflict with itself.
513       //
514       // Note: Just being the same op is not enough. It has to be the same use.
515       // Note: If the op is executed multiple times (e.g., because it is inside
516       //       a loop), it may be conflicting with itself.
517       if (canUseOpDominance && uConflictingWrite == uRead)
518         continue;
519 
520       // No conflict if the op interface says so.
521       if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp))
522         if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
523           continue;
524 
525       if (conflictingWritingOp != readingOp)
526         if (auto bufferizableOp =
527                 options.dynCastBufferizableOp(conflictingWritingOp))
528           if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
529             continue;
530 
531       // Ops are not conflicting if they are in mutually exclusive regions.
532       //
533       // Note: If ops are executed multiple times (e.g., because they are inside
534       //       a loop), mutually exclusive regions may be executed multiple
535       //       times.
536       if (canUseOpDominance &&
537           insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp))
538         continue;
539 
540       // Check all possible last writes.
541       for (Value lastWrite : lastWrites) {
542         // No conflict if the conflicting write happens before the last
543         // write.
544         if (Operation *writingOp = lastWrite.getDefiningOp()) {
545           if (happensBefore(conflictingWritingOp, writingOp, domInfo))
546             // conflictingWritingOp happens before writingOp. No conflict.
547             continue;
548           // No conflict if conflictingWritingOp is contained in writingOp.
549           if (writingOp->isProperAncestor(conflictingWritingOp))
550             continue;
551         } else {
552           auto bbArg = lastWrite.cast<BlockArgument>();
553           Block *block = bbArg.getOwner();
554           if (!block->findAncestorOpInBlock(*conflictingWritingOp))
555             // conflictingWritingOp happens outside of the block. No
556             // conflict.
557             continue;
558         }
559 
560         // No conflict if the conflicting write and the last write are the same
561         // use.
562         SmallVector<OpResult> aliasingOpResult =
563             state.getAliasingOpResult(*uConflictingWrite);
564         if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite)
565           continue;
566 
567         // All requirements are met. Conflict found!
568 
569         if (options.printConflicts)
570           annotateConflict(uRead, uConflictingWrite, lastWrite);
571 
572         return true;
573       }
574     }
575   }
576 
577   return false;
578 }
579 
580 // Helper function to iterate on aliases of `root` and capture the writes.
getAliasingInplaceWrites(DenseSet<OpOperand * > & res,Value root,const BufferizationAliasInfo & aliasInfo,const AnalysisState & state)581 static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root,
582                                      const BufferizationAliasInfo &aliasInfo,
583                                      const AnalysisState &state) {
584   aliasInfo.applyOnAliases(root, [&](Value alias) {
585     for (auto &use : alias.getUses())
586       // Inplace write to a value that aliases root.
587       if (isInplaceMemoryWrite(use, aliasInfo, state))
588         res.insert(&use);
589   });
590 }
591 
592 // Helper function to iterate on aliases of `root` and capture the reads.
getAliasingReads(DenseSet<OpOperand * > & res,Value root,const BufferizationAliasInfo & aliasInfo,const AnalysisState & state)593 static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
594                              const BufferizationAliasInfo &aliasInfo,
595                              const AnalysisState &state) {
596   aliasInfo.applyOnAliases(root, [&](Value alias) {
597     for (auto &use : alias.getUses())
598       // Read to a value that aliases root.
599       if (state.bufferizesToMemoryRead(use))
600         res.insert(&use);
601   });
602 }
603 
604 /// Return true if bufferizing `operand` inplace would create a conflict. A read
605 /// R and a write W of the same alias set is a conflict if inplace bufferization
606 /// of W changes the value read by R to a value different from the one that
607 /// would be expected by tracing back R's origin through SSA use-def chains.
608 /// A conflict can only be introduced by a new alias and/or an inplace
609 /// bufferization decision.
610 ///
611 /// Example:
612 /// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?}
613 /// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32>
614 /// %e = tensor.extract_slice %1
615 /// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32>
616 /// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32>
617 ///
618 /// In the above example, the two TransferWriteOps have already been decided to
619 /// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a
620 /// conflict because:
621 /// * According to SSA use-def chains, we expect to read the result of %1.
622 /// * However, adding an alias {%0, %t} would mean that the second
623 ///   TransferWriteOp overwrites the first one. Therefore, the TransferReadOp
624 ///   would no longer be reading the result of %1.
625 ///
626 /// If `checkConsistencyOnly` is true, this function checks if there is a
627 /// read-after-write conflict without bufferizing `operand` inplace. This would
628 /// indicate a problem with the current inplace bufferization decisions.
629 ///
630 /// Note: If `checkConsistencyOnly`, this function may be called with a null
631 /// OpResult. In that case, only the consistency of bufferization decisions
632 /// involving aliases of the given OpOperand are checked.
wouldCreateReadAfterWriteInterference(OpOperand & operand,const DominanceInfo & domInfo,AnalysisState & state,const BufferizationAliasInfo & aliasInfo,bool checkConsistencyOnly=false)633 static bool wouldCreateReadAfterWriteInterference(
634     OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state,
635     const BufferizationAliasInfo &aliasInfo,
636     bool checkConsistencyOnly = false) {
637   // Collect reads and writes of all aliases of OpOperand and OpResult.
638   DenseSet<OpOperand *> usesRead, usesWrite;
639   getAliasingReads(usesRead, operand.get(), aliasInfo, state);
640   getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state);
641   for (OpResult result : state.getAliasingOpResult(operand)) {
642     getAliasingReads(usesRead, result, aliasInfo, state);
643     getAliasingInplaceWrites(usesWrite, result, aliasInfo, state);
644   }
645   if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
646     usesWrite.insert(&operand);
647 
648   return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state,
649                                        aliasInfo);
650 }
651 
652 /// Check the reverse SSA use-def chain (following aliasing OpOperands) for
653 /// non-writable tensor values. Stop searching when an out-of-place bufferized
654 /// OpOperand was found (or when the OpOperand was not bufferized yet).
655 /// `currentOpOperand` is assumed to be in-place, even if that decision was not
656 /// materialized in `aliasInfo` yet.
657 static bool
hasPrecedingAliasingNonWritableTensor(Value value,OpOperand * currentOpOperand,const BufferizationAliasInfo & aliasInfo,const OneShotAnalysisState & state)658 hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand,
659                                       const BufferizationAliasInfo &aliasInfo,
660                                       const OneShotAnalysisState &state) {
661   SmallVector<Value> worklist;
662   worklist.push_back(value);
663   while (!worklist.empty()) {
664     Value nextVal = worklist.pop_back_val();
665     if (!state.isWritable(nextVal))
666       return true;
667 
668     // If `nextVal` is not a BlockArgument: End of use-def chain reached.
669     auto opResult = nextVal.dyn_cast<OpResult>();
670     if (!opResult)
671       continue;
672 
673     // Follow reverse SSA use-def chain.
674     SmallVector<OpOperand *> aliasingOpOperands =
675         state.getAliasingOpOperand(opResult);
676     for (OpOperand *opOperand : aliasingOpOperands)
677       if (aliasInfo.isInPlace(*opOperand) || currentOpOperand == opOperand)
678         worklist.push_back(opOperand->get());
679   }
680   return false;
681 }
682 
683 /// Return true if bufferizing `operand` inplace would create a write to a
684 /// non-writable buffer.
wouldCreateWriteToNonWritableBuffer(OpOperand & operand,const BufferizationAliasInfo & aliasInfo,OneShotAnalysisState & state,bool checkConsistencyOnly=false)685 static bool wouldCreateWriteToNonWritableBuffer(
686     OpOperand &operand, const BufferizationAliasInfo &aliasInfo,
687     OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
688   // Collect writes of all aliases of OpOperand and OpResult.
689   DenseSet<OpOperand *> usesWrite;
690   getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state);
691   for (OpResult result : state.getAliasingOpResult(operand)) {
692     getAliasingInplaceWrites(usesWrite, result, aliasInfo, state);
693   }
694   if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
695     usesWrite.insert(&operand);
696 
697   // Assuming that `operand` bufferizes in-place: For each write (to each
698   // alias), check if there is a non-writable tensor in the reverse SSA use-def
699   // chain.
700   for (OpOperand *uWrite : usesWrite)
701     if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand,
702                                               aliasInfo, state))
703       return true;
704 
705   return false;
706 }
707 
708 //===----------------------------------------------------------------------===//
709 // Bufferization analyses.
710 //===----------------------------------------------------------------------===//
711 
712 /// Determine if `operand` can be bufferized in-place.
bufferizableInPlaceAnalysisImpl(OpOperand & operand,BufferizationAliasInfo & aliasInfo,OneShotAnalysisState & state,const DominanceInfo & domInfo)713 static LogicalResult bufferizableInPlaceAnalysisImpl(
714     OpOperand &operand, BufferizationAliasInfo &aliasInfo,
715     OneShotAnalysisState &state, const DominanceInfo &domInfo) {
716   bool foundInterference =
717       wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) ||
718       wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo);
719 
720   if (foundInterference)
721     aliasInfo.bufferizeOutOfPlace(operand);
722   else
723     aliasInfo.bufferizeInPlace(operand, state);
724 
725   return success();
726 }
727 
728 /// Analyze the `ops` to determine which OpOperands are inplaceable. Walk ops in
729 /// reverse and bufferize ops greedily. This is a good starter heuristic.
730 ///
731 /// Even if an op does not read or write, it may still create an alias when
732 /// bufferized in-place. An example of such ops is tensor.extract_slice.
733 ///
734 /// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace:
735 ///
736 /// When bufferized out of place, an ExtractSliceOp lowers to alloc + copy. This
737 /// cannot change the flow of information for either the source or the
738 /// result buffers.
739 ///
740 /// When bufferized inplace, an ExtractSliceOp does not by itself create any
741 /// read or write from memory. Instead, it has the effect of merging the alias
742 /// sets of the source and the result buffers.
743 ///
744 /// An analysis is required to ensure inplace bufferization would not result in
745 /// RaW dependence violations.
inPlaceAnalysis(SmallVector<Operation * > & ops,BufferizationAliasInfo & aliasInfo,OneShotAnalysisState & state,const DominanceInfo & domInfo,unsigned analysisFuzzerSeed=0)746 static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
747                                      BufferizationAliasInfo &aliasInfo,
748                                      OneShotAnalysisState &state,
749                                      const DominanceInfo &domInfo,
750                                      unsigned analysisFuzzerSeed = 0) {
751   if (analysisFuzzerSeed) {
752     // This is a fuzzer. For testing purposes only. Randomize the order in which
753     // operations are analyzed. The bufferization quality is likely worse, but
754     // we want to make sure that no assertions are triggered anywhere.
755     std::mt19937 g(analysisFuzzerSeed);
756     llvm::shuffle(ops.begin(), ops.end(), g);
757   }
758 
759   // Walk ops in reverse for better interference analysis.
760   for (Operation *op : reverse(ops))
761     for (OpOperand &opOperand : op->getOpOperands())
762       if (opOperand.get().getType().isa<TensorType>())
763         if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
764           if (failed(bufferizableInPlaceAnalysisImpl(opOperand, aliasInfo,
765                                                      state, domInfo)))
766             return failure();
767 
768   return success();
769 }
770 
771 /// Return true if the given op has a tensor result or a tensor operand.
hasTensorSemantics(Operation * op)772 static bool hasTensorSemantics(Operation *op) {
773   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
774   bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
775   return hasTensorResult || hasTensorOperand;
776 }
777 
778 /// Analyze all ops that are contained in `op`.
inPlaceAnalysis(Operation * op,BufferizationAliasInfo & aliasInfo,OneShotAnalysisState & state,const DominanceInfo & domInfo,unsigned analysisFuzzerSeed=0)779 static LogicalResult inPlaceAnalysis(Operation *op,
780                                      BufferizationAliasInfo &aliasInfo,
781                                      OneShotAnalysisState &state,
782                                      const DominanceInfo &domInfo,
783                                      unsigned analysisFuzzerSeed = 0) {
784   // Collect ops so we can build our own reverse traversal.
785   SmallVector<Operation *> ops;
786   op->walk([&](Operation *op) {
787     // No tensors => no buffers.
788     if (!hasTensorSemantics(op))
789       return;
790     ops.push_back(op);
791   });
792 
793   return inPlaceAnalysis(ops, aliasInfo, state, domInfo, analysisFuzzerSeed);
794 }
795 
796 /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
equivalenceAnalysis(SmallVector<Operation * > & ops,BufferizationAliasInfo & aliasInfo,AnalysisState & state)797 static void equivalenceAnalysis(SmallVector<Operation *> &ops,
798                                 BufferizationAliasInfo &aliasInfo,
799                                 AnalysisState &state) {
800   for (Operation *op : ops)
801     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
802       for (OpResult opResult : op->getOpResults())
803         if (opResult.getType().isa<TensorType>())
804           for (OpOperand *opOperand :
805                bufferizableOp.getAliasingOpOperand(opResult, state))
806             if (state.isInPlace(*opOperand))
807               if (bufferizableOp.bufferRelation(opResult, state) ==
808                   BufferRelation::Equivalent)
809                 aliasInfo.unionEquivalenceClasses(opResult, opOperand->get());
810 }
811 
812 /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
813 /// in `op`.
equivalenceAnalysis(Operation * op,BufferizationAliasInfo & aliasInfo,AnalysisState & state)814 static void equivalenceAnalysis(Operation *op,
815                                 BufferizationAliasInfo &aliasInfo,
816                                 AnalysisState &state) {
817   // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
818   SmallVector<Operation *> ops;
819   op->walk<WalkOrder::PostOrder>([&](Operation *op) {
820     // No tensors => no buffers.
821     if (none_of(op->getResultTypes(), isaTensor))
822       return;
823     ops.push_back(op);
824   });
825 
826   equivalenceAnalysis(ops, aliasInfo, state);
827 }
828 
829 /// Assert that the current bufferization decisions are consistent.
830 static LogicalResult
checkAliasInfoConsistency(Operation * op,const DominanceInfo & domInfo,AnalysisState & state,const BufferizationAliasInfo & aliasInfo)831 checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
832                           AnalysisState &state,
833                           const BufferizationAliasInfo &aliasInfo) {
834   const BufferizationOptions &options = state.getOptions();
835   Operation *inconsistentOp = nullptr;
836   WalkResult walkResult = op->walk([&](Operation *op) {
837     if (auto bufferizableOp = options.dynCastBufferizableOp(op))
838       for (OpOperand &opOperand : op->getOpOperands())
839         if (opOperand.get().getType().isa<TensorType>()) {
840           if (wouldCreateReadAfterWriteInterference(
841                   opOperand, domInfo, state, aliasInfo,
842                   /*checkConsistencyOnly=*/true)) {
843             // This error can happen if certain "mustBufferizeInPlace" interface
844             // methods are implemented incorrectly, such that the IR already has
845             // a RaW conflict before making any bufferization decisions.
846             inconsistentOp = op;
847             return WalkResult::interrupt();
848           }
849         }
850     return WalkResult::advance();
851   });
852 
853   if (walkResult.wasInterrupted())
854     return inconsistentOp->emitError("input IR has RaW conflict");
855   return success();
856 }
857 
858 /// Annotate the IR with the result of the analysis. For testing/debugging only.
859 static void
annotateOpsWithBufferizationMarkers(Operation * op,const BufferizationAliasInfo & aliasInfo,AnalysisState & state)860 annotateOpsWithBufferizationMarkers(Operation *op,
861                                     const BufferizationAliasInfo &aliasInfo,
862                                     AnalysisState &state) {
863   op->walk([&](Operation *op) {
864     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
865       for (OpOperand &opOperand : op->getOpOperands())
866         if (opOperand.get().getType().isa<TensorType>())
867           setInPlaceOpOperand(opOperand, aliasInfo.isInPlace(opOperand));
868   });
869 }
870 
871 /// Assert that IR is in destination-passing style. I.e., every value that is
872 /// returned or yielded from a block is:
873 /// * aliasing a bbArg of that block or a parent block, or
874 /// * aliasing an OpResult of a op in a parent block.
875 ///
876 /// Example:
877 /// ```
878 /// %0 = "some_op" : tensor<?xf32>
879 /// %1 = scf.if %c -> (tensor<?xf32>) {
880 ///   scf.yield %0 : tensor<?xf32>
881 /// } else {
882 ///   %t = linalg.alloc_tensor : tensor<?xf32>
883 ///   scf.yield %t : tensor<?xf32>
884 /// }
885 /// ```
886 /// In the above example, the first scf.yield op satifies destination-passing
887 /// style because the yielded value %0 is defined in the parent block. The
888 /// second scf.yield op does not satisfy destination-passing style because the
889 /// yielded value %t is defined in the same block as the scf.yield op.
890 // TODO: The current implementation checks for equivalent values instead of
891 // aliasing values, which is stricter than needed. We can currently not check
892 // for aliasing values because the analysis is a maybe-alias analysis and we
893 // need a must-alias analysis here.
894 static LogicalResult
assertDestinationPassingStyle(Operation * op,AnalysisState & state,BufferizationAliasInfo & aliasInfo,SmallVector<Operation * > & newOps)895 assertDestinationPassingStyle(Operation *op, AnalysisState &state,
896                               BufferizationAliasInfo &aliasInfo,
897                               SmallVector<Operation *> &newOps) {
898   LogicalResult status = success();
899   DominanceInfo domInfo(op);
900   op->walk([&](Operation *returnOp) {
901     if (!isRegionReturnLike(returnOp) ||
902         !state.getOptions().isOpAllowed(returnOp))
903       return WalkResult::advance();
904 
905     for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
906       Value returnVal = returnValOperand.get();
907       // Skip non-tensor values.
908       if (!returnVal.getType().isa<TensorType>())
909         continue;
910 
911       bool foundEquivValue = false;
912       aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
913         if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
914           Operation *definingOp = bbArg.getOwner()->getParentOp();
915           if (definingOp->isProperAncestor(returnOp))
916             foundEquivValue = true;
917           return;
918         }
919 
920         Operation *definingOp = equivVal.getDefiningOp();
921         if (definingOp->getBlock()->findAncestorOpInBlock(
922                 *returnOp->getParentOp()))
923           // Skip ops that happen after `returnOp` and parent ops.
924           if (happensBefore(definingOp, returnOp, domInfo))
925             foundEquivValue = true;
926       });
927 
928       if (!foundEquivValue)
929         status =
930             returnOp->emitError()
931             << "operand #" << returnValOperand.getOperandNumber()
932             << " of ReturnLike op does not satisfy destination passing style";
933     }
934 
935     return WalkResult::advance();
936   });
937 
938   return status;
939 }
940 
analyzeOp(Operation * op,OneShotAnalysisState & state)941 LogicalResult bufferization::analyzeOp(Operation *op,
942                                        OneShotAnalysisState &state) {
943   DominanceInfo domInfo(op);
944   BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
945   const auto &options =
946       static_cast<const OneShotBufferizationOptions &>(state.getOptions());
947 
948   // Catch incorrect API usage.
949   assert((state.hasDialectState(func::FuncDialect::getDialectNamespace()) ||
950           !options.bufferizeFunctionBoundaries) &&
951          "must use ModuleBufferize to bufferize function boundaries");
952 
953   if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
954     return failure();
955 
956   // If the analysis fails, just return.
957   if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo,
958                              options.analysisFuzzerSeed)))
959     return failure();
960   equivalenceAnalysis(op, aliasInfo, state);
961 
962   bool failedAnalysis = false;
963   if (!options.allowReturnAllocs) {
964     SmallVector<Operation *> newOps;
965     failedAnalysis |=
966         failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps));
967   }
968 
969   // Gather some extra analysis data.
970   state.gatherYieldedTensors(op);
971   state.gatherUndefinedTensorUses(op);
972 
973   // Analysis verification: After setting up alias/equivalence sets, each op
974   // can check for expected invariants/limitations and fail the analysis if
975   // necessary.
976   op->walk([&](Operation *op) {
977     if (BufferizableOpInterface bufferizableOp =
978             options.dynCastBufferizableOp(op))
979       failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state));
980   });
981 
982   // Annotate operations if we only want to report the analysis.
983   if (options.testAnalysisOnly)
984     annotateOpsWithBufferizationMarkers(op, aliasInfo, state);
985 
986   return success(!failedAnalysis);
987 }
988 
989 LogicalResult
runOneShotBufferize(Operation * op,const OneShotBufferizationOptions & options)990 bufferization::runOneShotBufferize(Operation *op,
991                                    const OneShotBufferizationOptions &options) {
992   OneShotAnalysisState state(op, options);
993   if (failed(insertTensorCopies(op, options)))
994     return failure();
995   if (options.testAnalysisOnly)
996     return success();
997   return bufferizeOp(op, options, /*copyBeforeWrite=*/false);
998 }
999