1 //===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // One-Shot Analysis analyzes function bodies. Function boundaries (FuncOp
10 // bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
11 // ModuleBufferization.cpp is an extension of One-Shot Analysis for simple
12 // call graphs.
13 //
14 // One-Shot Bufferize consists of two phases.
15 //
16 // 1. Analyze ops to decide which OpResults can bufferize inplace, i.e., without
17 //    inserting buffer copies. The analysis queries op bufferization semantics
18 //    via `BufferizableOpInterface`.
19 // 2. Bufferize ops by calling `BufferizableOpInterface::bufferize`. This
20 //    function does not generate buffer copies for OpResults that were decided
21 //    to bufferize inplace during the analysis phase.
22 //
23 // This file contains only the analysis. The actual bufferization is implemented
24 // via `bufferizeOp` (Bufferize.h). For convenience, this file also contains a
25 // helper function `runOneShotBufferize` that analyzes an op (and its nested
26 // ops) and then bufferizes it.
27 //
28 // Inplace bufferization decisions are passed from the analysis to the
29 // bufferization phase via `AnalysisState` and `BufferizationAliasInfo`.
30 // They can be printed for debugging purposes with `testAnalysisOnly`.
31 //
32 // Ops that do not implement `BufferizableOpInterface` can be analyzed but are
33 // treated conservatively. E.g., the analysis has to assume that their tensor
34 // OpOperands bufferize to memory writes. While such ops can be analyzed, they
35 // are not bufferized and remain in the IR. to_tensor and to_memref ops are
36 // inserted at the bufferization boundary.
37 //
38 // This analysis caters to high-performance codegen where buffer reuse is deemed
39 // critical: the analysis should fail if the bufferized form of the function
40 // needs to return a buffer, unless `allowReturnAllocs` is enabled.
41 
42 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
43 
44 #include <random>
45 
46 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
47 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
48 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
49 #include "mlir/Dialect/MemRef/IR/MemRef.h"
50 #include "mlir/IR/AsmState.h"
51 #include "mlir/IR/Dominance.h"
52 #include "mlir/IR/Operation.h"
53 #include "mlir/IR/TypeUtilities.h"
54 #include "mlir/Interfaces/ControlFlowInterfaces.h"
55 #include "llvm/ADT/DenseSet.h"
56 #include "llvm/ADT/SetVector.h"
57 
58 using namespace mlir;
59 using namespace mlir::bufferization;
60 
61 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
62 
63 //===----------------------------------------------------------------------===//
64 // Bufferization-specific attribute manipulation.
65 // These are for testing and debugging only. Bufferization information is
66 // stored in BufferizationAliasInfo. When run with `testAnalysisOnly`, the IR
67 // is annotated with the results of the analysis (copied from
68 // BufferizationAliasInfo), so that they can be checked in tests.
69 //===----------------------------------------------------------------------===//
70 
71 /// Attribute marker to specify op results that can be bufferized inPlace.
72 constexpr StringLiteral kInPlaceResultsAttrName = "__inplace_operands_attr__";
73 
74 /// Mark whether OpOperand will be bufferized inplace.
75 static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
76   Operation *op = opOperand.getOwner();
77   auto attr =
78       op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null<ArrayAttr>();
79   SmallVector<StringRef> inPlaceVector;
80   if (attr) {
81     inPlaceVector = SmallVector<StringRef>(
82         llvm::to_vector<4>(attr.getAsValueRange<StringAttr>()));
83   } else {
84     inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
85     for (OpOperand &opOperand : op->getOpOperands())
86       if (opOperand.get().getType().isa<TensorType>())
87         inPlaceVector[opOperand.getOperandNumber()] = "false";
88   }
89 
90   inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
91   op->setAttr(kInPlaceResultsAttrName,
92               OpBuilder(op).getStrArrayAttr(inPlaceVector));
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // BufferizationAliasInfo
97 //===----------------------------------------------------------------------===//
98 
99 BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
100   rootOp->walk([&](Operation *op) {
101     for (Value v : op->getResults())
102       if (v.getType().isa<TensorType>())
103         createAliasInfoEntry(v);
104     for (Region &r : op->getRegions())
105       for (Block &b : r.getBlocks())
106         for (auto bbArg : b.getArguments())
107           if (bbArg.getType().isa<TensorType>())
108             createAliasInfoEntry(bbArg);
109   });
110 }
111 
112 /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
113 /// beginning the alias and equivalence sets only contain `v` itself.
114 void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
115   aliasInfo.insert(v);
116   equivalentInfo.insert(v);
117 }
118 
119 /// Insert an info entry for `newValue` and merge its alias set with that of
120 /// `alias`.
121 void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
122   createAliasInfoEntry(newValue);
123   aliasInfo.unionSets(newValue, alias);
124 }
125 
126 /// Insert an info entry for `newValue` and merge its alias set with that of
127 /// `alias`. Additionally, merge their equivalence classes.
128 void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
129                                                         Value alias) {
130   insertNewBufferAlias(newValue, alias);
131   equivalentInfo.unionSets(newValue, alias);
132 }
133 
134 /// Return `true` if a value was marked as in-place bufferized.
135 bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
136   return inplaceBufferized.contains(&operand);
137 }
138 
139 /// Set the inPlace bufferization spec to true.
140 void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
141                                               AnalysisState &state) {
142   markInPlace(operand);
143   for (OpResult result : state.getAliasingOpResult(operand))
144     aliasInfo.unionSets(result, operand.get());
145 }
146 
147 /// Set the inPlace bufferization spec to false.
148 void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) {
149   assert(!inplaceBufferized.contains(&operand) &&
150          "OpOperand was already decided to bufferize inplace");
151 }
152 
153 /// Apply `fun` to all the members of the equivalence class of `v`.
154 void BufferizationAliasInfo::applyOnEquivalenceClass(
155     Value v, function_ref<void(Value)> fun) const {
156   auto leaderIt = equivalentInfo.findLeader(v);
157   for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
158        ++mit) {
159     fun(*mit);
160   }
161 }
162 
163 /// Apply `fun` to all aliases of `v`.
164 void BufferizationAliasInfo::applyOnAliases(
165     Value v, function_ref<void(Value)> fun) const {
166   auto leaderIt = aliasInfo.findLeader(v);
167   for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
168     fun(*mit);
169   }
170 }
171 
172 BufferizationAliasInfo::EquivalenceClassRangeType
173 BufferizationAliasInfo::getAliases(Value v) const {
174   DenseSet<Value> res;
175   auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v));
176   for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
177        mit != meit; ++mit) {
178     res.insert(static_cast<Value>(*mit));
179   }
180   return BufferizationAliasInfo::EquivalenceClassRangeType(
181       aliasInfo.member_begin(it), aliasInfo.member_end());
182 }
183 
184 //===----------------------------------------------------------------------===//
185 // OneShotAnalysisState
186 //===----------------------------------------------------------------------===//
187 
188 OneShotAnalysisState::OneShotAnalysisState(
189     Operation *op, const OneShotBufferizationOptions &options)
190     : AnalysisState(options), aliasInfo(op) {
191   // Set up alias sets for OpResults that must bufferize in-place. This should
192   // be done before making any other bufferization decisions.
193   op->walk([&](BufferizableOpInterface bufferizableOp) {
194     if (!options.isOpAllowed(bufferizableOp))
195       return WalkResult::skip();
196     for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
197       if (opOperand.get().getType().isa<TensorType>())
198         if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) {
199           for (OpResult opResult :
200                bufferizableOp.getAliasingOpResult(opOperand, *this))
201             aliasInfo.unionAliasSets(opOperand.get(), opResult);
202           aliasInfo.markInPlace(opOperand);
203         }
204     }
205     return WalkResult::advance();
206   });
207 }
208 
209 bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const {
210   return aliasInfo.isInPlace(opOperand);
211 }
212 
213 bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
214                                                          Value v2) const {
215   return aliasInfo.areEquivalentBufferizedValues(v1, v2);
216 }
217 
218 // Gather yielded tensors in `yieldedTensors` by querying all aliases. This is
219 // to ensure that such information is available during bufferization time.
220 // Alias information can no longer be queried through BufferizationAliasInfo
221 // once we have started modifying the IR.
222 void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
223   op->walk([&](Operation *returnOp) {
224     if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp))
225       return WalkResult::advance();
226 
227     for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
228       Value returnVal = returnValOperand.get();
229       // Skip non-tensor values.
230       if (!returnVal.getType().isa<TensorType>())
231         continue;
232 
233       // Add all aliases of the returned value. But only the ones that are in
234       // the same block.
235       aliasInfo.applyOnAliases(returnVal, [&](Value v) {
236         if (auto bbArg = v.dyn_cast<BlockArgument>()) {
237           if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp())
238             yieldedTensors.insert(bbArg);
239           return;
240         }
241         Operation *definingOp = v.getDefiningOp();
242         if (definingOp->getParentOp() == returnOp->getParentOp())
243           yieldedTensors.insert(v);
244       });
245     }
246 
247     return WalkResult::advance();
248   });
249 }
250 
251 bool OneShotAnalysisState::isTensorYielded(Value tensor) const {
252   return yieldedTensors.contains(tensor);
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // Bufferization-specific alias analysis.
257 //===----------------------------------------------------------------------===//
258 
259 /// Return true if opOperand has been decided to bufferize in-place.
260 static bool isInplaceMemoryWrite(OpOperand &opOperand,
261                                  const BufferizationAliasInfo &aliasInfo,
262                                  AnalysisState &state) {
263   // OpOperands that do not bufferize to a memory write do not write in-place.
264   if (!state.bufferizesToMemoryWrite(opOperand))
265     return false;
266   // Check current bufferization decisions.
267   return aliasInfo.isInPlace(opOperand);
268 }
269 
270 /// Return true if, under current bufferization decisions, the buffer of `value`
271 /// is not writable.
272 static bool aliasesNonWritableBuffer(Value value,
273                                      const BufferizationAliasInfo &aliasInfo,
274                                      AnalysisState &state) {
275   bool foundNonWritableBuffer = false;
276   aliasInfo.applyOnAliases(value, [&](Value v) {
277     // Query BufferizableOpInterface to see if the value is writable.
278     // TODO: Out-of-place bufferized value could be considered writable.
279     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(v))
280       if (bufferizableOp && bufferizableOp.isWritable(v, state))
281         return;
282 
283     // Query BufferizableOpInterface to see if the BlockArgument is writable.
284     if (auto bbArg = v.dyn_cast<BlockArgument>())
285       if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(
286               bbArg.getOwner()->getParentOp()))
287         if (bufferizableOp.isWritable(bbArg, state))
288           return;
289 
290     foundNonWritableBuffer = true;
291   });
292 
293   return foundNonWritableBuffer;
294 }
295 
296 /// Return true if the buffer to which `operand` would bufferize is equivalent
297 /// to some buffer write.
298 static bool aliasesInPlaceWrite(Value value,
299                                 const BufferizationAliasInfo &aliasInfo,
300                                 AnalysisState &state) {
301   bool foundInplaceWrite = false;
302   aliasInfo.applyOnAliases(value, [&](Value v) {
303     for (auto &use : v.getUses()) {
304       if (isInplaceMemoryWrite(use, aliasInfo, state)) {
305         foundInplaceWrite = true;
306         return;
307       }
308     }
309   });
310   return foundInplaceWrite;
311 }
312 
313 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
314 /// properly dominates `b` and `b` is not inside `a`.
315 static bool happensBefore(Operation *a, Operation *b,
316                           const DominanceInfo &domInfo) {
317   do {
318     // TODO: Instead of isProperAncestor + properlyDominates, we should use
319     // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
320     if (a->isProperAncestor(b))
321       return false;
322     if (domInfo.properlyDominates(a, b))
323       return true;
324   } while ((a = a->getParentOp()));
325   return false;
326 }
327 
328 /// Annotate IR with details about the detected RaW conflict.
329 static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
330                              Value lastWrite) {
331   static uint64_t counter = 0;
332   Operation *readingOp = uRead->getOwner();
333   Operation *conflictingWritingOp = uConflictingWrite->getOwner();
334 
335   OpBuilder b(conflictingWritingOp->getContext());
336   std::string id = "C_" + std::to_string(counter++);
337 
338   std::string conflictingWriteAttr =
339       id +
340       "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) +
341       "]";
342   conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr());
343 
344   std::string readAttr =
345       id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
346   readingOp->setAttr(readAttr, b.getUnitAttr());
347 
348   if (auto opResult = lastWrite.dyn_cast<OpResult>()) {
349     std::string lastWriteAttr = id + "[LAST-WRITE: result " +
350                                 std::to_string(opResult.getResultNumber()) +
351                                 "]";
352     opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr());
353   } else {
354     auto bbArg = lastWrite.cast<BlockArgument>();
355     std::string lastWriteAttr =
356         id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
357     bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr());
358   }
359 }
360 
361 /// Given sets of uses and writes, return true if there is a RaW conflict under
362 /// the assumption that all given reads/writes alias the same buffer and that
363 /// all given writes bufferize inplace.
364 ///
365 /// A conflict is: According to SSA use-def chains, a read R is supposed to read
366 /// the result of a write W1. But because of bufferization decisions, R actually
367 /// reads another write W2.
368 static bool hasReadAfterWriteInterference(
369     const DenseSet<OpOperand *> &usesRead,
370     const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo,
371     AnalysisState &state, const BufferizationAliasInfo &aliasInfo) {
372   const BufferizationOptions &options = state.getOptions();
373 
374   for (OpOperand *uRead : usesRead) {
375     Operation *readingOp = uRead->getOwner();
376 
377     // Find most recent writes of uRead by following the SSA use-def chain.
378     // E.g.:
379     //
380     // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
381     // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
382     // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
383     //
384     // In the above example, if uRead is the OpOperand of reading_op, lastWrite
385     // is %0. Note that operations that create an alias but do not write (such
386     // as ExtractSliceOp) are skipped.
387     SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get());
388 
389     // Look for conflicting memory writes. Potential conflicts are writes to an
390     // alias that have been decided to bufferize inplace.
391     for (OpOperand *uConflictingWrite : usesWrite) {
392       // Throughout this loop, check for multiple requirements that have to be
393       // met for uConflictingWrite to be an actual conflict.
394       Operation *conflictingWritingOp = uConflictingWrite->getOwner();
395 
396       // No conflict if the readingOp dominates conflictingWritingOp, i.e., the
397       // write is not visible when reading.
398       if (happensBefore(readingOp, conflictingWritingOp, domInfo))
399         continue;
400 
401       // No conflict if the reading use equals the use of the conflicting write.
402       // A use cannot conflict with itself. Note: Just being the same op is not
403       // enough. It has to be the same use.
404       if (uConflictingWrite == uRead)
405         continue;
406 
407       // No conflict if the op interface says so.
408       if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp))
409         if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
410           continue;
411 
412       if (conflictingWritingOp != readingOp)
413         if (auto bufferizableOp =
414                 options.dynCastBufferizableOp(conflictingWritingOp))
415           if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
416             continue;
417 
418       // Ops are not conflicting if they are in mutually exclusive regions.
419       if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp))
420         continue;
421 
422       // Check all possible last writes.
423       for (Value lastWrite : lastWrites) {
424         // No conflict if the conflicting write happens before the last
425         // write.
426         if (Operation *writingOp = lastWrite.getDefiningOp()) {
427           if (happensBefore(conflictingWritingOp, writingOp, domInfo))
428             // conflictingWritingOp happens before writingOp. No conflict.
429             continue;
430           // No conflict if conflictingWritingOp is contained in writingOp.
431           if (writingOp->isProperAncestor(conflictingWritingOp))
432             continue;
433         } else {
434           auto bbArg = lastWrite.cast<BlockArgument>();
435           Block *block = bbArg.getOwner();
436           if (!block->findAncestorOpInBlock(*conflictingWritingOp))
437             // conflictingWritingOp happens outside of the block. No
438             // conflict.
439             continue;
440         }
441 
442         // No conflict if the conflicting write and the last write are the same
443         // use.
444         SmallVector<OpResult> aliasingOpResult =
445             state.getAliasingOpResult(*uConflictingWrite);
446         if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite)
447           continue;
448 
449         // All requirements are met. Conflict found!
450 
451         if (options.printConflicts)
452           annotateConflict(uRead, uConflictingWrite, lastWrite);
453 
454         return true;
455       }
456     }
457   }
458 
459   return false;
460 }
461 
462 /// Return true if bufferizing `operand` inplace would create a conflict. A read
463 /// R and a write W of the same alias set is a conflict if inplace bufferization
464 /// of W changes the value read by R to a value different from the one that
465 /// would be expected by tracing back R's origin through SSA use-def chains.
466 /// A conflict can only be introduced by a new alias and/or an inplace
467 /// bufferization decision.
468 ///
469 /// Example:
470 /// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?}
471 /// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32>
472 /// %e = tensor.extract_slice %1
473 /// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32>
474 /// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32>
475 ///
476 /// In the above example, the two TransferWriteOps have already been decided to
477 /// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a
478 /// conflict because:
479 /// * According to SSA use-def chains, we expect to read the result of %1.
480 /// * However, adding an alias {%0, %t} would mean that the second
481 ///   TransferWriteOp overwrites the first one. Therefore, the TransferReadOp
482 ///   would no longer be reading the result of %1.
483 ///
484 /// If `checkConsistencyOnly` is true, this function checks if there is a
485 /// read-after-write conflict without bufferizing `operand` inplace. This would
486 /// indicate a problem with the current inplace bufferization decisions.
487 ///
488 /// Note: If `checkConsistencyOnly`, this function may be called with a null
489 /// OpResult. In that case, only the consistency of bufferization decisions
490 /// involving aliases of the given OpOperand are checked.
491 static bool wouldCreateReadAfterWriteInterference(
492     OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state,
493     const BufferizationAliasInfo &aliasInfo,
494     bool checkConsistencyOnly = false) {
495   // Helper function to iterate on aliases of `root` and capture the reads.
496   auto getAliasingReads = [&](DenseSet<OpOperand *> &res, Value root) {
497     aliasInfo.applyOnAliases(root, [&](Value alias) {
498       for (auto &use : alias.getUses())
499         // Read to a value that aliases root.
500         if (state.bufferizesToMemoryRead(use))
501           res.insert(&use);
502     });
503   };
504 
505   // Helper function to iterate on aliases of `root` and capture the writes.
506   auto getAliasingInplaceWrites = [&](DenseSet<OpOperand *> &res, Value root) {
507     aliasInfo.applyOnAliases(root, [&](Value alias) {
508       for (auto &use : alias.getUses())
509         // Inplace write to a value that aliases root.
510         if (isInplaceMemoryWrite(use, aliasInfo, state))
511           res.insert(&use);
512     });
513   };
514 
515   // Collect reads and writes of all aliases of OpOperand and OpResult.
516   DenseSet<OpOperand *> usesRead, usesWrite;
517   getAliasingReads(usesRead, operand.get());
518   getAliasingInplaceWrites(usesWrite, operand.get());
519   for (OpResult result : state.getAliasingOpResult(operand)) {
520     getAliasingReads(usesRead, result);
521     getAliasingInplaceWrites(usesWrite, result);
522   }
523   if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
524     usesWrite.insert(&operand);
525 
526   return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state,
527                                        aliasInfo);
528 }
529 
530 /// Return true if bufferizing `opOperand` inplace would create a write to a
531 /// non-writable buffer.
532 static bool
533 wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand,
534                                     const BufferizationAliasInfo &aliasInfo,
535                                     AnalysisState &state) {
536   // Certain buffers are not writeable:
537   //   1. A function bbArg that is not inplaceable or
538   //   2. A constant op.
539   bool nonWritable =
540       aliasesNonWritableBuffer(opOperand.get(), aliasInfo, state);
541   if (!nonWritable)
542     return false;
543 
544   // This is a problem only if the buffer is written to via some alias.
545   bool hasWrite = aliasesInPlaceWrite(opOperand.get(), aliasInfo, state) ||
546                   state.bufferizesToMemoryWrite(opOperand);
547 
548   for (OpResult opResult : state.getAliasingOpResult(opOperand))
549     hasWrite |= aliasesInPlaceWrite(opResult, aliasInfo, state);
550 
551   return hasWrite;
552 }
553 
554 //===----------------------------------------------------------------------===//
555 // Bufferization analyses.
556 //===----------------------------------------------------------------------===//
557 
558 /// Determine if `operand` can be bufferized in-place.
559 static LogicalResult bufferizableInPlaceAnalysisImpl(
560     OpOperand &operand, BufferizationAliasInfo &aliasInfo, AnalysisState &state,
561     const DominanceInfo &domInfo) {
562   bool foundInterference =
563       wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) ||
564       wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo);
565 
566   if (foundInterference)
567     aliasInfo.bufferizeOutOfPlace(operand);
568   else
569     aliasInfo.bufferizeInPlace(operand, state);
570 
571   return success();
572 }
573 
574 /// Analyze the `ops` to determine which OpOperands are inplaceable. Walk ops in
575 /// reverse and bufferize ops greedily. This is a good starter heuristic.
576 ///
577 /// Even if an op does not read or write, it may still create an alias when
578 /// bufferized in-place. An example of such ops is tensor.extract_slice.
579 ///
580 /// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace:
581 ///
582 /// When bufferized out of place, an ExtractSliceOp lowers to alloc + copy. This
583 /// cannot change the flow of information for either the source or the
584 /// result buffers.
585 ///
586 /// When bufferized inplace, an ExtractSliceOp does not by itself create any
587 /// read or write from memory. Instead, it has the effect of merging the alias
588 /// sets of the source and the result buffers.
589 ///
590 /// An analysis is required to ensure inplace bufferization would not result in
591 /// RaW dependence violations.
592 static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
593                                      BufferizationAliasInfo &aliasInfo,
594                                      AnalysisState &state,
595                                      const DominanceInfo &domInfo,
596                                      unsigned analysisFuzzerSeed = 0) {
597   if (analysisFuzzerSeed) {
598     // This is a fuzzer. For testing purposes only. Randomize the order in which
599     // operations are analyzed. The bufferization quality is likely worse, but
600     // we want to make sure that no assertions are triggered anywhere.
601     std::mt19937 g(analysisFuzzerSeed);
602     llvm::shuffle(ops.begin(), ops.end(), g);
603   }
604 
605   // Walk ops in reverse for better interference analysis.
606   for (Operation *op : reverse(ops))
607     for (OpOperand &opOperand : op->getOpOperands())
608       if (opOperand.get().getType().isa<TensorType>())
609         if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
610           if (failed(bufferizableInPlaceAnalysisImpl(opOperand, aliasInfo,
611                                                      state, domInfo)))
612             return failure();
613 
614   return success();
615 }
616 
617 /// Return true if the given op has a tensor result or a tensor operand.
618 static bool hasTensorSemantics(Operation *op) {
619   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
620   bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
621   return hasTensorResult || hasTensorOperand;
622 }
623 
624 /// Analyze all ops that are contained in `op`.
625 static LogicalResult inPlaceAnalysis(Operation *op,
626                                      BufferizationAliasInfo &aliasInfo,
627                                      AnalysisState &state,
628                                      const DominanceInfo &domInfo,
629                                      unsigned analysisFuzzerSeed = 0) {
630   // Collect ops so we can build our own reverse traversal.
631   SmallVector<Operation *> ops;
632   op->walk([&](Operation *op) {
633     // No tensors => no buffers.
634     if (!hasTensorSemantics(op))
635       return;
636     ops.push_back(op);
637   });
638 
639   return inPlaceAnalysis(ops, aliasInfo, state, domInfo, analysisFuzzerSeed);
640 }
641 
642 /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
643 static void equivalenceAnalysis(SmallVector<Operation *> &ops,
644                                 BufferizationAliasInfo &aliasInfo,
645                                 AnalysisState &state) {
646   for (Operation *op : ops)
647     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
648       for (OpResult opResult : op->getOpResults())
649         if (opResult.getType().isa<TensorType>())
650           for (OpOperand *opOperand :
651                bufferizableOp.getAliasingOpOperand(opResult, state))
652             if (state.isInPlace(*opOperand))
653               if (bufferizableOp.bufferRelation(opResult, state) ==
654                   BufferRelation::Equivalent)
655                 aliasInfo.unionEquivalenceClasses(opResult, opOperand->get());
656 }
657 
658 /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
659 /// in `op`.
660 static void equivalenceAnalysis(Operation *op,
661                                 BufferizationAliasInfo &aliasInfo,
662                                 AnalysisState &state) {
663   // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
664   SmallVector<Operation *> ops;
665   op->walk<WalkOrder::PostOrder>([&](Operation *op) {
666     // No tensors => no buffers.
667     if (none_of(op->getResultTypes(), isaTensor))
668       return;
669     ops.push_back(op);
670   });
671 
672   equivalenceAnalysis(ops, aliasInfo, state);
673 }
674 
675 /// Assert that the current bufferization decisions are consistent.
676 static LogicalResult
677 checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
678                           AnalysisState &state,
679                           const BufferizationAliasInfo &aliasInfo) {
680   const BufferizationOptions &options = state.getOptions();
681   Operation *inconsistentOp = nullptr;
682   WalkResult walkResult = op->walk([&](Operation *op) {
683     if (auto bufferizableOp = options.dynCastBufferizableOp(op))
684       for (OpOperand &opOperand : op->getOpOperands())
685         if (opOperand.get().getType().isa<TensorType>()) {
686           if (wouldCreateReadAfterWriteInterference(
687                   opOperand, domInfo, state, aliasInfo,
688                   /*checkConsistencyOnly=*/true)) {
689             // This error can happen if certain "mustBufferizeInPlace" interface
690             // methods are implemented incorrectly, such that the IR already has
691             // a RaW conflict before making any bufferization decisions.
692             inconsistentOp = op;
693             return WalkResult::interrupt();
694           }
695         }
696     return WalkResult::advance();
697   });
698 
699   if (walkResult.wasInterrupted())
700     return inconsistentOp->emitError("input IR has RaW conflict");
701   return success();
702 }
703 
704 /// Annotate the IR with the result of the analysis. For testing/debugging only.
705 static void
706 annotateOpsWithBufferizationMarkers(Operation *op,
707                                     const BufferizationAliasInfo &aliasInfo,
708                                     AnalysisState &state) {
709   op->walk([&](Operation *op) {
710     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
711       for (OpOperand &opOperand : op->getOpOperands())
712         if (opOperand.get().getType().isa<TensorType>())
713           setInPlaceOpOperand(opOperand, aliasInfo.isInPlace(opOperand));
714   });
715 }
716 
717 /// Assert that IR is in destination-passing style. I.e., every value that is
718 /// returned or yielded from a block is:
719 /// * aliasing a bbArg of that block or a parent block, or
720 /// * aliasing an OpResult of a op in a parent block.
721 ///
722 /// Example:
723 /// ```
724 /// %0 = "some_op" : tensor<?xf32>
725 /// %1 = scf.if %c -> (tensor<?xf32>) {
726 ///   scf.yield %0 : tensor<?xf32>
727 /// } else {
728 ///   %t = linalg.init_tensor : tensor<?xf32>
729 ///   scf.yield %t : tensor<?xf32>
730 /// }
731 /// ```
732 /// In the above example, the first scf.yield op satifies destination-passing
733 /// style because the yielded value %0 is defined in the parent block. The
734 /// second scf.yield op does not satisfy destination-passing style because the
735 /// yielded value %t is defined in the same block as the scf.yield op.
736 // TODO: The current implementation checks for equivalent values instead of
737 // aliasing values, which is stricter than needed. We can currently not check
738 // for aliasing values because the analysis is a maybe-alias analysis and we
739 // need a must-alias analysis here.
740 static LogicalResult
741 assertDestinationPassingStyle(Operation *op, AnalysisState &state,
742                               BufferizationAliasInfo &aliasInfo,
743                               SmallVector<Operation *> &newOps) {
744   LogicalResult status = success();
745   DominanceInfo domInfo(op);
746   op->walk([&](Operation *returnOp) {
747     if (!isRegionReturnLike(returnOp) ||
748         !state.getOptions().isOpAllowed(returnOp))
749       return WalkResult::advance();
750 
751     for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
752       Value returnVal = returnValOperand.get();
753       // Skip non-tensor values.
754       if (!returnVal.getType().isa<TensorType>())
755         continue;
756 
757       bool foundEquivValue = false;
758       aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
759         if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
760           Operation *definingOp = bbArg.getOwner()->getParentOp();
761           if (definingOp->isProperAncestor(returnOp))
762             foundEquivValue = true;
763           return;
764         }
765 
766         Operation *definingOp = equivVal.getDefiningOp();
767         if (definingOp->getBlock()->findAncestorOpInBlock(
768                 *returnOp->getParentOp()))
769           // Skip ops that happen after `returnOp` and parent ops.
770           if (happensBefore(definingOp, returnOp, domInfo))
771             foundEquivValue = true;
772       });
773 
774       if (!foundEquivValue)
775         status =
776             returnOp->emitError()
777             << "operand #" << returnValOperand.getOperandNumber()
778             << " of ReturnLike op does not satisfy destination passing style";
779     }
780 
781     return WalkResult::advance();
782   });
783 
784   return status;
785 }
786 
787 LogicalResult bufferization::analyzeOp(Operation *op,
788                                        OneShotAnalysisState &state) {
789   DominanceInfo domInfo(op);
790   BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
791   const auto &options =
792       static_cast<const OneShotBufferizationOptions &>(state.getOptions());
793 
794   if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
795     return failure();
796 
797   // If the analysis fails, just return.
798   if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo,
799                              options.analysisFuzzerSeed)))
800     return failure();
801   equivalenceAnalysis(op, aliasInfo, state);
802 
803   for (const PostAnalysisStepFn &fn : options.postAnalysisSteps) {
804     SmallVector<Operation *> newOps;
805     if (failed(fn(op, state, aliasInfo, newOps)))
806       return failure();
807     // Analyze ops that were created by the PostAnalysisStepFn.
808     if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo)))
809       return failure();
810     equivalenceAnalysis(newOps, aliasInfo, state);
811   }
812 
813   bool failedAnalysis = false;
814   if (!options.allowReturnAllocs) {
815     SmallVector<Operation *> newOps;
816     failedAnalysis |=
817         failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps));
818   }
819 
820   // Gather all yielded tensors.
821   state.gatherYieldedTensors(op);
822 
823   // Analysis verification: After setting up alias/equivalence sets, each op
824   // can check for expected invariants/limitations and fail the analysis if
825   // necessary.
826   op->walk([&](Operation *op) {
827     if (BufferizableOpInterface bufferizableOp =
828             options.dynCastBufferizableOp(op))
829       failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state));
830   });
831 
832   // Annotate operations if we only want to report the analysis.
833   if (options.testAnalysisOnly)
834     annotateOpsWithBufferizationMarkers(op, aliasInfo, state);
835 
836   return success(!failedAnalysis);
837 }
838 
839 LogicalResult
840 bufferization::runOneShotBufferize(Operation *op,
841                                    const OneShotBufferizationOptions &options) {
842   OneShotAnalysisState state(op, options);
843   if (failed(analyzeOp(op, state)))
844     return failure();
845   if (options.testAnalysisOnly)
846     return success();
847   return bufferizeOp(op, state);
848 }
849