17a1579acSMatthias Springer //===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===//
27a1579acSMatthias Springer //
37a1579acSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47a1579acSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
57a1579acSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67a1579acSMatthias Springer //
77a1579acSMatthias Springer //===----------------------------------------------------------------------===//
87a1579acSMatthias Springer //
97a1579acSMatthias Springer // One-Shot Analysis analyzes function bodies. Function boundaries (FuncOp
107a1579acSMatthias Springer // bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
117a1579acSMatthias Springer // ModuleBufferization.cpp is an extension of One-Shot Analysis for simple
127a1579acSMatthias Springer // call graphs.
137a1579acSMatthias Springer //
147a1579acSMatthias Springer // One-Shot Bufferize consists of two phases.
157a1579acSMatthias Springer //
167a1579acSMatthias Springer // 1. Analyze ops to decide which OpResults can bufferize inplace, i.e., without
177a1579acSMatthias Springer // inserting buffer copies. The analysis queries op bufferization semantics
187a1579acSMatthias Springer // via `BufferizableOpInterface`.
197a1579acSMatthias Springer // 2. Bufferize ops by calling `BufferizableOpInterface::bufferize`. This
207a1579acSMatthias Springer // function does not generate buffer copies for OpResults that were decided
217a1579acSMatthias Springer // to bufferize inplace during the analysis phase.
227a1579acSMatthias Springer //
237a1579acSMatthias Springer // This file contains only the analysis. The actual bufferization is implemented
247a1579acSMatthias Springer // via `bufferizeOp` (Bufferize.h). For convenience, this file also contains a
257a1579acSMatthias Springer // helper function `runOneShotBufferize` that analyzes an op (and its nested
267a1579acSMatthias Springer // ops) and then bufferizes it.
277a1579acSMatthias Springer //
287a1579acSMatthias Springer // Inplace bufferization decisions are passed from the analysis to the
299597b16aSMatthias Springer // bufferization phase via `AnalysisState` and `BufferizationAliasInfo`.
307a1579acSMatthias Springer // They can be printed for debugging purposes with `testAnalysisOnly`.
317a1579acSMatthias Springer //
327a1579acSMatthias Springer // Ops that do not implement `BufferizableOpInterface` can be analyzed but are
337a1579acSMatthias Springer // treated conservatively. E.g., the analysis has to assume that their tensor
347a1579acSMatthias Springer // OpOperands bufferize to memory writes. While such ops can be analyzed, they
357a1579acSMatthias Springer // are not bufferized and remain in the IR. to_tensor and to_memref ops are
367a1579acSMatthias Springer // inserted at the bufferization boundary.
377a1579acSMatthias Springer //
387a1579acSMatthias Springer // This analysis caters to high-performance codegen where buffer reuse is deemed
397a1579acSMatthias Springer // critical: the analysis should fail if the bufferized form of the function
40855a11eeSMatthias Springer // needs to return a buffer, unless `allowReturnAllocs` is enabled.
417a1579acSMatthias Springer
427a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
437a1579acSMatthias Springer
447a1579acSMatthias Springer #include <random>
457a1579acSMatthias Springer
467a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
477a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
487a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
49*b3ebe3beSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h"
50d6dab38aSMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h"
517a1579acSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
527a1579acSMatthias Springer #include "mlir/IR/AsmState.h"
537a1579acSMatthias Springer #include "mlir/IR/Dominance.h"
547a1579acSMatthias Springer #include "mlir/IR/Operation.h"
557a1579acSMatthias Springer #include "mlir/IR/TypeUtilities.h"
567a1579acSMatthias Springer #include "mlir/Interfaces/ControlFlowInterfaces.h"
577a1579acSMatthias Springer #include "llvm/ADT/DenseSet.h"
587a1579acSMatthias Springer #include "llvm/ADT/SetVector.h"
597a1579acSMatthias Springer
607a1579acSMatthias Springer using namespace mlir;
617a1579acSMatthias Springer using namespace mlir::bufferization;
627a1579acSMatthias Springer
isaTensor(Type t)637a1579acSMatthias Springer static bool isaTensor(Type t) { return t.isa<TensorType>(); }
647a1579acSMatthias Springer
657a1579acSMatthias Springer //===----------------------------------------------------------------------===//
667a1579acSMatthias Springer // Bufferization-specific attribute manipulation.
677a1579acSMatthias Springer // These are for testing and debugging only. Bufferization information is
687a1579acSMatthias Springer // stored in BufferizationAliasInfo. When run with `testAnalysisOnly`, the IR
697a1579acSMatthias Springer // is annotated with the results of the analysis (copied from
707a1579acSMatthias Springer // BufferizationAliasInfo), so that they can be checked in tests.
717a1579acSMatthias Springer //===----------------------------------------------------------------------===//
727a1579acSMatthias Springer
737a1579acSMatthias Springer /// Attribute marker to specify op results that can be bufferized inPlace.
747a1579acSMatthias Springer constexpr StringLiteral kInPlaceResultsAttrName = "__inplace_operands_attr__";
757a1579acSMatthias Springer
767a1579acSMatthias Springer /// Mark whether OpOperand will be bufferized inplace.
setInPlaceOpOperand(OpOperand & opOperand,bool inPlace)777a1579acSMatthias Springer static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
787a1579acSMatthias Springer Operation *op = opOperand.getOwner();
797a1579acSMatthias Springer auto attr =
807a1579acSMatthias Springer op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null<ArrayAttr>();
817a1579acSMatthias Springer SmallVector<StringRef> inPlaceVector;
827a1579acSMatthias Springer if (attr) {
837a1579acSMatthias Springer inPlaceVector = SmallVector<StringRef>(
847a1579acSMatthias Springer llvm::to_vector<4>(attr.getAsValueRange<StringAttr>()));
857a1579acSMatthias Springer } else {
867a1579acSMatthias Springer inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
877a1579acSMatthias Springer for (OpOperand &opOperand : op->getOpOperands())
887a1579acSMatthias Springer if (opOperand.get().getType().isa<TensorType>())
897a1579acSMatthias Springer inPlaceVector[opOperand.getOperandNumber()] = "false";
907a1579acSMatthias Springer }
917a1579acSMatthias Springer
927a1579acSMatthias Springer inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
937a1579acSMatthias Springer op->setAttr(kInPlaceResultsAttrName,
947a1579acSMatthias Springer OpBuilder(op).getStrArrayAttr(inPlaceVector));
957a1579acSMatthias Springer }
967a1579acSMatthias Springer
977a1579acSMatthias Springer //===----------------------------------------------------------------------===//
987a1579acSMatthias Springer // BufferizationAliasInfo
997a1579acSMatthias Springer //===----------------------------------------------------------------------===//
1007a1579acSMatthias Springer
BufferizationAliasInfo(Operation * rootOp)1017a1579acSMatthias Springer BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
1027a1579acSMatthias Springer rootOp->walk([&](Operation *op) {
1037a1579acSMatthias Springer for (Value v : op->getResults())
1047a1579acSMatthias Springer if (v.getType().isa<TensorType>())
1057a1579acSMatthias Springer createAliasInfoEntry(v);
1067a1579acSMatthias Springer for (Region &r : op->getRegions())
1077a1579acSMatthias Springer for (Block &b : r.getBlocks())
1087a1579acSMatthias Springer for (auto bbArg : b.getArguments())
1097a1579acSMatthias Springer if (bbArg.getType().isa<TensorType>())
1107a1579acSMatthias Springer createAliasInfoEntry(bbArg);
1117a1579acSMatthias Springer });
1127a1579acSMatthias Springer }
1137a1579acSMatthias Springer
1147a1579acSMatthias Springer /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
1157a1579acSMatthias Springer /// beginning the alias and equivalence sets only contain `v` itself.
createAliasInfoEntry(Value v)1167a1579acSMatthias Springer void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
1177a1579acSMatthias Springer aliasInfo.insert(v);
1187a1579acSMatthias Springer equivalentInfo.insert(v);
1197a1579acSMatthias Springer }
1207a1579acSMatthias Springer
1217a1579acSMatthias Springer /// Insert an info entry for `newValue` and merge its alias set with that of
1227a1579acSMatthias Springer /// `alias`.
insertNewBufferAlias(Value newValue,Value alias)1237a1579acSMatthias Springer void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
1247a1579acSMatthias Springer createAliasInfoEntry(newValue);
1257a1579acSMatthias Springer aliasInfo.unionSets(newValue, alias);
1267a1579acSMatthias Springer }
1277a1579acSMatthias Springer
1287a1579acSMatthias Springer /// Insert an info entry for `newValue` and merge its alias set with that of
1297a1579acSMatthias Springer /// `alias`. Additionally, merge their equivalence classes.
insertNewBufferEquivalence(Value newValue,Value alias)1307a1579acSMatthias Springer void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
1317a1579acSMatthias Springer Value alias) {
1327a1579acSMatthias Springer insertNewBufferAlias(newValue, alias);
1337a1579acSMatthias Springer equivalentInfo.unionSets(newValue, alias);
1347a1579acSMatthias Springer }
1357a1579acSMatthias Springer
1367a1579acSMatthias Springer /// Return `true` if a value was marked as in-place bufferized.
isInPlace(OpOperand & operand) const1377a1579acSMatthias Springer bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
1387a1579acSMatthias Springer return inplaceBufferized.contains(&operand);
1397a1579acSMatthias Springer }
1407a1579acSMatthias Springer
1417a1579acSMatthias Springer /// Set the inPlace bufferization spec to true.
bufferizeInPlace(OpOperand & operand,AnalysisState & state)1427a1579acSMatthias Springer void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
1439597b16aSMatthias Springer AnalysisState &state) {
1447a1579acSMatthias Springer markInPlace(operand);
145585a8a32SMatthias Springer for (OpResult result : state.getAliasingOpResult(operand))
1467a1579acSMatthias Springer aliasInfo.unionSets(result, operand.get());
1477a1579acSMatthias Springer }
1487a1579acSMatthias Springer
1497a1579acSMatthias Springer /// Set the inPlace bufferization spec to false.
bufferizeOutOfPlace(OpOperand & operand)1507a1579acSMatthias Springer void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) {
1517a1579acSMatthias Springer assert(!inplaceBufferized.contains(&operand) &&
1527a1579acSMatthias Springer "OpOperand was already decided to bufferize inplace");
1537a1579acSMatthias Springer }
1547a1579acSMatthias Springer
1557a1579acSMatthias Springer /// Apply `fun` to all the members of the equivalence class of `v`.
applyOnEquivalenceClass(Value v,function_ref<void (Value)> fun) const1567a1579acSMatthias Springer void BufferizationAliasInfo::applyOnEquivalenceClass(
1577a1579acSMatthias Springer Value v, function_ref<void(Value)> fun) const {
1587a1579acSMatthias Springer auto leaderIt = equivalentInfo.findLeader(v);
1597a1579acSMatthias Springer for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
1607a1579acSMatthias Springer ++mit) {
1617a1579acSMatthias Springer fun(*mit);
1627a1579acSMatthias Springer }
1637a1579acSMatthias Springer }
1647a1579acSMatthias Springer
1657a1579acSMatthias Springer /// Apply `fun` to all aliases of `v`.
applyOnAliases(Value v,function_ref<void (Value)> fun) const1667a1579acSMatthias Springer void BufferizationAliasInfo::applyOnAliases(
1677a1579acSMatthias Springer Value v, function_ref<void(Value)> fun) const {
1687a1579acSMatthias Springer auto leaderIt = aliasInfo.findLeader(v);
1697a1579acSMatthias Springer for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
1707a1579acSMatthias Springer fun(*mit);
1717a1579acSMatthias Springer }
1727a1579acSMatthias Springer }
1737a1579acSMatthias Springer
1747a1579acSMatthias Springer BufferizationAliasInfo::EquivalenceClassRangeType
getAliases(Value v) const1757a1579acSMatthias Springer BufferizationAliasInfo::getAliases(Value v) const {
1767a1579acSMatthias Springer DenseSet<Value> res;
1777a1579acSMatthias Springer auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v));
1787a1579acSMatthias Springer for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
1797a1579acSMatthias Springer mit != meit; ++mit) {
1807a1579acSMatthias Springer res.insert(static_cast<Value>(*mit));
1817a1579acSMatthias Springer }
1827a1579acSMatthias Springer return BufferizationAliasInfo::EquivalenceClassRangeType(
1837a1579acSMatthias Springer aliasInfo.member_begin(it), aliasInfo.member_end());
1847a1579acSMatthias Springer }
1857a1579acSMatthias Springer
1867a1579acSMatthias Springer //===----------------------------------------------------------------------===//
1879597b16aSMatthias Springer // OneShotAnalysisState
1887a1579acSMatthias Springer //===----------------------------------------------------------------------===//
1897a1579acSMatthias Springer
OneShotAnalysisState(Operation * op,const OneShotBufferizationOptions & options)1909597b16aSMatthias Springer OneShotAnalysisState::OneShotAnalysisState(
1919597b16aSMatthias Springer Operation *op, const OneShotBufferizationOptions &options)
1929597b16aSMatthias Springer : AnalysisState(options), aliasInfo(op) {
1937a1579acSMatthias Springer // Set up alias sets for OpResults that must bufferize in-place. This should
1947a1579acSMatthias Springer // be done before making any other bufferization decisions.
1957a1579acSMatthias Springer op->walk([&](BufferizableOpInterface bufferizableOp) {
1967a1579acSMatthias Springer if (!options.isOpAllowed(bufferizableOp))
1977a1579acSMatthias Springer return WalkResult::skip();
1987a1579acSMatthias Springer for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
1997a1579acSMatthias Springer if (opOperand.get().getType().isa<TensorType>())
2007a1579acSMatthias Springer if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) {
201585a8a32SMatthias Springer for (OpResult opResult :
2027a1579acSMatthias Springer bufferizableOp.getAliasingOpResult(opOperand, *this))
2037a1579acSMatthias Springer aliasInfo.unionAliasSets(opOperand.get(), opResult);
2047a1579acSMatthias Springer aliasInfo.markInPlace(opOperand);
2057a1579acSMatthias Springer }
2067a1579acSMatthias Springer }
2077a1579acSMatthias Springer return WalkResult::advance();
2087a1579acSMatthias Springer });
2097a1579acSMatthias Springer }
2107a1579acSMatthias Springer
isInPlace(OpOperand & opOperand) const2119597b16aSMatthias Springer bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const {
2127a1579acSMatthias Springer return aliasInfo.isInPlace(opOperand);
2137a1579acSMatthias Springer }
2147a1579acSMatthias Springer
areEquivalentBufferizedValues(Value v1,Value v2) const2159597b16aSMatthias Springer bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
2167a1579acSMatthias Springer Value v2) const {
2177a1579acSMatthias Springer return aliasInfo.areEquivalentBufferizedValues(v1, v2);
2187a1579acSMatthias Springer }
2197a1579acSMatthias Springer
areAliasingBufferizedValues(Value v1,Value v2) const2203490aadfSMatthias Springer bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1,
2213490aadfSMatthias Springer Value v2) const {
2223490aadfSMatthias Springer return aliasInfo.areAliasingBufferizedValues(v1, v2);
2233490aadfSMatthias Springer }
2243490aadfSMatthias Springer
2259e24f0f4SMatthias Springer // Gather yielded tensors in `yieldedTensors` by querying all aliases. This is
2269e24f0f4SMatthias Springer // to ensure that such information is available during bufferization time.
2279e24f0f4SMatthias Springer // Alias information can no longer be queried through BufferizationAliasInfo
2289e24f0f4SMatthias Springer // once we have started modifying the IR.
gatherYieldedTensors(Operation * op)2299e24f0f4SMatthias Springer void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
2309e24f0f4SMatthias Springer op->walk([&](Operation *returnOp) {
2319e24f0f4SMatthias Springer if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp))
2329e24f0f4SMatthias Springer return WalkResult::advance();
2339e24f0f4SMatthias Springer
2349e24f0f4SMatthias Springer for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
2359e24f0f4SMatthias Springer Value returnVal = returnValOperand.get();
2369e24f0f4SMatthias Springer // Skip non-tensor values.
2379e24f0f4SMatthias Springer if (!returnVal.getType().isa<TensorType>())
2389e24f0f4SMatthias Springer continue;
2399e24f0f4SMatthias Springer
2409e24f0f4SMatthias Springer // Add all aliases of the returned value. But only the ones that are in
2419e24f0f4SMatthias Springer // the same block.
2429e24f0f4SMatthias Springer aliasInfo.applyOnAliases(returnVal, [&](Value v) {
2439e24f0f4SMatthias Springer if (auto bbArg = v.dyn_cast<BlockArgument>()) {
2449e24f0f4SMatthias Springer if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp())
2459e24f0f4SMatthias Springer yieldedTensors.insert(bbArg);
2469e24f0f4SMatthias Springer return;
2479e24f0f4SMatthias Springer }
2489e24f0f4SMatthias Springer Operation *definingOp = v.getDefiningOp();
2499e24f0f4SMatthias Springer if (definingOp->getParentOp() == returnOp->getParentOp())
2509e24f0f4SMatthias Springer yieldedTensors.insert(v);
2519e24f0f4SMatthias Springer });
2529e24f0f4SMatthias Springer }
2539e24f0f4SMatthias Springer
2549e24f0f4SMatthias Springer return WalkResult::advance();
2559e24f0f4SMatthias Springer });
2569e24f0f4SMatthias Springer }
2579e24f0f4SMatthias Springer
gatherUndefinedTensorUses(Operation * op)258988748c0SMatthias Springer void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
259988748c0SMatthias Springer op->walk([&](Operation *op) {
260988748c0SMatthias Springer // Skip unknown ops.
261988748c0SMatthias Springer auto bufferizableOp = getOptions().dynCastBufferizableOp(op);
262988748c0SMatthias Springer if (!bufferizableOp)
263988748c0SMatthias Springer return WalkResult::skip();
264988748c0SMatthias Springer
265988748c0SMatthias Springer // Check all tensor OpResults.
266988748c0SMatthias Springer for (OpResult opResult : op->getOpResults()) {
267988748c0SMatthias Springer if (!opResult.getType().isa<TensorType>())
268988748c0SMatthias Springer continue;
269988748c0SMatthias Springer
270988748c0SMatthias Springer // If there is no preceding memory write, the tensor contents are
271988748c0SMatthias Springer // undefined.
272988748c0SMatthias Springer // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
273988748c0SMatthias Springer // use-def chain, it returns that value, regardless of whether it is a
274988748c0SMatthias Springer // memory write or not.
275988748c0SMatthias Springer SetVector<Value> lastWrites = findLastPrecedingWrite(opResult);
276988748c0SMatthias Springer bool isUndefined = llvm::none_of(lastWrites, [&](Value lastWrite) {
277988748c0SMatthias Springer if (auto bufferizableOp = getOptions().dynCastBufferizableOp(lastWrite))
278988748c0SMatthias Springer return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(),
279988748c0SMatthias Springer *this);
280988748c0SMatthias Springer return true;
281988748c0SMatthias Springer });
282988748c0SMatthias Springer if (isUndefined)
283988748c0SMatthias Springer for (OpOperand &use : opResult.getUses())
284988748c0SMatthias Springer undefinedTensorUses.insert(&use);
285988748c0SMatthias Springer }
286988748c0SMatthias Springer
287988748c0SMatthias Springer return WalkResult::advance();
288988748c0SMatthias Springer });
289988748c0SMatthias Springer }
290988748c0SMatthias Springer
hasUndefinedContents(OpOperand * opOperand) const291988748c0SMatthias Springer bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
292988748c0SMatthias Springer return undefinedTensorUses.contains(opOperand);
293988748c0SMatthias Springer }
294988748c0SMatthias Springer
isTensorYielded(Value tensor) const2959e24f0f4SMatthias Springer bool OneShotAnalysisState::isTensorYielded(Value tensor) const {
2969e24f0f4SMatthias Springer return yieldedTensors.contains(tensor);
2979e24f0f4SMatthias Springer }
2989e24f0f4SMatthias Springer
isValueWritten(Value value) const2993490aadfSMatthias Springer bool OneShotAnalysisState::isValueWritten(Value value) const {
3003490aadfSMatthias Springer bool isWritten = false;
3013490aadfSMatthias Springer aliasInfo.applyOnAliases(value, [&](Value val) {
3023490aadfSMatthias Springer for (OpOperand &use : val.getUses())
3033490aadfSMatthias Springer if (isInPlace(use) && bufferizesToMemoryWrite(use))
3043490aadfSMatthias Springer isWritten = true;
3053490aadfSMatthias Springer });
3063490aadfSMatthias Springer return isWritten;
3073490aadfSMatthias Springer }
3083490aadfSMatthias Springer
isWritable(Value value) const309032be233SMatthias Springer bool OneShotAnalysisState::isWritable(Value value) const {
310032be233SMatthias Springer // TODO: Out-of-place bufferized value could be considered writable.
311032be233SMatthias Springer if (auto bufferizableOp = getOptions().dynCastBufferizableOp(value))
312032be233SMatthias Springer return bufferizableOp.isWritable(value, *this);
313032be233SMatthias Springer
314032be233SMatthias Springer // Query BufferizableOpInterface to see if the BlockArgument is writable.
315032be233SMatthias Springer if (auto bbArg = value.dyn_cast<BlockArgument>())
316032be233SMatthias Springer if (auto bufferizableOp =
317032be233SMatthias Springer getOptions().dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
318032be233SMatthias Springer return bufferizableOp.isWritable(bbArg, *this);
319032be233SMatthias Springer
320032be233SMatthias Springer // Not a bufferizable op: The conservative answer is "not writable".
321032be233SMatthias Springer return false;
322032be233SMatthias Springer }
323032be233SMatthias Springer
3247a1579acSMatthias Springer //===----------------------------------------------------------------------===//
3257a1579acSMatthias Springer // Bufferization-specific alias analysis.
3267a1579acSMatthias Springer //===----------------------------------------------------------------------===//
3277a1579acSMatthias Springer
3287a1579acSMatthias Springer /// Return true if opOperand has been decided to bufferize in-place.
isInplaceMemoryWrite(OpOperand & opOperand,const BufferizationAliasInfo & aliasInfo,const AnalysisState & state)3297a1579acSMatthias Springer static bool isInplaceMemoryWrite(OpOperand &opOperand,
3307a1579acSMatthias Springer const BufferizationAliasInfo &aliasInfo,
331032be233SMatthias Springer const AnalysisState &state) {
3327a1579acSMatthias Springer // OpOperands that do not bufferize to a memory write do not write in-place.
3337a1579acSMatthias Springer if (!state.bufferizesToMemoryWrite(opOperand))
3347a1579acSMatthias Springer return false;
3357a1579acSMatthias Springer // Check current bufferization decisions.
3367a1579acSMatthias Springer return aliasInfo.isInPlace(opOperand);
3377a1579acSMatthias Springer }
3387a1579acSMatthias Springer
3397a1579acSMatthias Springer /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
3407a1579acSMatthias Springer /// properly dominates `b` and `b` is not inside `a`.
happensBefore(Operation * a,Operation * b,const DominanceInfo & domInfo)3417a1579acSMatthias Springer static bool happensBefore(Operation *a, Operation *b,
3427a1579acSMatthias Springer const DominanceInfo &domInfo) {
3437a1579acSMatthias Springer do {
3447a1579acSMatthias Springer // TODO: Instead of isProperAncestor + properlyDominates, we should use
3457a1579acSMatthias Springer // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
3467a1579acSMatthias Springer if (a->isProperAncestor(b))
3477a1579acSMatthias Springer return false;
3487a1579acSMatthias Springer if (domInfo.properlyDominates(a, b))
3497a1579acSMatthias Springer return true;
3507a1579acSMatthias Springer } while ((a = a->getParentOp()));
3517a1579acSMatthias Springer return false;
3527a1579acSMatthias Springer }
3537a1579acSMatthias Springer
3549235e597SMatthias Springer /// For each given value, find the closest enclosing repetitive region. If this
3559235e597SMatthias Springer /// is the same region for each value, return it. Otherwise return None.
3569235e597SMatthias Springer /// Note: If there is no enclosing repetitive region, return nullptr.
3579235e597SMatthias Springer static Optional<Region *>
getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values)3589235e597SMatthias Springer getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values) {
3599235e597SMatthias Springer if (values.empty())
3609235e597SMatthias Springer return None;
3619235e597SMatthias Springer Region *r = getEnclosingRepetitiveRegion(values.front());
3629235e597SMatthias Springer for (Value value : values.drop_front())
3639235e597SMatthias Springer if (getEnclosingRepetitiveRegion(value) != r)
3649235e597SMatthias Springer return None;
3659235e597SMatthias Springer return r;
3669235e597SMatthias Springer }
3679235e597SMatthias Springer
36837a14735SMatthias Springer /// Return `true` if the given tensor value is a memory write. Most values are
36937a14735SMatthias Springer /// tensor writes, but ops that define a tensor SSA value without specifying its
370ffdbecccSMatthias Springer /// contents (e.g., alloc_tensor) are not.
isMemoryWrite(Value value,const AnalysisState & state)37137a14735SMatthias Springer static bool isMemoryWrite(Value value, const AnalysisState &state) {
37237a14735SMatthias Springer auto opResult = value.dyn_cast<OpResult>();
37337a14735SMatthias Springer if (!opResult)
37437a14735SMatthias Springer return true;
37537a14735SMatthias Springer auto bufferizableOp = state.getOptions().dynCastBufferizableOp(value);
37637a14735SMatthias Springer if (!bufferizableOp)
37737a14735SMatthias Springer return true;
37837a14735SMatthias Springer return bufferizableOp.isMemoryWrite(opResult, state);
37937a14735SMatthias Springer }
38037a14735SMatthias Springer
3817a1579acSMatthias Springer /// Annotate IR with details about the detected RaW conflict.
annotateConflict(OpOperand * uRead,OpOperand * uConflictingWrite,Value lastWrite)3827a1579acSMatthias Springer static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
3837a1579acSMatthias Springer Value lastWrite) {
3847a1579acSMatthias Springer static uint64_t counter = 0;
3857a1579acSMatthias Springer Operation *readingOp = uRead->getOwner();
3867a1579acSMatthias Springer Operation *conflictingWritingOp = uConflictingWrite->getOwner();
3877a1579acSMatthias Springer
3887a1579acSMatthias Springer OpBuilder b(conflictingWritingOp->getContext());
3897a1579acSMatthias Springer std::string id = "C_" + std::to_string(counter++);
3907a1579acSMatthias Springer
3917a1579acSMatthias Springer std::string conflictingWriteAttr =
3927a1579acSMatthias Springer id +
3937a1579acSMatthias Springer "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) +
3947a1579acSMatthias Springer "]";
3957a1579acSMatthias Springer conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr());
3967a1579acSMatthias Springer
3977a1579acSMatthias Springer std::string readAttr =
3987a1579acSMatthias Springer id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
3997a1579acSMatthias Springer readingOp->setAttr(readAttr, b.getUnitAttr());
4007a1579acSMatthias Springer
4017a1579acSMatthias Springer if (auto opResult = lastWrite.dyn_cast<OpResult>()) {
4027a1579acSMatthias Springer std::string lastWriteAttr = id + "[LAST-WRITE: result " +
4037a1579acSMatthias Springer std::to_string(opResult.getResultNumber()) +
4047a1579acSMatthias Springer "]";
4057a1579acSMatthias Springer opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr());
4067a1579acSMatthias Springer } else {
4077a1579acSMatthias Springer auto bbArg = lastWrite.cast<BlockArgument>();
4087a1579acSMatthias Springer std::string lastWriteAttr =
4097a1579acSMatthias Springer id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
4107a1579acSMatthias Springer bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr());
4117a1579acSMatthias Springer }
4127a1579acSMatthias Springer }
4137a1579acSMatthias Springer
4147a1579acSMatthias Springer /// Given sets of uses and writes, return true if there is a RaW conflict under
4157a1579acSMatthias Springer /// the assumption that all given reads/writes alias the same buffer and that
4167a1579acSMatthias Springer /// all given writes bufferize inplace.
4177a1579acSMatthias Springer ///
4187a1579acSMatthias Springer /// A conflict is: According to SSA use-def chains, a read R is supposed to read
4197a1579acSMatthias Springer /// the result of a write W1. But because of bufferization decisions, R actually
4207a1579acSMatthias Springer /// reads another write W2.
hasReadAfterWriteInterference(const DenseSet<OpOperand * > & usesRead,const DenseSet<OpOperand * > & usesWrite,const DominanceInfo & domInfo,AnalysisState & state,const BufferizationAliasInfo & aliasInfo)4217a1579acSMatthias Springer static bool hasReadAfterWriteInterference(
4227a1579acSMatthias Springer const DenseSet<OpOperand *> &usesRead,
4237a1579acSMatthias Springer const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo,
4249597b16aSMatthias Springer AnalysisState &state, const BufferizationAliasInfo &aliasInfo) {
4257a1579acSMatthias Springer const BufferizationOptions &options = state.getOptions();
4267a1579acSMatthias Springer
42737a14735SMatthias Springer // Gather all written aliases. Skip over aliases that are not actual writes.
4289235e597SMatthias Springer SmallVector<Value> writtenAliases;
4299235e597SMatthias Springer for (OpOperand *uWrite : usesWrite)
43037a14735SMatthias Springer if (isMemoryWrite(uWrite->get(), state))
4319235e597SMatthias Springer writtenAliases.push_back(uWrite->get());
4329235e597SMatthias Springer // Find the inner-most enclosing repetitive region of each alias. If this is
4339235e597SMatthias Springer // the same region for every alias, save it in `repetitiveRegionOfWrites`.
4349235e597SMatthias Springer Optional<Region *> repetitiveRegionOfWrites =
4359235e597SMatthias Springer getCommonEnclosingRepetitiveRegion(writtenAliases);
4369235e597SMatthias Springer
4377a1579acSMatthias Springer for (OpOperand *uRead : usesRead) {
4387a1579acSMatthias Springer Operation *readingOp = uRead->getOwner();
4397a1579acSMatthias Springer
4407a1579acSMatthias Springer // Find most recent writes of uRead by following the SSA use-def chain.
4417a1579acSMatthias Springer // E.g.:
4427a1579acSMatthias Springer //
4437a1579acSMatthias Springer // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
4447a1579acSMatthias Springer // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
4457a1579acSMatthias Springer // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
4467a1579acSMatthias Springer //
4477a1579acSMatthias Springer // In the above example, if uRead is the OpOperand of reading_op, lastWrite
4487a1579acSMatthias Springer // is %0. Note that operations that create an alias but do not write (such
4497a1579acSMatthias Springer // as ExtractSliceOp) are skipped.
4507a1579acSMatthias Springer SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get());
4517a1579acSMatthias Springer
4527a1579acSMatthias Springer // Look for conflicting memory writes. Potential conflicts are writes to an
4537a1579acSMatthias Springer // alias that have been decided to bufferize inplace.
4547a1579acSMatthias Springer for (OpOperand *uConflictingWrite : usesWrite) {
4557a1579acSMatthias Springer // Throughout this loop, check for multiple requirements that have to be
4567a1579acSMatthias Springer // met for uConflictingWrite to be an actual conflict.
4577a1579acSMatthias Springer Operation *conflictingWritingOp = uConflictingWrite->getOwner();
4587a1579acSMatthias Springer
4599235e597SMatthias Springer // Check if conflictingWritingOp is in the same repetitive region as all
4609235e597SMatthias Springer // written aliases. If this is not the case, there is no meaningful
4619235e597SMatthias Springer // `happensBefore` relationship because conflictingWritingOp may be
4629235e597SMatthias Springer // executed multiple times. E.g.:
4639235e597SMatthias Springer //
4649235e597SMatthias Springer // %0 = ... : tensor<?xf32>
4659235e597SMatthias Springer // scf.for ... {
4669235e597SMatthias Springer // "reading_op"(%0) : tensor<?xf32>
4679235e597SMatthias Springer // %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
4689235e597SMatthias Springer // ...
4699235e597SMatthias Springer // }
4709235e597SMatthias Springer //
4719235e597SMatthias Springer // In the above example, reading_op happens before writing_op according to
4729235e597SMatthias Springer // op dominance. However, both ops may happen multiple times; in
4739235e597SMatthias Springer // particular, the second execution of reading_op happens after the first
4749235e597SMatthias Springer // execution of writing_op. This is problematic if the tensor they operate
4759235e597SMatthias Springer // on (%0) is defined outside of the loop.
4769235e597SMatthias Springer //
4779235e597SMatthias Springer // Counter example:
4789235e597SMatthias Springer //
4799235e597SMatthias Springer // scf.for ... {
4809235e597SMatthias Springer // %0 = ... : tensor<?xf32>
4819235e597SMatthias Springer // "reading_op"(%0) : tensor<?xf32>
4829235e597SMatthias Springer // %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
4839235e597SMatthias Springer // ...
4849235e597SMatthias Springer // }
4859235e597SMatthias Springer //
4869235e597SMatthias Springer // In this example, %0 is in the same repetitive region as
4879235e597SMatthias Springer // conflictingWritingOp, so op dominance can be used to compute the
4889235e597SMatthias Springer // `happensBefore` relationship.
4899235e597SMatthias Springer //
4909235e597SMatthias Springer // Note: iter_args of loops are not aliases of their respective block
4919235e597SMatthias Springer // arguments, so op domanice can be used when analyzing ops that operate
4929235e597SMatthias Springer // on them.
49337a14735SMatthias Springer //
49437a14735SMatthias Springer // Note: If `writtenAliases` is empty, there are no memory writes outside
49537a14735SMatthias Springer // of the repetitive region of conflictingWritingOp, which means that all
49637a14735SMatthias Springer // relevant aliases are inside the same repetitive region.
4979235e597SMatthias Springer bool canUseOpDominance =
49837a14735SMatthias Springer writtenAliases.empty() ||
4999235e597SMatthias Springer repetitiveRegionOfWrites ==
5009235e597SMatthias Springer getEnclosingRepetitiveRegion(conflictingWritingOp);
5019235e597SMatthias Springer
5027a1579acSMatthias Springer // No conflict if the readingOp dominates conflictingWritingOp, i.e., the
5037a1579acSMatthias Springer // write is not visible when reading.
5049235e597SMatthias Springer //
5059235e597SMatthias Springer // Note: If ops are executed multiple times (e.g., because they are inside
5069235e597SMatthias Springer // a loop), there may be no meaningful `happensBefore` relationship.
5079235e597SMatthias Springer if (canUseOpDominance &&
5089235e597SMatthias Springer happensBefore(readingOp, conflictingWritingOp, domInfo))
5097a1579acSMatthias Springer continue;
5107a1579acSMatthias Springer
5117a1579acSMatthias Springer // No conflict if the reading use equals the use of the conflicting write.
5129235e597SMatthias Springer // A use cannot conflict with itself.
5139235e597SMatthias Springer //
5149235e597SMatthias Springer // Note: Just being the same op is not enough. It has to be the same use.
5159235e597SMatthias Springer // Note: If the op is executed multiple times (e.g., because it is inside
5169235e597SMatthias Springer // a loop), it may be conflicting with itself.
5179235e597SMatthias Springer if (canUseOpDominance && uConflictingWrite == uRead)
5187a1579acSMatthias Springer continue;
5197a1579acSMatthias Springer
5207a1579acSMatthias Springer // No conflict if the op interface says so.
5217a1579acSMatthias Springer if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp))
5227a1579acSMatthias Springer if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
5237a1579acSMatthias Springer continue;
5247a1579acSMatthias Springer
5257a1579acSMatthias Springer if (conflictingWritingOp != readingOp)
5267a1579acSMatthias Springer if (auto bufferizableOp =
5277a1579acSMatthias Springer options.dynCastBufferizableOp(conflictingWritingOp))
5287a1579acSMatthias Springer if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
5297a1579acSMatthias Springer continue;
5307a1579acSMatthias Springer
5317a1579acSMatthias Springer // Ops are not conflicting if they are in mutually exclusive regions.
5329235e597SMatthias Springer //
5339235e597SMatthias Springer // Note: If ops are executed multiple times (e.g., because they are inside
5349235e597SMatthias Springer // a loop), mutually exclusive regions may be executed multiple
5359235e597SMatthias Springer // times.
5369235e597SMatthias Springer if (canUseOpDominance &&
5379235e597SMatthias Springer insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp))
5387a1579acSMatthias Springer continue;
5397a1579acSMatthias Springer
5407a1579acSMatthias Springer // Check all possible last writes.
5417a1579acSMatthias Springer for (Value lastWrite : lastWrites) {
5427a1579acSMatthias Springer // No conflict if the conflicting write happens before the last
5437a1579acSMatthias Springer // write.
5447a1579acSMatthias Springer if (Operation *writingOp = lastWrite.getDefiningOp()) {
5457a1579acSMatthias Springer if (happensBefore(conflictingWritingOp, writingOp, domInfo))
5467a1579acSMatthias Springer // conflictingWritingOp happens before writingOp. No conflict.
5477a1579acSMatthias Springer continue;
5487a1579acSMatthias Springer // No conflict if conflictingWritingOp is contained in writingOp.
5497a1579acSMatthias Springer if (writingOp->isProperAncestor(conflictingWritingOp))
5507a1579acSMatthias Springer continue;
5517a1579acSMatthias Springer } else {
5527a1579acSMatthias Springer auto bbArg = lastWrite.cast<BlockArgument>();
5537a1579acSMatthias Springer Block *block = bbArg.getOwner();
5547a1579acSMatthias Springer if (!block->findAncestorOpInBlock(*conflictingWritingOp))
5557a1579acSMatthias Springer // conflictingWritingOp happens outside of the block. No
5567a1579acSMatthias Springer // conflict.
5577a1579acSMatthias Springer continue;
5587a1579acSMatthias Springer }
5597a1579acSMatthias Springer
5607a1579acSMatthias Springer // No conflict if the conflicting write and the last write are the same
5617a1579acSMatthias Springer // use.
562585a8a32SMatthias Springer SmallVector<OpResult> aliasingOpResult =
563585a8a32SMatthias Springer state.getAliasingOpResult(*uConflictingWrite);
564585a8a32SMatthias Springer if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite)
5657a1579acSMatthias Springer continue;
5667a1579acSMatthias Springer
5677a1579acSMatthias Springer // All requirements are met. Conflict found!
5687a1579acSMatthias Springer
5697a1579acSMatthias Springer if (options.printConflicts)
5707a1579acSMatthias Springer annotateConflict(uRead, uConflictingWrite, lastWrite);
5717a1579acSMatthias Springer
5727a1579acSMatthias Springer return true;
5737a1579acSMatthias Springer }
5747a1579acSMatthias Springer }
5757a1579acSMatthias Springer }
5767a1579acSMatthias Springer
5777a1579acSMatthias Springer return false;
5787a1579acSMatthias Springer }
5797a1579acSMatthias Springer
580032be233SMatthias Springer // Helper function to iterate on aliases of `root` and capture the writes.
getAliasingInplaceWrites(DenseSet<OpOperand * > & res,Value root,const BufferizationAliasInfo & aliasInfo,const AnalysisState & state)581032be233SMatthias Springer static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root,
582032be233SMatthias Springer const BufferizationAliasInfo &aliasInfo,
583032be233SMatthias Springer const AnalysisState &state) {
584032be233SMatthias Springer aliasInfo.applyOnAliases(root, [&](Value alias) {
585032be233SMatthias Springer for (auto &use : alias.getUses())
586032be233SMatthias Springer // Inplace write to a value that aliases root.
587032be233SMatthias Springer if (isInplaceMemoryWrite(use, aliasInfo, state))
588032be233SMatthias Springer res.insert(&use);
589032be233SMatthias Springer });
590032be233SMatthias Springer }
591032be233SMatthias Springer
592032be233SMatthias Springer // Helper function to iterate on aliases of `root` and capture the reads.
getAliasingReads(DenseSet<OpOperand * > & res,Value root,const BufferizationAliasInfo & aliasInfo,const AnalysisState & state)593032be233SMatthias Springer static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
594032be233SMatthias Springer const BufferizationAliasInfo &aliasInfo,
595032be233SMatthias Springer const AnalysisState &state) {
596032be233SMatthias Springer aliasInfo.applyOnAliases(root, [&](Value alias) {
597032be233SMatthias Springer for (auto &use : alias.getUses())
598032be233SMatthias Springer // Read to a value that aliases root.
599032be233SMatthias Springer if (state.bufferizesToMemoryRead(use))
600032be233SMatthias Springer res.insert(&use);
601032be233SMatthias Springer });
602032be233SMatthias Springer }
603032be233SMatthias Springer
6047a1579acSMatthias Springer /// Return true if bufferizing `operand` inplace would create a conflict. A read
6057a1579acSMatthias Springer /// R and a write W of the same alias set is a conflict if inplace bufferization
6067a1579acSMatthias Springer /// of W changes the value read by R to a value different from the one that
6077a1579acSMatthias Springer /// would be expected by tracing back R's origin through SSA use-def chains.
6087a1579acSMatthias Springer /// A conflict can only be introduced by a new alias and/or an inplace
6097a1579acSMatthias Springer /// bufferization decision.
6107a1579acSMatthias Springer ///
6117a1579acSMatthias Springer /// Example:
6127a1579acSMatthias Springer /// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?}
6137a1579acSMatthias Springer /// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32>
6147a1579acSMatthias Springer /// %e = tensor.extract_slice %1
6157a1579acSMatthias Springer /// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32>
6167a1579acSMatthias Springer /// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32>
6177a1579acSMatthias Springer ///
6187a1579acSMatthias Springer /// In the above example, the two TransferWriteOps have already been decided to
6197a1579acSMatthias Springer /// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a
6207a1579acSMatthias Springer /// conflict because:
6217a1579acSMatthias Springer /// * According to SSA use-def chains, we expect to read the result of %1.
6227a1579acSMatthias Springer /// * However, adding an alias {%0, %t} would mean that the second
6237a1579acSMatthias Springer /// TransferWriteOp overwrites the first one. Therefore, the TransferReadOp
6247a1579acSMatthias Springer /// would no longer be reading the result of %1.
6257a1579acSMatthias Springer ///
6267a1579acSMatthias Springer /// If `checkConsistencyOnly` is true, this function checks if there is a
6277a1579acSMatthias Springer /// read-after-write conflict without bufferizing `operand` inplace. This would
6287a1579acSMatthias Springer /// indicate a problem with the current inplace bufferization decisions.
6297a1579acSMatthias Springer ///
6307a1579acSMatthias Springer /// Note: If `checkConsistencyOnly`, this function may be called with a null
6317a1579acSMatthias Springer /// OpResult. In that case, only the consistency of bufferization decisions
6327a1579acSMatthias Springer /// involving aliases of the given OpOperand are checked.
wouldCreateReadAfterWriteInterference(OpOperand & operand,const DominanceInfo & domInfo,AnalysisState & state,const BufferizationAliasInfo & aliasInfo,bool checkConsistencyOnly=false)6337a1579acSMatthias Springer static bool wouldCreateReadAfterWriteInterference(
6349597b16aSMatthias Springer OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state,
6357a1579acSMatthias Springer const BufferizationAliasInfo &aliasInfo,
6367a1579acSMatthias Springer bool checkConsistencyOnly = false) {
6377a1579acSMatthias Springer // Collect reads and writes of all aliases of OpOperand and OpResult.
6387a1579acSMatthias Springer DenseSet<OpOperand *> usesRead, usesWrite;
639032be233SMatthias Springer getAliasingReads(usesRead, operand.get(), aliasInfo, state);
640032be233SMatthias Springer getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state);
641585a8a32SMatthias Springer for (OpResult result : state.getAliasingOpResult(operand)) {
642032be233SMatthias Springer getAliasingReads(usesRead, result, aliasInfo, state);
643032be233SMatthias Springer getAliasingInplaceWrites(usesWrite, result, aliasInfo, state);
6447a1579acSMatthias Springer }
6457a1579acSMatthias Springer if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
6467a1579acSMatthias Springer usesWrite.insert(&operand);
6477a1579acSMatthias Springer
6487a1579acSMatthias Springer return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state,
6497a1579acSMatthias Springer aliasInfo);
6507a1579acSMatthias Springer }
6517a1579acSMatthias Springer
652032be233SMatthias Springer /// Check the reverse SSA use-def chain (following aliasing OpOperands) for
653032be233SMatthias Springer /// non-writable tensor values. Stop searching when an out-of-place bufferized
654032be233SMatthias Springer /// OpOperand was found (or when the OpOperand was not bufferized yet).
655032be233SMatthias Springer /// `currentOpOperand` is assumed to be in-place, even if that decision was not
656032be233SMatthias Springer /// materialized in `aliasInfo` yet.
6577a1579acSMatthias Springer static bool
hasPrecedingAliasingNonWritableTensor(Value value,OpOperand * currentOpOperand,const BufferizationAliasInfo & aliasInfo,const OneShotAnalysisState & state)658032be233SMatthias Springer hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand,
6597a1579acSMatthias Springer const BufferizationAliasInfo &aliasInfo,
660032be233SMatthias Springer const OneShotAnalysisState &state) {
661032be233SMatthias Springer SmallVector<Value> worklist;
662032be233SMatthias Springer worklist.push_back(value);
663032be233SMatthias Springer while (!worklist.empty()) {
664032be233SMatthias Springer Value nextVal = worklist.pop_back_val();
665032be233SMatthias Springer if (!state.isWritable(nextVal))
666032be233SMatthias Springer return true;
667032be233SMatthias Springer
668032be233SMatthias Springer // If `nextVal` is not a BlockArgument: End of use-def chain reached.
669032be233SMatthias Springer auto opResult = nextVal.dyn_cast<OpResult>();
670032be233SMatthias Springer if (!opResult)
671032be233SMatthias Springer continue;
672032be233SMatthias Springer
673032be233SMatthias Springer // Follow reverse SSA use-def chain.
674032be233SMatthias Springer SmallVector<OpOperand *> aliasingOpOperands =
675032be233SMatthias Springer state.getAliasingOpOperand(opResult);
676032be233SMatthias Springer for (OpOperand *opOperand : aliasingOpOperands)
677032be233SMatthias Springer if (aliasInfo.isInPlace(*opOperand) || currentOpOperand == opOperand)
678032be233SMatthias Springer worklist.push_back(opOperand->get());
679032be233SMatthias Springer }
6807a1579acSMatthias Springer return false;
681032be233SMatthias Springer }
6827a1579acSMatthias Springer
683032be233SMatthias Springer /// Return true if bufferizing `operand` inplace would create a write to a
684032be233SMatthias Springer /// non-writable buffer.
wouldCreateWriteToNonWritableBuffer(OpOperand & operand,const BufferizationAliasInfo & aliasInfo,OneShotAnalysisState & state,bool checkConsistencyOnly=false)685032be233SMatthias Springer static bool wouldCreateWriteToNonWritableBuffer(
686032be233SMatthias Springer OpOperand &operand, const BufferizationAliasInfo &aliasInfo,
687032be233SMatthias Springer OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
688032be233SMatthias Springer // Collect writes of all aliases of OpOperand and OpResult.
689032be233SMatthias Springer DenseSet<OpOperand *> usesWrite;
690032be233SMatthias Springer getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state);
691032be233SMatthias Springer for (OpResult result : state.getAliasingOpResult(operand)) {
692032be233SMatthias Springer getAliasingInplaceWrites(usesWrite, result, aliasInfo, state);
693032be233SMatthias Springer }
694032be233SMatthias Springer if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
695032be233SMatthias Springer usesWrite.insert(&operand);
6967a1579acSMatthias Springer
697032be233SMatthias Springer // Assuming that `operand` bufferizes in-place: For each write (to each
698032be233SMatthias Springer // alias), check if there is a non-writable tensor in the reverse SSA use-def
699032be233SMatthias Springer // chain.
700032be233SMatthias Springer for (OpOperand *uWrite : usesWrite)
701032be233SMatthias Springer if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand,
702032be233SMatthias Springer aliasInfo, state))
703032be233SMatthias Springer return true;
7047a1579acSMatthias Springer
705032be233SMatthias Springer return false;
7067a1579acSMatthias Springer }
7077a1579acSMatthias Springer
7087a1579acSMatthias Springer //===----------------------------------------------------------------------===//
7097a1579acSMatthias Springer // Bufferization analyses.
7107a1579acSMatthias Springer //===----------------------------------------------------------------------===//
7117a1579acSMatthias Springer
7127a1579acSMatthias Springer /// Determine if `operand` can be bufferized in-place.
bufferizableInPlaceAnalysisImpl(OpOperand & operand,BufferizationAliasInfo & aliasInfo,OneShotAnalysisState & state,const DominanceInfo & domInfo)7137a1579acSMatthias Springer static LogicalResult bufferizableInPlaceAnalysisImpl(
714032be233SMatthias Springer OpOperand &operand, BufferizationAliasInfo &aliasInfo,
715032be233SMatthias Springer OneShotAnalysisState &state, const DominanceInfo &domInfo) {
7167a1579acSMatthias Springer bool foundInterference =
7177a1579acSMatthias Springer wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) ||
7187a1579acSMatthias Springer wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo);
7197a1579acSMatthias Springer
7207a1579acSMatthias Springer if (foundInterference)
7217a1579acSMatthias Springer aliasInfo.bufferizeOutOfPlace(operand);
7227a1579acSMatthias Springer else
7237a1579acSMatthias Springer aliasInfo.bufferizeInPlace(operand, state);
7247a1579acSMatthias Springer
7257a1579acSMatthias Springer return success();
7267a1579acSMatthias Springer }
7277a1579acSMatthias Springer
7287a1579acSMatthias Springer /// Analyze the `ops` to determine which OpOperands are inplaceable. Walk ops in
7297a1579acSMatthias Springer /// reverse and bufferize ops greedily. This is a good starter heuristic.
7307a1579acSMatthias Springer ///
7317a1579acSMatthias Springer /// Even if an op does not read or write, it may still create an alias when
7327a1579acSMatthias Springer /// bufferized in-place. An example of such ops is tensor.extract_slice.
7337a1579acSMatthias Springer ///
7347a1579acSMatthias Springer /// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace:
7357a1579acSMatthias Springer ///
7367a1579acSMatthias Springer /// When bufferized out of place, an ExtractSliceOp lowers to alloc + copy. This
7377a1579acSMatthias Springer /// cannot change the flow of information for either the source or the
7387a1579acSMatthias Springer /// result buffers.
7397a1579acSMatthias Springer ///
7407a1579acSMatthias Springer /// When bufferized inplace, an ExtractSliceOp does not by itself create any
7417a1579acSMatthias Springer /// read or write from memory. Instead, it has the effect of merging the alias
7427a1579acSMatthias Springer /// sets of the source and the result buffers.
7437a1579acSMatthias Springer ///
7447a1579acSMatthias Springer /// An analysis is required to ensure inplace bufferization would not result in
7457a1579acSMatthias Springer /// RaW dependence violations.
inPlaceAnalysis(SmallVector<Operation * > & ops,BufferizationAliasInfo & aliasInfo,OneShotAnalysisState & state,const DominanceInfo & domInfo,unsigned analysisFuzzerSeed=0)7467a1579acSMatthias Springer static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
7477a1579acSMatthias Springer BufferizationAliasInfo &aliasInfo,
748032be233SMatthias Springer OneShotAnalysisState &state,
7497a1579acSMatthias Springer const DominanceInfo &domInfo,
7507a1579acSMatthias Springer unsigned analysisFuzzerSeed = 0) {
7517a1579acSMatthias Springer if (analysisFuzzerSeed) {
7527a1579acSMatthias Springer // This is a fuzzer. For testing purposes only. Randomize the order in which
7537a1579acSMatthias Springer // operations are analyzed. The bufferization quality is likely worse, but
7547a1579acSMatthias Springer // we want to make sure that no assertions are triggered anywhere.
7557a1579acSMatthias Springer std::mt19937 g(analysisFuzzerSeed);
7567a1579acSMatthias Springer llvm::shuffle(ops.begin(), ops.end(), g);
7577a1579acSMatthias Springer }
7587a1579acSMatthias Springer
7597a1579acSMatthias Springer // Walk ops in reverse for better interference analysis.
7607a1579acSMatthias Springer for (Operation *op : reverse(ops))
7617a1579acSMatthias Springer for (OpOperand &opOperand : op->getOpOperands())
7627a1579acSMatthias Springer if (opOperand.get().getType().isa<TensorType>())
7637a1579acSMatthias Springer if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
7647a1579acSMatthias Springer if (failed(bufferizableInPlaceAnalysisImpl(opOperand, aliasInfo,
7657a1579acSMatthias Springer state, domInfo)))
7667a1579acSMatthias Springer return failure();
7677a1579acSMatthias Springer
7687a1579acSMatthias Springer return success();
7697a1579acSMatthias Springer }
7707a1579acSMatthias Springer
7717a1579acSMatthias Springer /// Return true if the given op has a tensor result or a tensor operand.
hasTensorSemantics(Operation * op)7727a1579acSMatthias Springer static bool hasTensorSemantics(Operation *op) {
7737a1579acSMatthias Springer bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
7747a1579acSMatthias Springer bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
7757a1579acSMatthias Springer return hasTensorResult || hasTensorOperand;
7767a1579acSMatthias Springer }
7777a1579acSMatthias Springer
7787a1579acSMatthias Springer /// Analyze all ops that are contained in `op`.
inPlaceAnalysis(Operation * op,BufferizationAliasInfo & aliasInfo,OneShotAnalysisState & state,const DominanceInfo & domInfo,unsigned analysisFuzzerSeed=0)7797a1579acSMatthias Springer static LogicalResult inPlaceAnalysis(Operation *op,
7807a1579acSMatthias Springer BufferizationAliasInfo &aliasInfo,
781032be233SMatthias Springer OneShotAnalysisState &state,
7827a1579acSMatthias Springer const DominanceInfo &domInfo,
7837a1579acSMatthias Springer unsigned analysisFuzzerSeed = 0) {
7847a1579acSMatthias Springer // Collect ops so we can build our own reverse traversal.
7857a1579acSMatthias Springer SmallVector<Operation *> ops;
7867a1579acSMatthias Springer op->walk([&](Operation *op) {
7877a1579acSMatthias Springer // No tensors => no buffers.
7887a1579acSMatthias Springer if (!hasTensorSemantics(op))
7897a1579acSMatthias Springer return;
7907a1579acSMatthias Springer ops.push_back(op);
7917a1579acSMatthias Springer });
7927a1579acSMatthias Springer
7937a1579acSMatthias Springer return inPlaceAnalysis(ops, aliasInfo, state, domInfo, analysisFuzzerSeed);
7947a1579acSMatthias Springer }
7957a1579acSMatthias Springer
7967a1579acSMatthias Springer /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
equivalenceAnalysis(SmallVector<Operation * > & ops,BufferizationAliasInfo & aliasInfo,AnalysisState & state)7977a1579acSMatthias Springer static void equivalenceAnalysis(SmallVector<Operation *> &ops,
7987a1579acSMatthias Springer BufferizationAliasInfo &aliasInfo,
7999597b16aSMatthias Springer AnalysisState &state) {
8007a1579acSMatthias Springer for (Operation *op : ops)
8017a1579acSMatthias Springer if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
8027a1579acSMatthias Springer for (OpResult opResult : op->getOpResults())
8037a1579acSMatthias Springer if (opResult.getType().isa<TensorType>())
8047a1579acSMatthias Springer for (OpOperand *opOperand :
8057a1579acSMatthias Springer bufferizableOp.getAliasingOpOperand(opResult, state))
8067a1579acSMatthias Springer if (state.isInPlace(*opOperand))
8077a1579acSMatthias Springer if (bufferizableOp.bufferRelation(opResult, state) ==
8087a1579acSMatthias Springer BufferRelation::Equivalent)
8097a1579acSMatthias Springer aliasInfo.unionEquivalenceClasses(opResult, opOperand->get());
8107a1579acSMatthias Springer }
8117a1579acSMatthias Springer
8127a1579acSMatthias Springer /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
8137a1579acSMatthias Springer /// in `op`.
equivalenceAnalysis(Operation * op,BufferizationAliasInfo & aliasInfo,AnalysisState & state)8147a1579acSMatthias Springer static void equivalenceAnalysis(Operation *op,
8157a1579acSMatthias Springer BufferizationAliasInfo &aliasInfo,
8169597b16aSMatthias Springer AnalysisState &state) {
8177a1579acSMatthias Springer // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
8187a1579acSMatthias Springer SmallVector<Operation *> ops;
8197a1579acSMatthias Springer op->walk<WalkOrder::PostOrder>([&](Operation *op) {
8207a1579acSMatthias Springer // No tensors => no buffers.
8217a1579acSMatthias Springer if (none_of(op->getResultTypes(), isaTensor))
8227a1579acSMatthias Springer return;
8237a1579acSMatthias Springer ops.push_back(op);
8247a1579acSMatthias Springer });
8257a1579acSMatthias Springer
8267a1579acSMatthias Springer equivalenceAnalysis(ops, aliasInfo, state);
8277a1579acSMatthias Springer }
8287a1579acSMatthias Springer
8297a1579acSMatthias Springer /// Assert that the current bufferization decisions are consistent.
8307a1579acSMatthias Springer static LogicalResult
checkAliasInfoConsistency(Operation * op,const DominanceInfo & domInfo,AnalysisState & state,const BufferizationAliasInfo & aliasInfo)8317a1579acSMatthias Springer checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
8329597b16aSMatthias Springer AnalysisState &state,
8337a1579acSMatthias Springer const BufferizationAliasInfo &aliasInfo) {
8347a1579acSMatthias Springer const BufferizationOptions &options = state.getOptions();
8357a1579acSMatthias Springer Operation *inconsistentOp = nullptr;
8367a1579acSMatthias Springer WalkResult walkResult = op->walk([&](Operation *op) {
8377a1579acSMatthias Springer if (auto bufferizableOp = options.dynCastBufferizableOp(op))
8387a1579acSMatthias Springer for (OpOperand &opOperand : op->getOpOperands())
8397a1579acSMatthias Springer if (opOperand.get().getType().isa<TensorType>()) {
8407a1579acSMatthias Springer if (wouldCreateReadAfterWriteInterference(
8417a1579acSMatthias Springer opOperand, domInfo, state, aliasInfo,
8427a1579acSMatthias Springer /*checkConsistencyOnly=*/true)) {
8437a1579acSMatthias Springer // This error can happen if certain "mustBufferizeInPlace" interface
8447a1579acSMatthias Springer // methods are implemented incorrectly, such that the IR already has
8457a1579acSMatthias Springer // a RaW conflict before making any bufferization decisions.
8467a1579acSMatthias Springer inconsistentOp = op;
8477a1579acSMatthias Springer return WalkResult::interrupt();
8487a1579acSMatthias Springer }
8497a1579acSMatthias Springer }
8507a1579acSMatthias Springer return WalkResult::advance();
8517a1579acSMatthias Springer });
8527a1579acSMatthias Springer
8537a1579acSMatthias Springer if (walkResult.wasInterrupted())
8547a1579acSMatthias Springer return inconsistentOp->emitError("input IR has RaW conflict");
8557a1579acSMatthias Springer return success();
8567a1579acSMatthias Springer }
8577a1579acSMatthias Springer
8587a1579acSMatthias Springer /// Annotate the IR with the result of the analysis. For testing/debugging only.
8597a1579acSMatthias Springer static void
annotateOpsWithBufferizationMarkers(Operation * op,const BufferizationAliasInfo & aliasInfo,AnalysisState & state)8607a1579acSMatthias Springer annotateOpsWithBufferizationMarkers(Operation *op,
8617a1579acSMatthias Springer const BufferizationAliasInfo &aliasInfo,
8629597b16aSMatthias Springer AnalysisState &state) {
8637a1579acSMatthias Springer op->walk([&](Operation *op) {
8647a1579acSMatthias Springer if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
8657a1579acSMatthias Springer for (OpOperand &opOperand : op->getOpOperands())
8667a1579acSMatthias Springer if (opOperand.get().getType().isa<TensorType>())
8677a1579acSMatthias Springer setInPlaceOpOperand(opOperand, aliasInfo.isInPlace(opOperand));
8687a1579acSMatthias Springer });
8697a1579acSMatthias Springer }
8707a1579acSMatthias Springer
8717a1579acSMatthias Springer /// Assert that IR is in destination-passing style. I.e., every value that is
8727a1579acSMatthias Springer /// returned or yielded from a block is:
8737a1579acSMatthias Springer /// * aliasing a bbArg of that block or a parent block, or
8747a1579acSMatthias Springer /// * aliasing an OpResult of a op in a parent block.
8757a1579acSMatthias Springer ///
8767a1579acSMatthias Springer /// Example:
8777a1579acSMatthias Springer /// ```
8787a1579acSMatthias Springer /// %0 = "some_op" : tensor<?xf32>
8797a1579acSMatthias Springer /// %1 = scf.if %c -> (tensor<?xf32>) {
8807a1579acSMatthias Springer /// scf.yield %0 : tensor<?xf32>
8817a1579acSMatthias Springer /// } else {
882ffdbecccSMatthias Springer /// %t = linalg.alloc_tensor : tensor<?xf32>
8837a1579acSMatthias Springer /// scf.yield %t : tensor<?xf32>
8847a1579acSMatthias Springer /// }
8857a1579acSMatthias Springer /// ```
8867a1579acSMatthias Springer /// In the above example, the first scf.yield op satifies destination-passing
8877a1579acSMatthias Springer /// style because the yielded value %0 is defined in the parent block. The
8887a1579acSMatthias Springer /// second scf.yield op does not satisfy destination-passing style because the
8897a1579acSMatthias Springer /// yielded value %t is defined in the same block as the scf.yield op.
8907a1579acSMatthias Springer // TODO: The current implementation checks for equivalent values instead of
8917a1579acSMatthias Springer // aliasing values, which is stricter than needed. We can currently not check
8927a1579acSMatthias Springer // for aliasing values because the analysis is a maybe-alias analysis and we
8937a1579acSMatthias Springer // need a must-alias analysis here.
894cdb7675cSMatthias Springer static LogicalResult
assertDestinationPassingStyle(Operation * op,AnalysisState & state,BufferizationAliasInfo & aliasInfo,SmallVector<Operation * > & newOps)8959597b16aSMatthias Springer assertDestinationPassingStyle(Operation *op, AnalysisState &state,
8967a1579acSMatthias Springer BufferizationAliasInfo &aliasInfo,
897cdb7675cSMatthias Springer SmallVector<Operation *> &newOps) {
8987a1579acSMatthias Springer LogicalResult status = success();
8997a1579acSMatthias Springer DominanceInfo domInfo(op);
9007a1579acSMatthias Springer op->walk([&](Operation *returnOp) {
9013b426868SMatthias Springer if (!isRegionReturnLike(returnOp) ||
9023b426868SMatthias Springer !state.getOptions().isOpAllowed(returnOp))
9037a1579acSMatthias Springer return WalkResult::advance();
9047a1579acSMatthias Springer
9057a1579acSMatthias Springer for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
9067a1579acSMatthias Springer Value returnVal = returnValOperand.get();
9077a1579acSMatthias Springer // Skip non-tensor values.
9087a1579acSMatthias Springer if (!returnVal.getType().isa<TensorType>())
9097a1579acSMatthias Springer continue;
9107a1579acSMatthias Springer
9117a1579acSMatthias Springer bool foundEquivValue = false;
9127a1579acSMatthias Springer aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
9137a1579acSMatthias Springer if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
9147a1579acSMatthias Springer Operation *definingOp = bbArg.getOwner()->getParentOp();
9157a1579acSMatthias Springer if (definingOp->isProperAncestor(returnOp))
9167a1579acSMatthias Springer foundEquivValue = true;
9177a1579acSMatthias Springer return;
9187a1579acSMatthias Springer }
9197a1579acSMatthias Springer
9207a1579acSMatthias Springer Operation *definingOp = equivVal.getDefiningOp();
9217a1579acSMatthias Springer if (definingOp->getBlock()->findAncestorOpInBlock(
9227a1579acSMatthias Springer *returnOp->getParentOp()))
9237a1579acSMatthias Springer // Skip ops that happen after `returnOp` and parent ops.
9247a1579acSMatthias Springer if (happensBefore(definingOp, returnOp, domInfo))
9257a1579acSMatthias Springer foundEquivValue = true;
9267a1579acSMatthias Springer });
9277a1579acSMatthias Springer
9287a1579acSMatthias Springer if (!foundEquivValue)
9297a1579acSMatthias Springer status =
9307a1579acSMatthias Springer returnOp->emitError()
9317a1579acSMatthias Springer << "operand #" << returnValOperand.getOperandNumber()
9327a1579acSMatthias Springer << " of ReturnLike op does not satisfy destination passing style";
9337a1579acSMatthias Springer }
9347a1579acSMatthias Springer
9357a1579acSMatthias Springer return WalkResult::advance();
9367a1579acSMatthias Springer });
9377a1579acSMatthias Springer
9387a1579acSMatthias Springer return status;
9397a1579acSMatthias Springer }
9407a1579acSMatthias Springer
analyzeOp(Operation * op,OneShotAnalysisState & state)9417a1579acSMatthias Springer LogicalResult bufferization::analyzeOp(Operation *op,
9429597b16aSMatthias Springer OneShotAnalysisState &state) {
9437a1579acSMatthias Springer DominanceInfo domInfo(op);
9447a1579acSMatthias Springer BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
9457a1579acSMatthias Springer const auto &options =
9469597b16aSMatthias Springer static_cast<const OneShotBufferizationOptions &>(state.getOptions());
9477a1579acSMatthias Springer
948d6dab38aSMatthias Springer // Catch incorrect API usage.
949d6dab38aSMatthias Springer assert((state.hasDialectState(func::FuncDialect::getDialectNamespace()) ||
950d6dab38aSMatthias Springer !options.bufferizeFunctionBoundaries) &&
951d6dab38aSMatthias Springer "must use ModuleBufferize to bufferize function boundaries");
952d6dab38aSMatthias Springer
9537a1579acSMatthias Springer if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
9547a1579acSMatthias Springer return failure();
9557a1579acSMatthias Springer
9567a1579acSMatthias Springer // If the analysis fails, just return.
9577a1579acSMatthias Springer if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo,
9587a1579acSMatthias Springer options.analysisFuzzerSeed)))
9597a1579acSMatthias Springer return failure();
9607a1579acSMatthias Springer equivalenceAnalysis(op, aliasInfo, state);
9617a1579acSMatthias Springer
962d1d79920SMatthias Springer bool failedAnalysis = false;
963855a11eeSMatthias Springer if (!options.allowReturnAllocs) {
9647a1579acSMatthias Springer SmallVector<Operation *> newOps;
965d1d79920SMatthias Springer failedAnalysis |=
966d1d79920SMatthias Springer failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps));
9677a1579acSMatthias Springer }
9687a1579acSMatthias Springer
969988748c0SMatthias Springer // Gather some extra analysis data.
9709e24f0f4SMatthias Springer state.gatherYieldedTensors(op);
971988748c0SMatthias Springer state.gatherUndefinedTensorUses(op);
9729e24f0f4SMatthias Springer
9734ec00fb3SMatthias Springer // Analysis verification: After setting up alias/equivalence sets, each op
9744ec00fb3SMatthias Springer // can check for expected invariants/limitations and fail the analysis if
9754ec00fb3SMatthias Springer // necessary.
9764ec00fb3SMatthias Springer op->walk([&](Operation *op) {
9774ec00fb3SMatthias Springer if (BufferizableOpInterface bufferizableOp =
9784ec00fb3SMatthias Springer options.dynCastBufferizableOp(op))
979d1d79920SMatthias Springer failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state));
9804ec00fb3SMatthias Springer });
9814ec00fb3SMatthias Springer
9827a1579acSMatthias Springer // Annotate operations if we only want to report the analysis.
9837a1579acSMatthias Springer if (options.testAnalysisOnly)
9847a1579acSMatthias Springer annotateOpsWithBufferizationMarkers(op, aliasInfo, state);
9857a1579acSMatthias Springer
986d1d79920SMatthias Springer return success(!failedAnalysis);
9877a1579acSMatthias Springer }
9887a1579acSMatthias Springer
9899597b16aSMatthias Springer LogicalResult
runOneShotBufferize(Operation * op,const OneShotBufferizationOptions & options)9909597b16aSMatthias Springer bufferization::runOneShotBufferize(Operation *op,
9919597b16aSMatthias Springer const OneShotBufferizationOptions &options) {
9929597b16aSMatthias Springer OneShotAnalysisState state(op, options);
993*b3ebe3beSMatthias Springer if (failed(insertTensorCopies(op, options)))
9947a1579acSMatthias Springer return failure();
995d2dacde5SMatthias Springer if (options.testAnalysisOnly)
9967a1579acSMatthias Springer return success();
997*b3ebe3beSMatthias Springer return bufferizeOp(op, options, /*copyBeforeWrite=*/false);
9987a1579acSMatthias Springer }
999