17a1579acSMatthias Springer //===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===//
27a1579acSMatthias Springer //
37a1579acSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47a1579acSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
57a1579acSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67a1579acSMatthias Springer //
77a1579acSMatthias Springer //===----------------------------------------------------------------------===//
87a1579acSMatthias Springer //
97a1579acSMatthias Springer // One-Shot Analysis analyzes function bodies. Function boundaries (FuncOp
107a1579acSMatthias Springer // bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
117a1579acSMatthias Springer // ModuleBufferization.cpp is an extension of One-Shot Analysis for simple
127a1579acSMatthias Springer // call graphs.
137a1579acSMatthias Springer //
147a1579acSMatthias Springer // One-Shot Bufferize consists of two phases.
157a1579acSMatthias Springer //
167a1579acSMatthias Springer // 1. Analyze ops to decide which OpResults can bufferize inplace, i.e., without
177a1579acSMatthias Springer //    inserting buffer copies. The analysis queries op bufferization semantics
187a1579acSMatthias Springer //    via `BufferizableOpInterface`.
197a1579acSMatthias Springer // 2. Bufferize ops by calling `BufferizableOpInterface::bufferize`. This
207a1579acSMatthias Springer //    function does not generate buffer copies for OpResults that were decided
217a1579acSMatthias Springer //    to bufferize inplace during the analysis phase.
227a1579acSMatthias Springer //
237a1579acSMatthias Springer // This file contains only the analysis. The actual bufferization is implemented
247a1579acSMatthias Springer // via `bufferizeOp` (Bufferize.h). For convenience, this file also contains a
257a1579acSMatthias Springer // helper function `runOneShotBufferize` that analyzes an op (and its nested
267a1579acSMatthias Springer // ops) and then bufferizes it.
277a1579acSMatthias Springer //
287a1579acSMatthias Springer // Inplace bufferization decisions are passed from the analysis to the
299597b16aSMatthias Springer // bufferization phase via `AnalysisState` and `BufferizationAliasInfo`.
307a1579acSMatthias Springer // They can be printed for debugging purposes with `testAnalysisOnly`.
317a1579acSMatthias Springer //
327a1579acSMatthias Springer // Ops that do not implement `BufferizableOpInterface` can be analyzed but are
337a1579acSMatthias Springer // treated conservatively. E.g., the analysis has to assume that their tensor
347a1579acSMatthias Springer // OpOperands bufferize to memory writes. While such ops can be analyzed, they
357a1579acSMatthias Springer // are not bufferized and remain in the IR. to_tensor and to_memref ops are
367a1579acSMatthias Springer // inserted at the bufferization boundary.
377a1579acSMatthias Springer //
387a1579acSMatthias Springer // This analysis caters to high-performance codegen where buffer reuse is deemed
397a1579acSMatthias Springer // critical: the analysis should fail if the bufferized form of the function
40855a11eeSMatthias Springer // needs to return a buffer, unless `allowReturnAllocs` is enabled.
417a1579acSMatthias Springer 
427a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
437a1579acSMatthias Springer 
447a1579acSMatthias Springer #include <random>
457a1579acSMatthias Springer 
467a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
477a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
487a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
49*b3ebe3beSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h"
50d6dab38aSMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h"
517a1579acSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
527a1579acSMatthias Springer #include "mlir/IR/AsmState.h"
537a1579acSMatthias Springer #include "mlir/IR/Dominance.h"
547a1579acSMatthias Springer #include "mlir/IR/Operation.h"
557a1579acSMatthias Springer #include "mlir/IR/TypeUtilities.h"
567a1579acSMatthias Springer #include "mlir/Interfaces/ControlFlowInterfaces.h"
577a1579acSMatthias Springer #include "llvm/ADT/DenseSet.h"
587a1579acSMatthias Springer #include "llvm/ADT/SetVector.h"
597a1579acSMatthias Springer 
607a1579acSMatthias Springer using namespace mlir;
617a1579acSMatthias Springer using namespace mlir::bufferization;
627a1579acSMatthias Springer 
isaTensor(Type t)637a1579acSMatthias Springer static bool isaTensor(Type t) { return t.isa<TensorType>(); }
647a1579acSMatthias Springer 
657a1579acSMatthias Springer //===----------------------------------------------------------------------===//
667a1579acSMatthias Springer // Bufferization-specific attribute manipulation.
677a1579acSMatthias Springer // These are for testing and debugging only. Bufferization information is
687a1579acSMatthias Springer // stored in BufferizationAliasInfo. When run with `testAnalysisOnly`, the IR
697a1579acSMatthias Springer // is annotated with the results of the analysis (copied from
707a1579acSMatthias Springer // BufferizationAliasInfo), so that they can be checked in tests.
717a1579acSMatthias Springer //===----------------------------------------------------------------------===//
727a1579acSMatthias Springer 
737a1579acSMatthias Springer /// Attribute marker to specify op results that can be bufferized inPlace.
747a1579acSMatthias Springer constexpr StringLiteral kInPlaceResultsAttrName = "__inplace_operands_attr__";
757a1579acSMatthias Springer 
767a1579acSMatthias Springer /// Mark whether OpOperand will be bufferized inplace.
setInPlaceOpOperand(OpOperand & opOperand,bool inPlace)777a1579acSMatthias Springer static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
787a1579acSMatthias Springer   Operation *op = opOperand.getOwner();
797a1579acSMatthias Springer   auto attr =
807a1579acSMatthias Springer       op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null<ArrayAttr>();
817a1579acSMatthias Springer   SmallVector<StringRef> inPlaceVector;
827a1579acSMatthias Springer   if (attr) {
837a1579acSMatthias Springer     inPlaceVector = SmallVector<StringRef>(
847a1579acSMatthias Springer         llvm::to_vector<4>(attr.getAsValueRange<StringAttr>()));
857a1579acSMatthias Springer   } else {
867a1579acSMatthias Springer     inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
877a1579acSMatthias Springer     for (OpOperand &opOperand : op->getOpOperands())
887a1579acSMatthias Springer       if (opOperand.get().getType().isa<TensorType>())
897a1579acSMatthias Springer         inPlaceVector[opOperand.getOperandNumber()] = "false";
907a1579acSMatthias Springer   }
917a1579acSMatthias Springer 
927a1579acSMatthias Springer   inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
937a1579acSMatthias Springer   op->setAttr(kInPlaceResultsAttrName,
947a1579acSMatthias Springer               OpBuilder(op).getStrArrayAttr(inPlaceVector));
957a1579acSMatthias Springer }
967a1579acSMatthias Springer 
977a1579acSMatthias Springer //===----------------------------------------------------------------------===//
987a1579acSMatthias Springer // BufferizationAliasInfo
997a1579acSMatthias Springer //===----------------------------------------------------------------------===//
1007a1579acSMatthias Springer 
BufferizationAliasInfo(Operation * rootOp)1017a1579acSMatthias Springer BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
1027a1579acSMatthias Springer   rootOp->walk([&](Operation *op) {
1037a1579acSMatthias Springer     for (Value v : op->getResults())
1047a1579acSMatthias Springer       if (v.getType().isa<TensorType>())
1057a1579acSMatthias Springer         createAliasInfoEntry(v);
1067a1579acSMatthias Springer     for (Region &r : op->getRegions())
1077a1579acSMatthias Springer       for (Block &b : r.getBlocks())
1087a1579acSMatthias Springer         for (auto bbArg : b.getArguments())
1097a1579acSMatthias Springer           if (bbArg.getType().isa<TensorType>())
1107a1579acSMatthias Springer             createAliasInfoEntry(bbArg);
1117a1579acSMatthias Springer   });
1127a1579acSMatthias Springer }
1137a1579acSMatthias Springer 
1147a1579acSMatthias Springer /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
1157a1579acSMatthias Springer /// beginning the alias and equivalence sets only contain `v` itself.
createAliasInfoEntry(Value v)1167a1579acSMatthias Springer void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
1177a1579acSMatthias Springer   aliasInfo.insert(v);
1187a1579acSMatthias Springer   equivalentInfo.insert(v);
1197a1579acSMatthias Springer }
1207a1579acSMatthias Springer 
1217a1579acSMatthias Springer /// Insert an info entry for `newValue` and merge its alias set with that of
1227a1579acSMatthias Springer /// `alias`.
insertNewBufferAlias(Value newValue,Value alias)1237a1579acSMatthias Springer void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
1247a1579acSMatthias Springer   createAliasInfoEntry(newValue);
1257a1579acSMatthias Springer   aliasInfo.unionSets(newValue, alias);
1267a1579acSMatthias Springer }
1277a1579acSMatthias Springer 
1287a1579acSMatthias Springer /// Insert an info entry for `newValue` and merge its alias set with that of
1297a1579acSMatthias Springer /// `alias`. Additionally, merge their equivalence classes.
insertNewBufferEquivalence(Value newValue,Value alias)1307a1579acSMatthias Springer void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
1317a1579acSMatthias Springer                                                         Value alias) {
1327a1579acSMatthias Springer   insertNewBufferAlias(newValue, alias);
1337a1579acSMatthias Springer   equivalentInfo.unionSets(newValue, alias);
1347a1579acSMatthias Springer }
1357a1579acSMatthias Springer 
1367a1579acSMatthias Springer /// Return `true` if a value was marked as in-place bufferized.
isInPlace(OpOperand & operand) const1377a1579acSMatthias Springer bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
1387a1579acSMatthias Springer   return inplaceBufferized.contains(&operand);
1397a1579acSMatthias Springer }
1407a1579acSMatthias Springer 
1417a1579acSMatthias Springer /// Set the inPlace bufferization spec to true.
bufferizeInPlace(OpOperand & operand,AnalysisState & state)1427a1579acSMatthias Springer void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
1439597b16aSMatthias Springer                                               AnalysisState &state) {
1447a1579acSMatthias Springer   markInPlace(operand);
145585a8a32SMatthias Springer   for (OpResult result : state.getAliasingOpResult(operand))
1467a1579acSMatthias Springer     aliasInfo.unionSets(result, operand.get());
1477a1579acSMatthias Springer }
1487a1579acSMatthias Springer 
1497a1579acSMatthias Springer /// Set the inPlace bufferization spec to false.
bufferizeOutOfPlace(OpOperand & operand)1507a1579acSMatthias Springer void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) {
1517a1579acSMatthias Springer   assert(!inplaceBufferized.contains(&operand) &&
1527a1579acSMatthias Springer          "OpOperand was already decided to bufferize inplace");
1537a1579acSMatthias Springer }
1547a1579acSMatthias Springer 
1557a1579acSMatthias Springer /// Apply `fun` to all the members of the equivalence class of `v`.
applyOnEquivalenceClass(Value v,function_ref<void (Value)> fun) const1567a1579acSMatthias Springer void BufferizationAliasInfo::applyOnEquivalenceClass(
1577a1579acSMatthias Springer     Value v, function_ref<void(Value)> fun) const {
1587a1579acSMatthias Springer   auto leaderIt = equivalentInfo.findLeader(v);
1597a1579acSMatthias Springer   for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
1607a1579acSMatthias Springer        ++mit) {
1617a1579acSMatthias Springer     fun(*mit);
1627a1579acSMatthias Springer   }
1637a1579acSMatthias Springer }
1647a1579acSMatthias Springer 
1657a1579acSMatthias Springer /// Apply `fun` to all aliases of `v`.
applyOnAliases(Value v,function_ref<void (Value)> fun) const1667a1579acSMatthias Springer void BufferizationAliasInfo::applyOnAliases(
1677a1579acSMatthias Springer     Value v, function_ref<void(Value)> fun) const {
1687a1579acSMatthias Springer   auto leaderIt = aliasInfo.findLeader(v);
1697a1579acSMatthias Springer   for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
1707a1579acSMatthias Springer     fun(*mit);
1717a1579acSMatthias Springer   }
1727a1579acSMatthias Springer }
1737a1579acSMatthias Springer 
1747a1579acSMatthias Springer BufferizationAliasInfo::EquivalenceClassRangeType
getAliases(Value v) const1757a1579acSMatthias Springer BufferizationAliasInfo::getAliases(Value v) const {
1767a1579acSMatthias Springer   DenseSet<Value> res;
1777a1579acSMatthias Springer   auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v));
1787a1579acSMatthias Springer   for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
1797a1579acSMatthias Springer        mit != meit; ++mit) {
1807a1579acSMatthias Springer     res.insert(static_cast<Value>(*mit));
1817a1579acSMatthias Springer   }
1827a1579acSMatthias Springer   return BufferizationAliasInfo::EquivalenceClassRangeType(
1837a1579acSMatthias Springer       aliasInfo.member_begin(it), aliasInfo.member_end());
1847a1579acSMatthias Springer }
1857a1579acSMatthias Springer 
1867a1579acSMatthias Springer //===----------------------------------------------------------------------===//
1879597b16aSMatthias Springer // OneShotAnalysisState
1887a1579acSMatthias Springer //===----------------------------------------------------------------------===//
1897a1579acSMatthias Springer 
OneShotAnalysisState(Operation * op,const OneShotBufferizationOptions & options)1909597b16aSMatthias Springer OneShotAnalysisState::OneShotAnalysisState(
1919597b16aSMatthias Springer     Operation *op, const OneShotBufferizationOptions &options)
1929597b16aSMatthias Springer     : AnalysisState(options), aliasInfo(op) {
1937a1579acSMatthias Springer   // Set up alias sets for OpResults that must bufferize in-place. This should
1947a1579acSMatthias Springer   // be done before making any other bufferization decisions.
1957a1579acSMatthias Springer   op->walk([&](BufferizableOpInterface bufferizableOp) {
1967a1579acSMatthias Springer     if (!options.isOpAllowed(bufferizableOp))
1977a1579acSMatthias Springer       return WalkResult::skip();
1987a1579acSMatthias Springer     for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
1997a1579acSMatthias Springer       if (opOperand.get().getType().isa<TensorType>())
2007a1579acSMatthias Springer         if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) {
201585a8a32SMatthias Springer           for (OpResult opResult :
2027a1579acSMatthias Springer                bufferizableOp.getAliasingOpResult(opOperand, *this))
2037a1579acSMatthias Springer             aliasInfo.unionAliasSets(opOperand.get(), opResult);
2047a1579acSMatthias Springer           aliasInfo.markInPlace(opOperand);
2057a1579acSMatthias Springer         }
2067a1579acSMatthias Springer     }
2077a1579acSMatthias Springer     return WalkResult::advance();
2087a1579acSMatthias Springer   });
2097a1579acSMatthias Springer }
2107a1579acSMatthias Springer 
isInPlace(OpOperand & opOperand) const2119597b16aSMatthias Springer bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const {
2127a1579acSMatthias Springer   return aliasInfo.isInPlace(opOperand);
2137a1579acSMatthias Springer }
2147a1579acSMatthias Springer 
areEquivalentBufferizedValues(Value v1,Value v2) const2159597b16aSMatthias Springer bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
2167a1579acSMatthias Springer                                                          Value v2) const {
2177a1579acSMatthias Springer   return aliasInfo.areEquivalentBufferizedValues(v1, v2);
2187a1579acSMatthias Springer }
2197a1579acSMatthias Springer 
areAliasingBufferizedValues(Value v1,Value v2) const2203490aadfSMatthias Springer bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1,
2213490aadfSMatthias Springer                                                        Value v2) const {
2223490aadfSMatthias Springer   return aliasInfo.areAliasingBufferizedValues(v1, v2);
2233490aadfSMatthias Springer }
2243490aadfSMatthias Springer 
2259e24f0f4SMatthias Springer // Gather yielded tensors in `yieldedTensors` by querying all aliases. This is
2269e24f0f4SMatthias Springer // to ensure that such information is available during bufferization time.
2279e24f0f4SMatthias Springer // Alias information can no longer be queried through BufferizationAliasInfo
2289e24f0f4SMatthias Springer // once we have started modifying the IR.
gatherYieldedTensors(Operation * op)2299e24f0f4SMatthias Springer void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
2309e24f0f4SMatthias Springer   op->walk([&](Operation *returnOp) {
2319e24f0f4SMatthias Springer     if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp))
2329e24f0f4SMatthias Springer       return WalkResult::advance();
2339e24f0f4SMatthias Springer 
2349e24f0f4SMatthias Springer     for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
2359e24f0f4SMatthias Springer       Value returnVal = returnValOperand.get();
2369e24f0f4SMatthias Springer       // Skip non-tensor values.
2379e24f0f4SMatthias Springer       if (!returnVal.getType().isa<TensorType>())
2389e24f0f4SMatthias Springer         continue;
2399e24f0f4SMatthias Springer 
2409e24f0f4SMatthias Springer       // Add all aliases of the returned value. But only the ones that are in
2419e24f0f4SMatthias Springer       // the same block.
2429e24f0f4SMatthias Springer       aliasInfo.applyOnAliases(returnVal, [&](Value v) {
2439e24f0f4SMatthias Springer         if (auto bbArg = v.dyn_cast<BlockArgument>()) {
2449e24f0f4SMatthias Springer           if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp())
2459e24f0f4SMatthias Springer             yieldedTensors.insert(bbArg);
2469e24f0f4SMatthias Springer           return;
2479e24f0f4SMatthias Springer         }
2489e24f0f4SMatthias Springer         Operation *definingOp = v.getDefiningOp();
2499e24f0f4SMatthias Springer         if (definingOp->getParentOp() == returnOp->getParentOp())
2509e24f0f4SMatthias Springer           yieldedTensors.insert(v);
2519e24f0f4SMatthias Springer       });
2529e24f0f4SMatthias Springer     }
2539e24f0f4SMatthias Springer 
2549e24f0f4SMatthias Springer     return WalkResult::advance();
2559e24f0f4SMatthias Springer   });
2569e24f0f4SMatthias Springer }
2579e24f0f4SMatthias Springer 
gatherUndefinedTensorUses(Operation * op)258988748c0SMatthias Springer void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
259988748c0SMatthias Springer   op->walk([&](Operation *op) {
260988748c0SMatthias Springer     // Skip unknown ops.
261988748c0SMatthias Springer     auto bufferizableOp = getOptions().dynCastBufferizableOp(op);
262988748c0SMatthias Springer     if (!bufferizableOp)
263988748c0SMatthias Springer       return WalkResult::skip();
264988748c0SMatthias Springer 
265988748c0SMatthias Springer     // Check all tensor OpResults.
266988748c0SMatthias Springer     for (OpResult opResult : op->getOpResults()) {
267988748c0SMatthias Springer       if (!opResult.getType().isa<TensorType>())
268988748c0SMatthias Springer         continue;
269988748c0SMatthias Springer 
270988748c0SMatthias Springer       // If there is no preceding memory write, the tensor contents are
271988748c0SMatthias Springer       // undefined.
272988748c0SMatthias Springer       // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
273988748c0SMatthias Springer       // use-def chain, it returns that value, regardless of whether it is a
274988748c0SMatthias Springer       // memory write or not.
275988748c0SMatthias Springer       SetVector<Value> lastWrites = findLastPrecedingWrite(opResult);
276988748c0SMatthias Springer       bool isUndefined = llvm::none_of(lastWrites, [&](Value lastWrite) {
277988748c0SMatthias Springer         if (auto bufferizableOp = getOptions().dynCastBufferizableOp(lastWrite))
278988748c0SMatthias Springer           return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(),
279988748c0SMatthias Springer                                               *this);
280988748c0SMatthias Springer         return true;
281988748c0SMatthias Springer       });
282988748c0SMatthias Springer       if (isUndefined)
283988748c0SMatthias Springer         for (OpOperand &use : opResult.getUses())
284988748c0SMatthias Springer           undefinedTensorUses.insert(&use);
285988748c0SMatthias Springer     }
286988748c0SMatthias Springer 
287988748c0SMatthias Springer     return WalkResult::advance();
288988748c0SMatthias Springer   });
289988748c0SMatthias Springer }
290988748c0SMatthias Springer 
hasUndefinedContents(OpOperand * opOperand) const291988748c0SMatthias Springer bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
292988748c0SMatthias Springer   return undefinedTensorUses.contains(opOperand);
293988748c0SMatthias Springer }
294988748c0SMatthias Springer 
isTensorYielded(Value tensor) const2959e24f0f4SMatthias Springer bool OneShotAnalysisState::isTensorYielded(Value tensor) const {
2969e24f0f4SMatthias Springer   return yieldedTensors.contains(tensor);
2979e24f0f4SMatthias Springer }
2989e24f0f4SMatthias Springer 
isValueWritten(Value value) const2993490aadfSMatthias Springer bool OneShotAnalysisState::isValueWritten(Value value) const {
3003490aadfSMatthias Springer   bool isWritten = false;
3013490aadfSMatthias Springer   aliasInfo.applyOnAliases(value, [&](Value val) {
3023490aadfSMatthias Springer     for (OpOperand &use : val.getUses())
3033490aadfSMatthias Springer       if (isInPlace(use) && bufferizesToMemoryWrite(use))
3043490aadfSMatthias Springer         isWritten = true;
3053490aadfSMatthias Springer   });
3063490aadfSMatthias Springer   return isWritten;
3073490aadfSMatthias Springer }
3083490aadfSMatthias Springer 
isWritable(Value value) const309032be233SMatthias Springer bool OneShotAnalysisState::isWritable(Value value) const {
310032be233SMatthias Springer   // TODO: Out-of-place bufferized value could be considered writable.
311032be233SMatthias Springer   if (auto bufferizableOp = getOptions().dynCastBufferizableOp(value))
312032be233SMatthias Springer     return bufferizableOp.isWritable(value, *this);
313032be233SMatthias Springer 
314032be233SMatthias Springer   // Query BufferizableOpInterface to see if the BlockArgument is writable.
315032be233SMatthias Springer   if (auto bbArg = value.dyn_cast<BlockArgument>())
316032be233SMatthias Springer     if (auto bufferizableOp =
317032be233SMatthias Springer             getOptions().dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
318032be233SMatthias Springer       return bufferizableOp.isWritable(bbArg, *this);
319032be233SMatthias Springer 
320032be233SMatthias Springer   // Not a bufferizable op: The conservative answer is "not writable".
321032be233SMatthias Springer   return false;
322032be233SMatthias Springer }
323032be233SMatthias Springer 
3247a1579acSMatthias Springer //===----------------------------------------------------------------------===//
3257a1579acSMatthias Springer // Bufferization-specific alias analysis.
3267a1579acSMatthias Springer //===----------------------------------------------------------------------===//
3277a1579acSMatthias Springer 
3287a1579acSMatthias Springer /// Return true if opOperand has been decided to bufferize in-place.
isInplaceMemoryWrite(OpOperand & opOperand,const BufferizationAliasInfo & aliasInfo,const AnalysisState & state)3297a1579acSMatthias Springer static bool isInplaceMemoryWrite(OpOperand &opOperand,
3307a1579acSMatthias Springer                                  const BufferizationAliasInfo &aliasInfo,
331032be233SMatthias Springer                                  const AnalysisState &state) {
3327a1579acSMatthias Springer   // OpOperands that do not bufferize to a memory write do not write in-place.
3337a1579acSMatthias Springer   if (!state.bufferizesToMemoryWrite(opOperand))
3347a1579acSMatthias Springer     return false;
3357a1579acSMatthias Springer   // Check current bufferization decisions.
3367a1579acSMatthias Springer   return aliasInfo.isInPlace(opOperand);
3377a1579acSMatthias Springer }
3387a1579acSMatthias Springer 
3397a1579acSMatthias Springer /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
3407a1579acSMatthias Springer /// properly dominates `b` and `b` is not inside `a`.
happensBefore(Operation * a,Operation * b,const DominanceInfo & domInfo)3417a1579acSMatthias Springer static bool happensBefore(Operation *a, Operation *b,
3427a1579acSMatthias Springer                           const DominanceInfo &domInfo) {
3437a1579acSMatthias Springer   do {
3447a1579acSMatthias Springer     // TODO: Instead of isProperAncestor + properlyDominates, we should use
3457a1579acSMatthias Springer     // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
3467a1579acSMatthias Springer     if (a->isProperAncestor(b))
3477a1579acSMatthias Springer       return false;
3487a1579acSMatthias Springer     if (domInfo.properlyDominates(a, b))
3497a1579acSMatthias Springer       return true;
3507a1579acSMatthias Springer   } while ((a = a->getParentOp()));
3517a1579acSMatthias Springer   return false;
3527a1579acSMatthias Springer }
3537a1579acSMatthias Springer 
3549235e597SMatthias Springer /// For each given value, find the closest enclosing repetitive region. If this
3559235e597SMatthias Springer /// is the same region for each value, return it. Otherwise return None.
3569235e597SMatthias Springer /// Note: If there is no enclosing repetitive region, return nullptr.
3579235e597SMatthias Springer static Optional<Region *>
getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values)3589235e597SMatthias Springer getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values) {
3599235e597SMatthias Springer   if (values.empty())
3609235e597SMatthias Springer     return None;
3619235e597SMatthias Springer   Region *r = getEnclosingRepetitiveRegion(values.front());
3629235e597SMatthias Springer   for (Value value : values.drop_front())
3639235e597SMatthias Springer     if (getEnclosingRepetitiveRegion(value) != r)
3649235e597SMatthias Springer       return None;
3659235e597SMatthias Springer   return r;
3669235e597SMatthias Springer }
3679235e597SMatthias Springer 
36837a14735SMatthias Springer /// Return `true` if the given tensor value is a memory write. Most values are
36937a14735SMatthias Springer /// tensor writes, but ops that define a tensor SSA value without specifying its
370ffdbecccSMatthias Springer /// contents (e.g., alloc_tensor) are not.
isMemoryWrite(Value value,const AnalysisState & state)37137a14735SMatthias Springer static bool isMemoryWrite(Value value, const AnalysisState &state) {
37237a14735SMatthias Springer   auto opResult = value.dyn_cast<OpResult>();
37337a14735SMatthias Springer   if (!opResult)
37437a14735SMatthias Springer     return true;
37537a14735SMatthias Springer   auto bufferizableOp = state.getOptions().dynCastBufferizableOp(value);
37637a14735SMatthias Springer   if (!bufferizableOp)
37737a14735SMatthias Springer     return true;
37837a14735SMatthias Springer   return bufferizableOp.isMemoryWrite(opResult, state);
37937a14735SMatthias Springer }
38037a14735SMatthias Springer 
3817a1579acSMatthias Springer /// Annotate IR with details about the detected RaW conflict.
annotateConflict(OpOperand * uRead,OpOperand * uConflictingWrite,Value lastWrite)3827a1579acSMatthias Springer static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
3837a1579acSMatthias Springer                              Value lastWrite) {
3847a1579acSMatthias Springer   static uint64_t counter = 0;
3857a1579acSMatthias Springer   Operation *readingOp = uRead->getOwner();
3867a1579acSMatthias Springer   Operation *conflictingWritingOp = uConflictingWrite->getOwner();
3877a1579acSMatthias Springer 
3887a1579acSMatthias Springer   OpBuilder b(conflictingWritingOp->getContext());
3897a1579acSMatthias Springer   std::string id = "C_" + std::to_string(counter++);
3907a1579acSMatthias Springer 
3917a1579acSMatthias Springer   std::string conflictingWriteAttr =
3927a1579acSMatthias Springer       id +
3937a1579acSMatthias Springer       "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) +
3947a1579acSMatthias Springer       "]";
3957a1579acSMatthias Springer   conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr());
3967a1579acSMatthias Springer 
3977a1579acSMatthias Springer   std::string readAttr =
3987a1579acSMatthias Springer       id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
3997a1579acSMatthias Springer   readingOp->setAttr(readAttr, b.getUnitAttr());
4007a1579acSMatthias Springer 
4017a1579acSMatthias Springer   if (auto opResult = lastWrite.dyn_cast<OpResult>()) {
4027a1579acSMatthias Springer     std::string lastWriteAttr = id + "[LAST-WRITE: result " +
4037a1579acSMatthias Springer                                 std::to_string(opResult.getResultNumber()) +
4047a1579acSMatthias Springer                                 "]";
4057a1579acSMatthias Springer     opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr());
4067a1579acSMatthias Springer   } else {
4077a1579acSMatthias Springer     auto bbArg = lastWrite.cast<BlockArgument>();
4087a1579acSMatthias Springer     std::string lastWriteAttr =
4097a1579acSMatthias Springer         id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
4107a1579acSMatthias Springer     bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr());
4117a1579acSMatthias Springer   }
4127a1579acSMatthias Springer }
4137a1579acSMatthias Springer 
4147a1579acSMatthias Springer /// Given sets of uses and writes, return true if there is a RaW conflict under
4157a1579acSMatthias Springer /// the assumption that all given reads/writes alias the same buffer and that
4167a1579acSMatthias Springer /// all given writes bufferize inplace.
4177a1579acSMatthias Springer ///
4187a1579acSMatthias Springer /// A conflict is: According to SSA use-def chains, a read R is supposed to read
4197a1579acSMatthias Springer /// the result of a write W1. But because of bufferization decisions, R actually
4207a1579acSMatthias Springer /// reads another write W2.
hasReadAfterWriteInterference(const DenseSet<OpOperand * > & usesRead,const DenseSet<OpOperand * > & usesWrite,const DominanceInfo & domInfo,AnalysisState & state,const BufferizationAliasInfo & aliasInfo)4217a1579acSMatthias Springer static bool hasReadAfterWriteInterference(
4227a1579acSMatthias Springer     const DenseSet<OpOperand *> &usesRead,
4237a1579acSMatthias Springer     const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo,
4249597b16aSMatthias Springer     AnalysisState &state, const BufferizationAliasInfo &aliasInfo) {
4257a1579acSMatthias Springer   const BufferizationOptions &options = state.getOptions();
4267a1579acSMatthias Springer 
42737a14735SMatthias Springer   // Gather all written aliases. Skip over aliases that are not actual writes.
4289235e597SMatthias Springer   SmallVector<Value> writtenAliases;
4299235e597SMatthias Springer   for (OpOperand *uWrite : usesWrite)
43037a14735SMatthias Springer     if (isMemoryWrite(uWrite->get(), state))
4319235e597SMatthias Springer       writtenAliases.push_back(uWrite->get());
4329235e597SMatthias Springer   // Find the inner-most enclosing repetitive region of each alias. If this is
4339235e597SMatthias Springer   // the same region for every alias, save it in `repetitiveRegionOfWrites`.
4349235e597SMatthias Springer   Optional<Region *> repetitiveRegionOfWrites =
4359235e597SMatthias Springer       getCommonEnclosingRepetitiveRegion(writtenAliases);
4369235e597SMatthias Springer 
4377a1579acSMatthias Springer   for (OpOperand *uRead : usesRead) {
4387a1579acSMatthias Springer     Operation *readingOp = uRead->getOwner();
4397a1579acSMatthias Springer 
4407a1579acSMatthias Springer     // Find most recent writes of uRead by following the SSA use-def chain.
4417a1579acSMatthias Springer     // E.g.:
4427a1579acSMatthias Springer     //
4437a1579acSMatthias Springer     // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
4447a1579acSMatthias Springer     // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
4457a1579acSMatthias Springer     // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
4467a1579acSMatthias Springer     //
4477a1579acSMatthias Springer     // In the above example, if uRead is the OpOperand of reading_op, lastWrite
4487a1579acSMatthias Springer     // is %0. Note that operations that create an alias but do not write (such
4497a1579acSMatthias Springer     // as ExtractSliceOp) are skipped.
4507a1579acSMatthias Springer     SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get());
4517a1579acSMatthias Springer 
4527a1579acSMatthias Springer     // Look for conflicting memory writes. Potential conflicts are writes to an
4537a1579acSMatthias Springer     // alias that have been decided to bufferize inplace.
4547a1579acSMatthias Springer     for (OpOperand *uConflictingWrite : usesWrite) {
4557a1579acSMatthias Springer       // Throughout this loop, check for multiple requirements that have to be
4567a1579acSMatthias Springer       // met for uConflictingWrite to be an actual conflict.
4577a1579acSMatthias Springer       Operation *conflictingWritingOp = uConflictingWrite->getOwner();
4587a1579acSMatthias Springer 
4599235e597SMatthias Springer       // Check if conflictingWritingOp is in the same repetitive region as all
4609235e597SMatthias Springer       // written aliases. If this is not the case, there is no meaningful
4619235e597SMatthias Springer       // `happensBefore` relationship because conflictingWritingOp may be
4629235e597SMatthias Springer       // executed multiple times. E.g.:
4639235e597SMatthias Springer       //
4649235e597SMatthias Springer       // %0 = ... : tensor<?xf32>
4659235e597SMatthias Springer       // scf.for ... {
4669235e597SMatthias Springer       //   "reading_op"(%0) : tensor<?xf32>
4679235e597SMatthias Springer       //   %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
4689235e597SMatthias Springer       //   ...
4699235e597SMatthias Springer       // }
4709235e597SMatthias Springer       //
4719235e597SMatthias Springer       // In the above example, reading_op happens before writing_op according to
4729235e597SMatthias Springer       // op dominance. However, both ops may happen multiple times; in
4739235e597SMatthias Springer       // particular, the second execution of reading_op happens after the first
4749235e597SMatthias Springer       // execution of writing_op. This is problematic if the tensor they operate
4759235e597SMatthias Springer       // on (%0) is defined outside of the loop.
4769235e597SMatthias Springer       //
4779235e597SMatthias Springer       // Counter example:
4789235e597SMatthias Springer       //
4799235e597SMatthias Springer       // scf.for ... {
4809235e597SMatthias Springer       //   %0 = ... : tensor<?xf32>
4819235e597SMatthias Springer       //   "reading_op"(%0) : tensor<?xf32>
4829235e597SMatthias Springer       //   %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
4839235e597SMatthias Springer       //   ...
4849235e597SMatthias Springer       // }
4859235e597SMatthias Springer       //
4869235e597SMatthias Springer       // In this example, %0 is in the same repetitive region as
4879235e597SMatthias Springer       // conflictingWritingOp, so op dominance can be used to compute the
4889235e597SMatthias Springer       // `happensBefore` relationship.
4899235e597SMatthias Springer       //
4909235e597SMatthias Springer       // Note: iter_args of loops are not aliases of their respective block
4919235e597SMatthias Springer       // arguments, so op domanice can be used when analyzing ops that operate
4929235e597SMatthias Springer       // on them.
49337a14735SMatthias Springer       //
49437a14735SMatthias Springer       // Note: If `writtenAliases` is empty, there are no memory writes outside
49537a14735SMatthias Springer       // of the repetitive region of conflictingWritingOp, which means that all
49637a14735SMatthias Springer       // relevant aliases are inside the same repetitive region.
4979235e597SMatthias Springer       bool canUseOpDominance =
49837a14735SMatthias Springer           writtenAliases.empty() ||
4999235e597SMatthias Springer           repetitiveRegionOfWrites ==
5009235e597SMatthias Springer               getEnclosingRepetitiveRegion(conflictingWritingOp);
5019235e597SMatthias Springer 
5027a1579acSMatthias Springer       // No conflict if the readingOp dominates conflictingWritingOp, i.e., the
5037a1579acSMatthias Springer       // write is not visible when reading.
5049235e597SMatthias Springer       //
5059235e597SMatthias Springer       // Note: If ops are executed multiple times (e.g., because they are inside
5069235e597SMatthias Springer       //       a loop), there may be no meaningful `happensBefore` relationship.
5079235e597SMatthias Springer       if (canUseOpDominance &&
5089235e597SMatthias Springer           happensBefore(readingOp, conflictingWritingOp, domInfo))
5097a1579acSMatthias Springer         continue;
5107a1579acSMatthias Springer 
5117a1579acSMatthias Springer       // No conflict if the reading use equals the use of the conflicting write.
5129235e597SMatthias Springer       // A use cannot conflict with itself.
5139235e597SMatthias Springer       //
5149235e597SMatthias Springer       // Note: Just being the same op is not enough. It has to be the same use.
5159235e597SMatthias Springer       // Note: If the op is executed multiple times (e.g., because it is inside
5169235e597SMatthias Springer       //       a loop), it may be conflicting with itself.
5179235e597SMatthias Springer       if (canUseOpDominance && uConflictingWrite == uRead)
5187a1579acSMatthias Springer         continue;
5197a1579acSMatthias Springer 
5207a1579acSMatthias Springer       // No conflict if the op interface says so.
5217a1579acSMatthias Springer       if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp))
5227a1579acSMatthias Springer         if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
5237a1579acSMatthias Springer           continue;
5247a1579acSMatthias Springer 
5257a1579acSMatthias Springer       if (conflictingWritingOp != readingOp)
5267a1579acSMatthias Springer         if (auto bufferizableOp =
5277a1579acSMatthias Springer                 options.dynCastBufferizableOp(conflictingWritingOp))
5287a1579acSMatthias Springer           if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
5297a1579acSMatthias Springer             continue;
5307a1579acSMatthias Springer 
5317a1579acSMatthias Springer       // Ops are not conflicting if they are in mutually exclusive regions.
5329235e597SMatthias Springer       //
5339235e597SMatthias Springer       // Note: If ops are executed multiple times (e.g., because they are inside
5349235e597SMatthias Springer       //       a loop), mutually exclusive regions may be executed multiple
5359235e597SMatthias Springer       //       times.
5369235e597SMatthias Springer       if (canUseOpDominance &&
5379235e597SMatthias Springer           insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp))
5387a1579acSMatthias Springer         continue;
5397a1579acSMatthias Springer 
5407a1579acSMatthias Springer       // Check all possible last writes.
5417a1579acSMatthias Springer       for (Value lastWrite : lastWrites) {
5427a1579acSMatthias Springer         // No conflict if the conflicting write happens before the last
5437a1579acSMatthias Springer         // write.
5447a1579acSMatthias Springer         if (Operation *writingOp = lastWrite.getDefiningOp()) {
5457a1579acSMatthias Springer           if (happensBefore(conflictingWritingOp, writingOp, domInfo))
5467a1579acSMatthias Springer             // conflictingWritingOp happens before writingOp. No conflict.
5477a1579acSMatthias Springer             continue;
5487a1579acSMatthias Springer           // No conflict if conflictingWritingOp is contained in writingOp.
5497a1579acSMatthias Springer           if (writingOp->isProperAncestor(conflictingWritingOp))
5507a1579acSMatthias Springer             continue;
5517a1579acSMatthias Springer         } else {
5527a1579acSMatthias Springer           auto bbArg = lastWrite.cast<BlockArgument>();
5537a1579acSMatthias Springer           Block *block = bbArg.getOwner();
5547a1579acSMatthias Springer           if (!block->findAncestorOpInBlock(*conflictingWritingOp))
5557a1579acSMatthias Springer             // conflictingWritingOp happens outside of the block. No
5567a1579acSMatthias Springer             // conflict.
5577a1579acSMatthias Springer             continue;
5587a1579acSMatthias Springer         }
5597a1579acSMatthias Springer 
5607a1579acSMatthias Springer         // No conflict if the conflicting write and the last write are the same
5617a1579acSMatthias Springer         // use.
562585a8a32SMatthias Springer         SmallVector<OpResult> aliasingOpResult =
563585a8a32SMatthias Springer             state.getAliasingOpResult(*uConflictingWrite);
564585a8a32SMatthias Springer         if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite)
5657a1579acSMatthias Springer           continue;
5667a1579acSMatthias Springer 
5677a1579acSMatthias Springer         // All requirements are met. Conflict found!
5687a1579acSMatthias Springer 
5697a1579acSMatthias Springer         if (options.printConflicts)
5707a1579acSMatthias Springer           annotateConflict(uRead, uConflictingWrite, lastWrite);
5717a1579acSMatthias Springer 
5727a1579acSMatthias Springer         return true;
5737a1579acSMatthias Springer       }
5747a1579acSMatthias Springer     }
5757a1579acSMatthias Springer   }
5767a1579acSMatthias Springer 
5777a1579acSMatthias Springer   return false;
5787a1579acSMatthias Springer }
5797a1579acSMatthias Springer 
580032be233SMatthias Springer // Helper function to iterate on aliases of `root` and capture the writes.
getAliasingInplaceWrites(DenseSet<OpOperand * > & res,Value root,const BufferizationAliasInfo & aliasInfo,const AnalysisState & state)581032be233SMatthias Springer static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root,
582032be233SMatthias Springer                                      const BufferizationAliasInfo &aliasInfo,
583032be233SMatthias Springer                                      const AnalysisState &state) {
584032be233SMatthias Springer   aliasInfo.applyOnAliases(root, [&](Value alias) {
585032be233SMatthias Springer     for (auto &use : alias.getUses())
586032be233SMatthias Springer       // Inplace write to a value that aliases root.
587032be233SMatthias Springer       if (isInplaceMemoryWrite(use, aliasInfo, state))
588032be233SMatthias Springer         res.insert(&use);
589032be233SMatthias Springer   });
590032be233SMatthias Springer }
591032be233SMatthias Springer 
592032be233SMatthias Springer // Helper function to iterate on aliases of `root` and capture the reads.
getAliasingReads(DenseSet<OpOperand * > & res,Value root,const BufferizationAliasInfo & aliasInfo,const AnalysisState & state)593032be233SMatthias Springer static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
594032be233SMatthias Springer                              const BufferizationAliasInfo &aliasInfo,
595032be233SMatthias Springer                              const AnalysisState &state) {
596032be233SMatthias Springer   aliasInfo.applyOnAliases(root, [&](Value alias) {
597032be233SMatthias Springer     for (auto &use : alias.getUses())
598032be233SMatthias Springer       // Read to a value that aliases root.
599032be233SMatthias Springer       if (state.bufferizesToMemoryRead(use))
600032be233SMatthias Springer         res.insert(&use);
601032be233SMatthias Springer   });
602032be233SMatthias Springer }
603032be233SMatthias Springer 
6047a1579acSMatthias Springer /// Return true if bufferizing `operand` inplace would create a conflict. A read
6057a1579acSMatthias Springer /// R and a write W of the same alias set is a conflict if inplace bufferization
6067a1579acSMatthias Springer /// of W changes the value read by R to a value different from the one that
6077a1579acSMatthias Springer /// would be expected by tracing back R's origin through SSA use-def chains.
6087a1579acSMatthias Springer /// A conflict can only be introduced by a new alias and/or an inplace
6097a1579acSMatthias Springer /// bufferization decision.
6107a1579acSMatthias Springer ///
6117a1579acSMatthias Springer /// Example:
6127a1579acSMatthias Springer /// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?}
6137a1579acSMatthias Springer /// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32>
6147a1579acSMatthias Springer /// %e = tensor.extract_slice %1
6157a1579acSMatthias Springer /// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32>
6167a1579acSMatthias Springer /// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32>
6177a1579acSMatthias Springer ///
6187a1579acSMatthias Springer /// In the above example, the two TransferWriteOps have already been decided to
6197a1579acSMatthias Springer /// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a
6207a1579acSMatthias Springer /// conflict because:
6217a1579acSMatthias Springer /// * According to SSA use-def chains, we expect to read the result of %1.
6227a1579acSMatthias Springer /// * However, adding an alias {%0, %t} would mean that the second
6237a1579acSMatthias Springer ///   TransferWriteOp overwrites the first one. Therefore, the TransferReadOp
6247a1579acSMatthias Springer ///   would no longer be reading the result of %1.
6257a1579acSMatthias Springer ///
6267a1579acSMatthias Springer /// If `checkConsistencyOnly` is true, this function checks if there is a
6277a1579acSMatthias Springer /// read-after-write conflict without bufferizing `operand` inplace. This would
6287a1579acSMatthias Springer /// indicate a problem with the current inplace bufferization decisions.
6297a1579acSMatthias Springer ///
6307a1579acSMatthias Springer /// Note: If `checkConsistencyOnly`, this function may be called with a null
6317a1579acSMatthias Springer /// OpResult. In that case, only the consistency of bufferization decisions
6327a1579acSMatthias Springer /// involving aliases of the given OpOperand are checked.
wouldCreateReadAfterWriteInterference(OpOperand & operand,const DominanceInfo & domInfo,AnalysisState & state,const BufferizationAliasInfo & aliasInfo,bool checkConsistencyOnly=false)6337a1579acSMatthias Springer static bool wouldCreateReadAfterWriteInterference(
6349597b16aSMatthias Springer     OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state,
6357a1579acSMatthias Springer     const BufferizationAliasInfo &aliasInfo,
6367a1579acSMatthias Springer     bool checkConsistencyOnly = false) {
6377a1579acSMatthias Springer   // Collect reads and writes of all aliases of OpOperand and OpResult.
6387a1579acSMatthias Springer   DenseSet<OpOperand *> usesRead, usesWrite;
639032be233SMatthias Springer   getAliasingReads(usesRead, operand.get(), aliasInfo, state);
640032be233SMatthias Springer   getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state);
641585a8a32SMatthias Springer   for (OpResult result : state.getAliasingOpResult(operand)) {
642032be233SMatthias Springer     getAliasingReads(usesRead, result, aliasInfo, state);
643032be233SMatthias Springer     getAliasingInplaceWrites(usesWrite, result, aliasInfo, state);
6447a1579acSMatthias Springer   }
6457a1579acSMatthias Springer   if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
6467a1579acSMatthias Springer     usesWrite.insert(&operand);
6477a1579acSMatthias Springer 
6487a1579acSMatthias Springer   return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state,
6497a1579acSMatthias Springer                                        aliasInfo);
6507a1579acSMatthias Springer }
6517a1579acSMatthias Springer 
652032be233SMatthias Springer /// Check the reverse SSA use-def chain (following aliasing OpOperands) for
653032be233SMatthias Springer /// non-writable tensor values. Stop searching when an out-of-place bufferized
654032be233SMatthias Springer /// OpOperand was found (or when the OpOperand was not bufferized yet).
655032be233SMatthias Springer /// `currentOpOperand` is assumed to be in-place, even if that decision was not
656032be233SMatthias Springer /// materialized in `aliasInfo` yet.
6577a1579acSMatthias Springer static bool
hasPrecedingAliasingNonWritableTensor(Value value,OpOperand * currentOpOperand,const BufferizationAliasInfo & aliasInfo,const OneShotAnalysisState & state)658032be233SMatthias Springer hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand,
6597a1579acSMatthias Springer                                       const BufferizationAliasInfo &aliasInfo,
660032be233SMatthias Springer                                       const OneShotAnalysisState &state) {
661032be233SMatthias Springer   SmallVector<Value> worklist;
662032be233SMatthias Springer   worklist.push_back(value);
663032be233SMatthias Springer   while (!worklist.empty()) {
664032be233SMatthias Springer     Value nextVal = worklist.pop_back_val();
665032be233SMatthias Springer     if (!state.isWritable(nextVal))
666032be233SMatthias Springer       return true;
667032be233SMatthias Springer 
668032be233SMatthias Springer     // If `nextVal` is not a BlockArgument: End of use-def chain reached.
669032be233SMatthias Springer     auto opResult = nextVal.dyn_cast<OpResult>();
670032be233SMatthias Springer     if (!opResult)
671032be233SMatthias Springer       continue;
672032be233SMatthias Springer 
673032be233SMatthias Springer     // Follow reverse SSA use-def chain.
674032be233SMatthias Springer     SmallVector<OpOperand *> aliasingOpOperands =
675032be233SMatthias Springer         state.getAliasingOpOperand(opResult);
676032be233SMatthias Springer     for (OpOperand *opOperand : aliasingOpOperands)
677032be233SMatthias Springer       if (aliasInfo.isInPlace(*opOperand) || currentOpOperand == opOperand)
678032be233SMatthias Springer         worklist.push_back(opOperand->get());
679032be233SMatthias Springer   }
6807a1579acSMatthias Springer   return false;
681032be233SMatthias Springer }
6827a1579acSMatthias Springer 
683032be233SMatthias Springer /// Return true if bufferizing `operand` inplace would create a write to a
684032be233SMatthias Springer /// non-writable buffer.
wouldCreateWriteToNonWritableBuffer(OpOperand & operand,const BufferizationAliasInfo & aliasInfo,OneShotAnalysisState & state,bool checkConsistencyOnly=false)685032be233SMatthias Springer static bool wouldCreateWriteToNonWritableBuffer(
686032be233SMatthias Springer     OpOperand &operand, const BufferizationAliasInfo &aliasInfo,
687032be233SMatthias Springer     OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
688032be233SMatthias Springer   // Collect writes of all aliases of OpOperand and OpResult.
689032be233SMatthias Springer   DenseSet<OpOperand *> usesWrite;
690032be233SMatthias Springer   getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state);
691032be233SMatthias Springer   for (OpResult result : state.getAliasingOpResult(operand)) {
692032be233SMatthias Springer     getAliasingInplaceWrites(usesWrite, result, aliasInfo, state);
693032be233SMatthias Springer   }
694032be233SMatthias Springer   if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
695032be233SMatthias Springer     usesWrite.insert(&operand);
6967a1579acSMatthias Springer 
697032be233SMatthias Springer   // Assuming that `operand` bufferizes in-place: For each write (to each
698032be233SMatthias Springer   // alias), check if there is a non-writable tensor in the reverse SSA use-def
699032be233SMatthias Springer   // chain.
700032be233SMatthias Springer   for (OpOperand *uWrite : usesWrite)
701032be233SMatthias Springer     if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand,
702032be233SMatthias Springer                                               aliasInfo, state))
703032be233SMatthias Springer       return true;
7047a1579acSMatthias Springer 
705032be233SMatthias Springer   return false;
7067a1579acSMatthias Springer }
7077a1579acSMatthias Springer 
7087a1579acSMatthias Springer //===----------------------------------------------------------------------===//
7097a1579acSMatthias Springer // Bufferization analyses.
7107a1579acSMatthias Springer //===----------------------------------------------------------------------===//
7117a1579acSMatthias Springer 
7127a1579acSMatthias Springer /// Determine if `operand` can be bufferized in-place.
bufferizableInPlaceAnalysisImpl(OpOperand & operand,BufferizationAliasInfo & aliasInfo,OneShotAnalysisState & state,const DominanceInfo & domInfo)7137a1579acSMatthias Springer static LogicalResult bufferizableInPlaceAnalysisImpl(
714032be233SMatthias Springer     OpOperand &operand, BufferizationAliasInfo &aliasInfo,
715032be233SMatthias Springer     OneShotAnalysisState &state, const DominanceInfo &domInfo) {
7167a1579acSMatthias Springer   bool foundInterference =
7177a1579acSMatthias Springer       wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) ||
7187a1579acSMatthias Springer       wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo);
7197a1579acSMatthias Springer 
7207a1579acSMatthias Springer   if (foundInterference)
7217a1579acSMatthias Springer     aliasInfo.bufferizeOutOfPlace(operand);
7227a1579acSMatthias Springer   else
7237a1579acSMatthias Springer     aliasInfo.bufferizeInPlace(operand, state);
7247a1579acSMatthias Springer 
7257a1579acSMatthias Springer   return success();
7267a1579acSMatthias Springer }
7277a1579acSMatthias Springer 
7287a1579acSMatthias Springer /// Analyze the `ops` to determine which OpOperands are inplaceable. Walk ops in
7297a1579acSMatthias Springer /// reverse and bufferize ops greedily. This is a good starter heuristic.
7307a1579acSMatthias Springer ///
7317a1579acSMatthias Springer /// Even if an op does not read or write, it may still create an alias when
7327a1579acSMatthias Springer /// bufferized in-place. An example of such ops is tensor.extract_slice.
7337a1579acSMatthias Springer ///
7347a1579acSMatthias Springer /// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace:
7357a1579acSMatthias Springer ///
7367a1579acSMatthias Springer /// When bufferized out of place, an ExtractSliceOp lowers to alloc + copy. This
7377a1579acSMatthias Springer /// cannot change the flow of information for either the source or the
7387a1579acSMatthias Springer /// result buffers.
7397a1579acSMatthias Springer ///
7407a1579acSMatthias Springer /// When bufferized inplace, an ExtractSliceOp does not by itself create any
7417a1579acSMatthias Springer /// read or write from memory. Instead, it has the effect of merging the alias
7427a1579acSMatthias Springer /// sets of the source and the result buffers.
7437a1579acSMatthias Springer ///
7447a1579acSMatthias Springer /// An analysis is required to ensure inplace bufferization would not result in
7457a1579acSMatthias Springer /// RaW dependence violations.
inPlaceAnalysis(SmallVector<Operation * > & ops,BufferizationAliasInfo & aliasInfo,OneShotAnalysisState & state,const DominanceInfo & domInfo,unsigned analysisFuzzerSeed=0)7467a1579acSMatthias Springer static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
7477a1579acSMatthias Springer                                      BufferizationAliasInfo &aliasInfo,
748032be233SMatthias Springer                                      OneShotAnalysisState &state,
7497a1579acSMatthias Springer                                      const DominanceInfo &domInfo,
7507a1579acSMatthias Springer                                      unsigned analysisFuzzerSeed = 0) {
7517a1579acSMatthias Springer   if (analysisFuzzerSeed) {
7527a1579acSMatthias Springer     // This is a fuzzer. For testing purposes only. Randomize the order in which
7537a1579acSMatthias Springer     // operations are analyzed. The bufferization quality is likely worse, but
7547a1579acSMatthias Springer     // we want to make sure that no assertions are triggered anywhere.
7557a1579acSMatthias Springer     std::mt19937 g(analysisFuzzerSeed);
7567a1579acSMatthias Springer     llvm::shuffle(ops.begin(), ops.end(), g);
7577a1579acSMatthias Springer   }
7587a1579acSMatthias Springer 
7597a1579acSMatthias Springer   // Walk ops in reverse for better interference analysis.
7607a1579acSMatthias Springer   for (Operation *op : reverse(ops))
7617a1579acSMatthias Springer     for (OpOperand &opOperand : op->getOpOperands())
7627a1579acSMatthias Springer       if (opOperand.get().getType().isa<TensorType>())
7637a1579acSMatthias Springer         if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
7647a1579acSMatthias Springer           if (failed(bufferizableInPlaceAnalysisImpl(opOperand, aliasInfo,
7657a1579acSMatthias Springer                                                      state, domInfo)))
7667a1579acSMatthias Springer             return failure();
7677a1579acSMatthias Springer 
7687a1579acSMatthias Springer   return success();
7697a1579acSMatthias Springer }
7707a1579acSMatthias Springer 
7717a1579acSMatthias Springer /// Return true if the given op has a tensor result or a tensor operand.
hasTensorSemantics(Operation * op)7727a1579acSMatthias Springer static bool hasTensorSemantics(Operation *op) {
7737a1579acSMatthias Springer   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
7747a1579acSMatthias Springer   bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
7757a1579acSMatthias Springer   return hasTensorResult || hasTensorOperand;
7767a1579acSMatthias Springer }
7777a1579acSMatthias Springer 
7787a1579acSMatthias Springer /// Analyze all ops that are contained in `op`.
inPlaceAnalysis(Operation * op,BufferizationAliasInfo & aliasInfo,OneShotAnalysisState & state,const DominanceInfo & domInfo,unsigned analysisFuzzerSeed=0)7797a1579acSMatthias Springer static LogicalResult inPlaceAnalysis(Operation *op,
7807a1579acSMatthias Springer                                      BufferizationAliasInfo &aliasInfo,
781032be233SMatthias Springer                                      OneShotAnalysisState &state,
7827a1579acSMatthias Springer                                      const DominanceInfo &domInfo,
7837a1579acSMatthias Springer                                      unsigned analysisFuzzerSeed = 0) {
7847a1579acSMatthias Springer   // Collect ops so we can build our own reverse traversal.
7857a1579acSMatthias Springer   SmallVector<Operation *> ops;
7867a1579acSMatthias Springer   op->walk([&](Operation *op) {
7877a1579acSMatthias Springer     // No tensors => no buffers.
7887a1579acSMatthias Springer     if (!hasTensorSemantics(op))
7897a1579acSMatthias Springer       return;
7907a1579acSMatthias Springer     ops.push_back(op);
7917a1579acSMatthias Springer   });
7927a1579acSMatthias Springer 
7937a1579acSMatthias Springer   return inPlaceAnalysis(ops, aliasInfo, state, domInfo, analysisFuzzerSeed);
7947a1579acSMatthias Springer }
7957a1579acSMatthias Springer 
7967a1579acSMatthias Springer /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
equivalenceAnalysis(SmallVector<Operation * > & ops,BufferizationAliasInfo & aliasInfo,AnalysisState & state)7977a1579acSMatthias Springer static void equivalenceAnalysis(SmallVector<Operation *> &ops,
7987a1579acSMatthias Springer                                 BufferizationAliasInfo &aliasInfo,
7999597b16aSMatthias Springer                                 AnalysisState &state) {
8007a1579acSMatthias Springer   for (Operation *op : ops)
8017a1579acSMatthias Springer     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
8027a1579acSMatthias Springer       for (OpResult opResult : op->getOpResults())
8037a1579acSMatthias Springer         if (opResult.getType().isa<TensorType>())
8047a1579acSMatthias Springer           for (OpOperand *opOperand :
8057a1579acSMatthias Springer                bufferizableOp.getAliasingOpOperand(opResult, state))
8067a1579acSMatthias Springer             if (state.isInPlace(*opOperand))
8077a1579acSMatthias Springer               if (bufferizableOp.bufferRelation(opResult, state) ==
8087a1579acSMatthias Springer                   BufferRelation::Equivalent)
8097a1579acSMatthias Springer                 aliasInfo.unionEquivalenceClasses(opResult, opOperand->get());
8107a1579acSMatthias Springer }
8117a1579acSMatthias Springer 
8127a1579acSMatthias Springer /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
8137a1579acSMatthias Springer /// in `op`.
equivalenceAnalysis(Operation * op,BufferizationAliasInfo & aliasInfo,AnalysisState & state)8147a1579acSMatthias Springer static void equivalenceAnalysis(Operation *op,
8157a1579acSMatthias Springer                                 BufferizationAliasInfo &aliasInfo,
8169597b16aSMatthias Springer                                 AnalysisState &state) {
8177a1579acSMatthias Springer   // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
8187a1579acSMatthias Springer   SmallVector<Operation *> ops;
8197a1579acSMatthias Springer   op->walk<WalkOrder::PostOrder>([&](Operation *op) {
8207a1579acSMatthias Springer     // No tensors => no buffers.
8217a1579acSMatthias Springer     if (none_of(op->getResultTypes(), isaTensor))
8227a1579acSMatthias Springer       return;
8237a1579acSMatthias Springer     ops.push_back(op);
8247a1579acSMatthias Springer   });
8257a1579acSMatthias Springer 
8267a1579acSMatthias Springer   equivalenceAnalysis(ops, aliasInfo, state);
8277a1579acSMatthias Springer }
8287a1579acSMatthias Springer 
8297a1579acSMatthias Springer /// Assert that the current bufferization decisions are consistent.
8307a1579acSMatthias Springer static LogicalResult
checkAliasInfoConsistency(Operation * op,const DominanceInfo & domInfo,AnalysisState & state,const BufferizationAliasInfo & aliasInfo)8317a1579acSMatthias Springer checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
8329597b16aSMatthias Springer                           AnalysisState &state,
8337a1579acSMatthias Springer                           const BufferizationAliasInfo &aliasInfo) {
8347a1579acSMatthias Springer   const BufferizationOptions &options = state.getOptions();
8357a1579acSMatthias Springer   Operation *inconsistentOp = nullptr;
8367a1579acSMatthias Springer   WalkResult walkResult = op->walk([&](Operation *op) {
8377a1579acSMatthias Springer     if (auto bufferizableOp = options.dynCastBufferizableOp(op))
8387a1579acSMatthias Springer       for (OpOperand &opOperand : op->getOpOperands())
8397a1579acSMatthias Springer         if (opOperand.get().getType().isa<TensorType>()) {
8407a1579acSMatthias Springer           if (wouldCreateReadAfterWriteInterference(
8417a1579acSMatthias Springer                   opOperand, domInfo, state, aliasInfo,
8427a1579acSMatthias Springer                   /*checkConsistencyOnly=*/true)) {
8437a1579acSMatthias Springer             // This error can happen if certain "mustBufferizeInPlace" interface
8447a1579acSMatthias Springer             // methods are implemented incorrectly, such that the IR already has
8457a1579acSMatthias Springer             // a RaW conflict before making any bufferization decisions.
8467a1579acSMatthias Springer             inconsistentOp = op;
8477a1579acSMatthias Springer             return WalkResult::interrupt();
8487a1579acSMatthias Springer           }
8497a1579acSMatthias Springer         }
8507a1579acSMatthias Springer     return WalkResult::advance();
8517a1579acSMatthias Springer   });
8527a1579acSMatthias Springer 
8537a1579acSMatthias Springer   if (walkResult.wasInterrupted())
8547a1579acSMatthias Springer     return inconsistentOp->emitError("input IR has RaW conflict");
8557a1579acSMatthias Springer   return success();
8567a1579acSMatthias Springer }
8577a1579acSMatthias Springer 
8587a1579acSMatthias Springer /// Annotate the IR with the result of the analysis. For testing/debugging only.
8597a1579acSMatthias Springer static void
annotateOpsWithBufferizationMarkers(Operation * op,const BufferizationAliasInfo & aliasInfo,AnalysisState & state)8607a1579acSMatthias Springer annotateOpsWithBufferizationMarkers(Operation *op,
8617a1579acSMatthias Springer                                     const BufferizationAliasInfo &aliasInfo,
8629597b16aSMatthias Springer                                     AnalysisState &state) {
8637a1579acSMatthias Springer   op->walk([&](Operation *op) {
8647a1579acSMatthias Springer     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
8657a1579acSMatthias Springer       for (OpOperand &opOperand : op->getOpOperands())
8667a1579acSMatthias Springer         if (opOperand.get().getType().isa<TensorType>())
8677a1579acSMatthias Springer           setInPlaceOpOperand(opOperand, aliasInfo.isInPlace(opOperand));
8687a1579acSMatthias Springer   });
8697a1579acSMatthias Springer }
8707a1579acSMatthias Springer 
8717a1579acSMatthias Springer /// Assert that IR is in destination-passing style. I.e., every value that is
8727a1579acSMatthias Springer /// returned or yielded from a block is:
8737a1579acSMatthias Springer /// * aliasing a bbArg of that block or a parent block, or
8747a1579acSMatthias Springer /// * aliasing an OpResult of a op in a parent block.
8757a1579acSMatthias Springer ///
8767a1579acSMatthias Springer /// Example:
8777a1579acSMatthias Springer /// ```
8787a1579acSMatthias Springer /// %0 = "some_op" : tensor<?xf32>
8797a1579acSMatthias Springer /// %1 = scf.if %c -> (tensor<?xf32>) {
8807a1579acSMatthias Springer ///   scf.yield %0 : tensor<?xf32>
8817a1579acSMatthias Springer /// } else {
882ffdbecccSMatthias Springer ///   %t = linalg.alloc_tensor : tensor<?xf32>
8837a1579acSMatthias Springer ///   scf.yield %t : tensor<?xf32>
8847a1579acSMatthias Springer /// }
8857a1579acSMatthias Springer /// ```
8867a1579acSMatthias Springer /// In the above example, the first scf.yield op satifies destination-passing
8877a1579acSMatthias Springer /// style because the yielded value %0 is defined in the parent block. The
8887a1579acSMatthias Springer /// second scf.yield op does not satisfy destination-passing style because the
8897a1579acSMatthias Springer /// yielded value %t is defined in the same block as the scf.yield op.
8907a1579acSMatthias Springer // TODO: The current implementation checks for equivalent values instead of
8917a1579acSMatthias Springer // aliasing values, which is stricter than needed. We can currently not check
8927a1579acSMatthias Springer // for aliasing values because the analysis is a maybe-alias analysis and we
8937a1579acSMatthias Springer // need a must-alias analysis here.
894cdb7675cSMatthias Springer static LogicalResult
assertDestinationPassingStyle(Operation * op,AnalysisState & state,BufferizationAliasInfo & aliasInfo,SmallVector<Operation * > & newOps)8959597b16aSMatthias Springer assertDestinationPassingStyle(Operation *op, AnalysisState &state,
8967a1579acSMatthias Springer                               BufferizationAliasInfo &aliasInfo,
897cdb7675cSMatthias Springer                               SmallVector<Operation *> &newOps) {
8987a1579acSMatthias Springer   LogicalResult status = success();
8997a1579acSMatthias Springer   DominanceInfo domInfo(op);
9007a1579acSMatthias Springer   op->walk([&](Operation *returnOp) {
9013b426868SMatthias Springer     if (!isRegionReturnLike(returnOp) ||
9023b426868SMatthias Springer         !state.getOptions().isOpAllowed(returnOp))
9037a1579acSMatthias Springer       return WalkResult::advance();
9047a1579acSMatthias Springer 
9057a1579acSMatthias Springer     for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
9067a1579acSMatthias Springer       Value returnVal = returnValOperand.get();
9077a1579acSMatthias Springer       // Skip non-tensor values.
9087a1579acSMatthias Springer       if (!returnVal.getType().isa<TensorType>())
9097a1579acSMatthias Springer         continue;
9107a1579acSMatthias Springer 
9117a1579acSMatthias Springer       bool foundEquivValue = false;
9127a1579acSMatthias Springer       aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
9137a1579acSMatthias Springer         if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
9147a1579acSMatthias Springer           Operation *definingOp = bbArg.getOwner()->getParentOp();
9157a1579acSMatthias Springer           if (definingOp->isProperAncestor(returnOp))
9167a1579acSMatthias Springer             foundEquivValue = true;
9177a1579acSMatthias Springer           return;
9187a1579acSMatthias Springer         }
9197a1579acSMatthias Springer 
9207a1579acSMatthias Springer         Operation *definingOp = equivVal.getDefiningOp();
9217a1579acSMatthias Springer         if (definingOp->getBlock()->findAncestorOpInBlock(
9227a1579acSMatthias Springer                 *returnOp->getParentOp()))
9237a1579acSMatthias Springer           // Skip ops that happen after `returnOp` and parent ops.
9247a1579acSMatthias Springer           if (happensBefore(definingOp, returnOp, domInfo))
9257a1579acSMatthias Springer             foundEquivValue = true;
9267a1579acSMatthias Springer       });
9277a1579acSMatthias Springer 
9287a1579acSMatthias Springer       if (!foundEquivValue)
9297a1579acSMatthias Springer         status =
9307a1579acSMatthias Springer             returnOp->emitError()
9317a1579acSMatthias Springer             << "operand #" << returnValOperand.getOperandNumber()
9327a1579acSMatthias Springer             << " of ReturnLike op does not satisfy destination passing style";
9337a1579acSMatthias Springer     }
9347a1579acSMatthias Springer 
9357a1579acSMatthias Springer     return WalkResult::advance();
9367a1579acSMatthias Springer   });
9377a1579acSMatthias Springer 
9387a1579acSMatthias Springer   return status;
9397a1579acSMatthias Springer }
9407a1579acSMatthias Springer 
analyzeOp(Operation * op,OneShotAnalysisState & state)9417a1579acSMatthias Springer LogicalResult bufferization::analyzeOp(Operation *op,
9429597b16aSMatthias Springer                                        OneShotAnalysisState &state) {
9437a1579acSMatthias Springer   DominanceInfo domInfo(op);
9447a1579acSMatthias Springer   BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
9457a1579acSMatthias Springer   const auto &options =
9469597b16aSMatthias Springer       static_cast<const OneShotBufferizationOptions &>(state.getOptions());
9477a1579acSMatthias Springer 
948d6dab38aSMatthias Springer   // Catch incorrect API usage.
949d6dab38aSMatthias Springer   assert((state.hasDialectState(func::FuncDialect::getDialectNamespace()) ||
950d6dab38aSMatthias Springer           !options.bufferizeFunctionBoundaries) &&
951d6dab38aSMatthias Springer          "must use ModuleBufferize to bufferize function boundaries");
952d6dab38aSMatthias Springer 
9537a1579acSMatthias Springer   if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
9547a1579acSMatthias Springer     return failure();
9557a1579acSMatthias Springer 
9567a1579acSMatthias Springer   // If the analysis fails, just return.
9577a1579acSMatthias Springer   if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo,
9587a1579acSMatthias Springer                              options.analysisFuzzerSeed)))
9597a1579acSMatthias Springer     return failure();
9607a1579acSMatthias Springer   equivalenceAnalysis(op, aliasInfo, state);
9617a1579acSMatthias Springer 
962d1d79920SMatthias Springer   bool failedAnalysis = false;
963855a11eeSMatthias Springer   if (!options.allowReturnAllocs) {
9647a1579acSMatthias Springer     SmallVector<Operation *> newOps;
965d1d79920SMatthias Springer     failedAnalysis |=
966d1d79920SMatthias Springer         failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps));
9677a1579acSMatthias Springer   }
9687a1579acSMatthias Springer 
969988748c0SMatthias Springer   // Gather some extra analysis data.
9709e24f0f4SMatthias Springer   state.gatherYieldedTensors(op);
971988748c0SMatthias Springer   state.gatherUndefinedTensorUses(op);
9729e24f0f4SMatthias Springer 
9734ec00fb3SMatthias Springer   // Analysis verification: After setting up alias/equivalence sets, each op
9744ec00fb3SMatthias Springer   // can check for expected invariants/limitations and fail the analysis if
9754ec00fb3SMatthias Springer   // necessary.
9764ec00fb3SMatthias Springer   op->walk([&](Operation *op) {
9774ec00fb3SMatthias Springer     if (BufferizableOpInterface bufferizableOp =
9784ec00fb3SMatthias Springer             options.dynCastBufferizableOp(op))
979d1d79920SMatthias Springer       failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state));
9804ec00fb3SMatthias Springer   });
9814ec00fb3SMatthias Springer 
9827a1579acSMatthias Springer   // Annotate operations if we only want to report the analysis.
9837a1579acSMatthias Springer   if (options.testAnalysisOnly)
9847a1579acSMatthias Springer     annotateOpsWithBufferizationMarkers(op, aliasInfo, state);
9857a1579acSMatthias Springer 
986d1d79920SMatthias Springer   return success(!failedAnalysis);
9877a1579acSMatthias Springer }
9887a1579acSMatthias Springer 
9899597b16aSMatthias Springer LogicalResult
runOneShotBufferize(Operation * op,const OneShotBufferizationOptions & options)9909597b16aSMatthias Springer bufferization::runOneShotBufferize(Operation *op,
9919597b16aSMatthias Springer                                    const OneShotBufferizationOptions &options) {
9929597b16aSMatthias Springer   OneShotAnalysisState state(op, options);
993*b3ebe3beSMatthias Springer   if (failed(insertTensorCopies(op, options)))
9947a1579acSMatthias Springer     return failure();
995d2dacde5SMatthias Springer   if (options.testAnalysisOnly)
9967a1579acSMatthias Springer     return success();
997*b3ebe3beSMatthias Springer   return bufferizeOp(op, options, /*copyBeforeWrite=*/false);
9987a1579acSMatthias Springer }
999