17a1579acSMatthias Springer //===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===//
27a1579acSMatthias Springer //
37a1579acSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47a1579acSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
57a1579acSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67a1579acSMatthias Springer //
77a1579acSMatthias Springer //===----------------------------------------------------------------------===//
87a1579acSMatthias Springer //
97a1579acSMatthias Springer // One-Shot Analysis analyzes function bodies. Function boundaries (FuncOp
107a1579acSMatthias Springer // bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
117a1579acSMatthias Springer // ModuleBufferization.cpp is an extension of One-Shot Analysis for simple
127a1579acSMatthias Springer // call graphs.
137a1579acSMatthias Springer //
147a1579acSMatthias Springer // One-Shot Bufferize consists of two phases.
157a1579acSMatthias Springer //
167a1579acSMatthias Springer // 1. Analyze ops to decide which OpResults can bufferize inplace, i.e., without
177a1579acSMatthias Springer //    inserting buffer copies. The analysis queries op bufferization semantics
187a1579acSMatthias Springer //    via `BufferizableOpInterface`.
197a1579acSMatthias Springer // 2. Bufferize ops by calling `BufferizableOpInterface::bufferize`. This
207a1579acSMatthias Springer //    function does not generate buffer copies for OpResults that were decided
217a1579acSMatthias Springer //    to bufferize inplace during the analysis phase.
227a1579acSMatthias Springer //
237a1579acSMatthias Springer // This file contains only the analysis. The actual bufferization is implemented
247a1579acSMatthias Springer // via `bufferizeOp` (Bufferize.h). For convenience, this file also contains a
257a1579acSMatthias Springer // helper function `runOneShotBufferize` that analyzes an op (and its nested
267a1579acSMatthias Springer // ops) and then bufferizes it.
277a1579acSMatthias Springer //
287a1579acSMatthias Springer // Inplace bufferization decisions are passed from the analysis to the
297a1579acSMatthias Springer // bufferization phase via `BufferizationState` and `BufferizationAliasInfo`.
307a1579acSMatthias Springer // They can be printed for debugging purposes with `testAnalysisOnly`.
317a1579acSMatthias Springer //
327a1579acSMatthias Springer // Ops that do not implement `BufferizableOpInterface` can be analyzed but are
337a1579acSMatthias Springer // treated conservatively. E.g., the analysis has to assume that their tensor
347a1579acSMatthias Springer // OpOperands bufferize to memory writes. While such ops can be analyzed, they
357a1579acSMatthias Springer // are not bufferized and remain in the IR. to_tensor and to_memref ops are
367a1579acSMatthias Springer // inserted at the bufferization boundary.
377a1579acSMatthias Springer //
387a1579acSMatthias Springer // This analysis caters to high-performance codegen where buffer reuse is deemed
397a1579acSMatthias Springer // critical: the analysis should fail if the bufferized form of the function
407a1579acSMatthias Springer // needs to return a buffer, unless `allowReturnMemref` is enabled.
417a1579acSMatthias Springer 
427a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
437a1579acSMatthias Springer 
447a1579acSMatthias Springer #include <random>
457a1579acSMatthias Springer 
467a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
477a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
487a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
497a1579acSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
507a1579acSMatthias Springer #include "mlir/IR/AsmState.h"
517a1579acSMatthias Springer #include "mlir/IR/Dominance.h"
527a1579acSMatthias Springer #include "mlir/IR/Operation.h"
537a1579acSMatthias Springer #include "mlir/IR/TypeUtilities.h"
547a1579acSMatthias Springer #include "mlir/Interfaces/ControlFlowInterfaces.h"
557a1579acSMatthias Springer #include "llvm/ADT/DenseSet.h"
567a1579acSMatthias Springer #include "llvm/ADT/SetVector.h"
577a1579acSMatthias Springer 
587a1579acSMatthias Springer using namespace mlir;
597a1579acSMatthias Springer using namespace mlir::bufferization;
607a1579acSMatthias Springer 
617a1579acSMatthias Springer static bool isaTensor(Type t) { return t.isa<TensorType>(); }
627a1579acSMatthias Springer 
637a1579acSMatthias Springer //===----------------------------------------------------------------------===//
647a1579acSMatthias Springer // Bufferization-specific attribute manipulation.
657a1579acSMatthias Springer // These are for testing and debugging only. Bufferization information is
667a1579acSMatthias Springer // stored in BufferizationAliasInfo. When run with `testAnalysisOnly`, the IR
677a1579acSMatthias Springer // is annotated with the results of the analysis (copied from
687a1579acSMatthias Springer // BufferizationAliasInfo), so that they can be checked in tests.
697a1579acSMatthias Springer //===----------------------------------------------------------------------===//
707a1579acSMatthias Springer 
717a1579acSMatthias Springer /// Attribute marker to specify op results that can be bufferized inPlace.
727a1579acSMatthias Springer constexpr StringLiteral kInPlaceResultsAttrName = "__inplace_operands_attr__";
737a1579acSMatthias Springer 
747a1579acSMatthias Springer /// Mark whether OpOperand will be bufferized inplace.
757a1579acSMatthias Springer static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
767a1579acSMatthias Springer   Operation *op = opOperand.getOwner();
777a1579acSMatthias Springer   auto attr =
787a1579acSMatthias Springer       op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null<ArrayAttr>();
797a1579acSMatthias Springer   SmallVector<StringRef> inPlaceVector;
807a1579acSMatthias Springer   if (attr) {
817a1579acSMatthias Springer     inPlaceVector = SmallVector<StringRef>(
827a1579acSMatthias Springer         llvm::to_vector<4>(attr.getAsValueRange<StringAttr>()));
837a1579acSMatthias Springer   } else {
847a1579acSMatthias Springer     inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
857a1579acSMatthias Springer     for (OpOperand &opOperand : op->getOpOperands())
867a1579acSMatthias Springer       if (opOperand.get().getType().isa<TensorType>())
877a1579acSMatthias Springer         inPlaceVector[opOperand.getOperandNumber()] = "false";
887a1579acSMatthias Springer   }
897a1579acSMatthias Springer 
907a1579acSMatthias Springer   inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
917a1579acSMatthias Springer   op->setAttr(kInPlaceResultsAttrName,
927a1579acSMatthias Springer               OpBuilder(op).getStrArrayAttr(inPlaceVector));
937a1579acSMatthias Springer }
947a1579acSMatthias Springer 
957a1579acSMatthias Springer //===----------------------------------------------------------------------===//
967a1579acSMatthias Springer // BufferizationAliasInfo
977a1579acSMatthias Springer //===----------------------------------------------------------------------===//
987a1579acSMatthias Springer 
997a1579acSMatthias Springer BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
1007a1579acSMatthias Springer   rootOp->walk([&](Operation *op) {
1017a1579acSMatthias Springer     for (Value v : op->getResults())
1027a1579acSMatthias Springer       if (v.getType().isa<TensorType>())
1037a1579acSMatthias Springer         createAliasInfoEntry(v);
1047a1579acSMatthias Springer     for (Region &r : op->getRegions())
1057a1579acSMatthias Springer       for (Block &b : r.getBlocks())
1067a1579acSMatthias Springer         for (auto bbArg : b.getArguments())
1077a1579acSMatthias Springer           if (bbArg.getType().isa<TensorType>())
1087a1579acSMatthias Springer             createAliasInfoEntry(bbArg);
1097a1579acSMatthias Springer   });
1107a1579acSMatthias Springer }
1117a1579acSMatthias Springer 
1127a1579acSMatthias Springer /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
1137a1579acSMatthias Springer /// beginning the alias and equivalence sets only contain `v` itself.
1147a1579acSMatthias Springer void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
1157a1579acSMatthias Springer   aliasInfo.insert(v);
1167a1579acSMatthias Springer   equivalentInfo.insert(v);
1177a1579acSMatthias Springer }
1187a1579acSMatthias Springer 
1197a1579acSMatthias Springer /// Insert an info entry for `newValue` and merge its alias set with that of
1207a1579acSMatthias Springer /// `alias`.
1217a1579acSMatthias Springer void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
1227a1579acSMatthias Springer   createAliasInfoEntry(newValue);
1237a1579acSMatthias Springer   aliasInfo.unionSets(newValue, alias);
1247a1579acSMatthias Springer }
1257a1579acSMatthias Springer 
1267a1579acSMatthias Springer /// Insert an info entry for `newValue` and merge its alias set with that of
1277a1579acSMatthias Springer /// `alias`. Additionally, merge their equivalence classes.
1287a1579acSMatthias Springer void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
1297a1579acSMatthias Springer                                                         Value alias) {
1307a1579acSMatthias Springer   insertNewBufferAlias(newValue, alias);
1317a1579acSMatthias Springer   equivalentInfo.unionSets(newValue, alias);
1327a1579acSMatthias Springer }
1337a1579acSMatthias Springer 
1347a1579acSMatthias Springer /// Return `true` if a value was marked as in-place bufferized.
1357a1579acSMatthias Springer bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
1367a1579acSMatthias Springer   return inplaceBufferized.contains(&operand);
1377a1579acSMatthias Springer }
1387a1579acSMatthias Springer 
1397a1579acSMatthias Springer /// Set the inPlace bufferization spec to true.
1407a1579acSMatthias Springer void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
1417a1579acSMatthias Springer                                               BufferizationState &state) {
1427a1579acSMatthias Springer   markInPlace(operand);
143585a8a32SMatthias Springer   for (OpResult result : state.getAliasingOpResult(operand))
1447a1579acSMatthias Springer     aliasInfo.unionSets(result, operand.get());
1457a1579acSMatthias Springer }
1467a1579acSMatthias Springer 
1477a1579acSMatthias Springer /// Set the inPlace bufferization spec to false.
1487a1579acSMatthias Springer void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) {
1497a1579acSMatthias Springer   assert(!inplaceBufferized.contains(&operand) &&
1507a1579acSMatthias Springer          "OpOperand was already decided to bufferize inplace");
1517a1579acSMatthias Springer }
1527a1579acSMatthias Springer 
1537a1579acSMatthias Springer /// Apply `fun` to all the members of the equivalence class of `v`.
1547a1579acSMatthias Springer void BufferizationAliasInfo::applyOnEquivalenceClass(
1557a1579acSMatthias Springer     Value v, function_ref<void(Value)> fun) const {
1567a1579acSMatthias Springer   auto leaderIt = equivalentInfo.findLeader(v);
1577a1579acSMatthias Springer   for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
1587a1579acSMatthias Springer        ++mit) {
1597a1579acSMatthias Springer     fun(*mit);
1607a1579acSMatthias Springer   }
1617a1579acSMatthias Springer }
1627a1579acSMatthias Springer 
1637a1579acSMatthias Springer /// Apply `fun` to all aliases of `v`.
1647a1579acSMatthias Springer void BufferizationAliasInfo::applyOnAliases(
1657a1579acSMatthias Springer     Value v, function_ref<void(Value)> fun) const {
1667a1579acSMatthias Springer   auto leaderIt = aliasInfo.findLeader(v);
1677a1579acSMatthias Springer   for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
1687a1579acSMatthias Springer     fun(*mit);
1697a1579acSMatthias Springer   }
1707a1579acSMatthias Springer }
1717a1579acSMatthias Springer 
1727a1579acSMatthias Springer BufferizationAliasInfo::EquivalenceClassRangeType
1737a1579acSMatthias Springer BufferizationAliasInfo::getAliases(Value v) const {
1747a1579acSMatthias Springer   DenseSet<Value> res;
1757a1579acSMatthias Springer   auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v));
1767a1579acSMatthias Springer   for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
1777a1579acSMatthias Springer        mit != meit; ++mit) {
1787a1579acSMatthias Springer     res.insert(static_cast<Value>(*mit));
1797a1579acSMatthias Springer   }
1807a1579acSMatthias Springer   return BufferizationAliasInfo::EquivalenceClassRangeType(
1817a1579acSMatthias Springer       aliasInfo.member_begin(it), aliasInfo.member_end());
1827a1579acSMatthias Springer }
1837a1579acSMatthias Springer 
1847a1579acSMatthias Springer //===----------------------------------------------------------------------===//
1857a1579acSMatthias Springer // AnalysisBufferizationState
1867a1579acSMatthias Springer //===----------------------------------------------------------------------===//
1877a1579acSMatthias Springer 
1887a1579acSMatthias Springer AnalysisBufferizationState::AnalysisBufferizationState(
1897a1579acSMatthias Springer     Operation *op, const AnalysisBufferizationOptions &options)
1907a1579acSMatthias Springer     : BufferizationState(options), aliasInfo(op) {
1917a1579acSMatthias Springer   // Set up alias sets for OpResults that must bufferize in-place. This should
1927a1579acSMatthias Springer   // be done before making any other bufferization decisions.
1937a1579acSMatthias Springer   op->walk([&](BufferizableOpInterface bufferizableOp) {
1947a1579acSMatthias Springer     if (!options.isOpAllowed(bufferizableOp))
1957a1579acSMatthias Springer       return WalkResult::skip();
1967a1579acSMatthias Springer     for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
1977a1579acSMatthias Springer       if (opOperand.get().getType().isa<TensorType>())
1987a1579acSMatthias Springer         if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) {
199585a8a32SMatthias Springer           for (OpResult opResult :
2007a1579acSMatthias Springer                bufferizableOp.getAliasingOpResult(opOperand, *this))
2017a1579acSMatthias Springer             aliasInfo.unionAliasSets(opOperand.get(), opResult);
2027a1579acSMatthias Springer           aliasInfo.markInPlace(opOperand);
2037a1579acSMatthias Springer         }
2047a1579acSMatthias Springer     }
2057a1579acSMatthias Springer     return WalkResult::advance();
2067a1579acSMatthias Springer   });
2077a1579acSMatthias Springer }
2087a1579acSMatthias Springer 
2097a1579acSMatthias Springer bool AnalysisBufferizationState::isInPlace(OpOperand &opOperand) const {
2107a1579acSMatthias Springer   return aliasInfo.isInPlace(opOperand);
2117a1579acSMatthias Springer }
2127a1579acSMatthias Springer 
2137a1579acSMatthias Springer bool AnalysisBufferizationState::areEquivalentBufferizedValues(Value v1,
2147a1579acSMatthias Springer                                                                Value v2) const {
2157a1579acSMatthias Springer   return aliasInfo.areEquivalentBufferizedValues(v1, v2);
2167a1579acSMatthias Springer }
2177a1579acSMatthias Springer 
2187a1579acSMatthias Springer //===----------------------------------------------------------------------===//
2197a1579acSMatthias Springer // Bufferization-specific alias analysis.
2207a1579acSMatthias Springer //===----------------------------------------------------------------------===//
2217a1579acSMatthias Springer 
2227a1579acSMatthias Springer /// Return true if opOperand has been decided to bufferize in-place.
2237a1579acSMatthias Springer static bool isInplaceMemoryWrite(OpOperand &opOperand,
2247a1579acSMatthias Springer                                  const BufferizationAliasInfo &aliasInfo,
2257a1579acSMatthias Springer                                  BufferizationState &state) {
2267a1579acSMatthias Springer   // OpOperands that do not bufferize to a memory write do not write in-place.
2277a1579acSMatthias Springer   if (!state.bufferizesToMemoryWrite(opOperand))
2287a1579acSMatthias Springer     return false;
2297a1579acSMatthias Springer   // Check current bufferization decisions.
2307a1579acSMatthias Springer   return aliasInfo.isInPlace(opOperand);
2317a1579acSMatthias Springer }
2327a1579acSMatthias Springer 
2337a1579acSMatthias Springer /// Return true if, under current bufferization decisions, the buffer of `value`
2347a1579acSMatthias Springer /// is not writable.
2357a1579acSMatthias Springer static bool aliasesNonWritableBuffer(Value value,
2367a1579acSMatthias Springer                                      const BufferizationAliasInfo &aliasInfo,
2377a1579acSMatthias Springer                                      BufferizationState &state) {
2387a1579acSMatthias Springer   bool foundNonWritableBuffer = false;
2397a1579acSMatthias Springer   aliasInfo.applyOnAliases(value, [&](Value v) {
2407a1579acSMatthias Springer     // Query BufferizableOpInterface to see if the value is writable.
2417a1579acSMatthias Springer     // TODO: Out-of-place bufferized value could be considered writable.
2427a1579acSMatthias Springer     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(v))
2437a1579acSMatthias Springer       if (bufferizableOp && bufferizableOp.isWritable(v, state))
2447a1579acSMatthias Springer         return;
2457a1579acSMatthias Springer 
2467a1579acSMatthias Springer     // Query BufferizableOpInterface to see if the BlockArgument is writable.
2477a1579acSMatthias Springer     if (auto bbArg = v.dyn_cast<BlockArgument>())
2487a1579acSMatthias Springer       if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(
2497a1579acSMatthias Springer               bbArg.getOwner()->getParentOp()))
2507a1579acSMatthias Springer         if (bufferizableOp.isWritable(bbArg, state))
2517a1579acSMatthias Springer           return;
2527a1579acSMatthias Springer 
2537a1579acSMatthias Springer     foundNonWritableBuffer = true;
2547a1579acSMatthias Springer   });
2557a1579acSMatthias Springer 
2567a1579acSMatthias Springer   return foundNonWritableBuffer;
2577a1579acSMatthias Springer }
2587a1579acSMatthias Springer 
2597a1579acSMatthias Springer /// Return true if the buffer to which `operand` would bufferize is equivalent
2607a1579acSMatthias Springer /// to some buffer write.
2617a1579acSMatthias Springer static bool aliasesInPlaceWrite(Value value,
2627a1579acSMatthias Springer                                 const BufferizationAliasInfo &aliasInfo,
2637a1579acSMatthias Springer                                 BufferizationState &state) {
2647a1579acSMatthias Springer   bool foundInplaceWrite = false;
2657a1579acSMatthias Springer   aliasInfo.applyOnAliases(value, [&](Value v) {
2667a1579acSMatthias Springer     for (auto &use : v.getUses()) {
2677a1579acSMatthias Springer       if (isInplaceMemoryWrite(use, aliasInfo, state)) {
2687a1579acSMatthias Springer         foundInplaceWrite = true;
2697a1579acSMatthias Springer         return;
2707a1579acSMatthias Springer       }
2717a1579acSMatthias Springer     }
2727a1579acSMatthias Springer   });
2737a1579acSMatthias Springer   return foundInplaceWrite;
2747a1579acSMatthias Springer }
2757a1579acSMatthias Springer 
2767a1579acSMatthias Springer /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
2777a1579acSMatthias Springer /// properly dominates `b` and `b` is not inside `a`.
2787a1579acSMatthias Springer static bool happensBefore(Operation *a, Operation *b,
2797a1579acSMatthias Springer                           const DominanceInfo &domInfo) {
2807a1579acSMatthias Springer   do {
2817a1579acSMatthias Springer     // TODO: Instead of isProperAncestor + properlyDominates, we should use
2827a1579acSMatthias Springer     // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
2837a1579acSMatthias Springer     if (a->isProperAncestor(b))
2847a1579acSMatthias Springer       return false;
2857a1579acSMatthias Springer     if (domInfo.properlyDominates(a, b))
2867a1579acSMatthias Springer       return true;
2877a1579acSMatthias Springer   } while ((a = a->getParentOp()));
2887a1579acSMatthias Springer   return false;
2897a1579acSMatthias Springer }
2907a1579acSMatthias Springer 
2917a1579acSMatthias Springer /// Annotate IR with details about the detected RaW conflict.
2927a1579acSMatthias Springer static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
2937a1579acSMatthias Springer                              Value lastWrite) {
2947a1579acSMatthias Springer   static uint64_t counter = 0;
2957a1579acSMatthias Springer   Operation *readingOp = uRead->getOwner();
2967a1579acSMatthias Springer   Operation *conflictingWritingOp = uConflictingWrite->getOwner();
2977a1579acSMatthias Springer 
2987a1579acSMatthias Springer   OpBuilder b(conflictingWritingOp->getContext());
2997a1579acSMatthias Springer   std::string id = "C_" + std::to_string(counter++);
3007a1579acSMatthias Springer 
3017a1579acSMatthias Springer   std::string conflictingWriteAttr =
3027a1579acSMatthias Springer       id +
3037a1579acSMatthias Springer       "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) +
3047a1579acSMatthias Springer       "]";
3057a1579acSMatthias Springer   conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr());
3067a1579acSMatthias Springer 
3077a1579acSMatthias Springer   std::string readAttr =
3087a1579acSMatthias Springer       id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
3097a1579acSMatthias Springer   readingOp->setAttr(readAttr, b.getUnitAttr());
3107a1579acSMatthias Springer 
3117a1579acSMatthias Springer   if (auto opResult = lastWrite.dyn_cast<OpResult>()) {
3127a1579acSMatthias Springer     std::string lastWriteAttr = id + "[LAST-WRITE: result " +
3137a1579acSMatthias Springer                                 std::to_string(opResult.getResultNumber()) +
3147a1579acSMatthias Springer                                 "]";
3157a1579acSMatthias Springer     opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr());
3167a1579acSMatthias Springer   } else {
3177a1579acSMatthias Springer     auto bbArg = lastWrite.cast<BlockArgument>();
3187a1579acSMatthias Springer     std::string lastWriteAttr =
3197a1579acSMatthias Springer         id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
3207a1579acSMatthias Springer     bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr());
3217a1579acSMatthias Springer   }
3227a1579acSMatthias Springer }
3237a1579acSMatthias Springer 
3247a1579acSMatthias Springer /// Given sets of uses and writes, return true if there is a RaW conflict under
3257a1579acSMatthias Springer /// the assumption that all given reads/writes alias the same buffer and that
3267a1579acSMatthias Springer /// all given writes bufferize inplace.
3277a1579acSMatthias Springer ///
3287a1579acSMatthias Springer /// A conflict is: According to SSA use-def chains, a read R is supposed to read
3297a1579acSMatthias Springer /// the result of a write W1. But because of bufferization decisions, R actually
3307a1579acSMatthias Springer /// reads another write W2.
3317a1579acSMatthias Springer static bool hasReadAfterWriteInterference(
3327a1579acSMatthias Springer     const DenseSet<OpOperand *> &usesRead,
3337a1579acSMatthias Springer     const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo,
3347a1579acSMatthias Springer     BufferizationState &state, const BufferizationAliasInfo &aliasInfo) {
3357a1579acSMatthias Springer   const BufferizationOptions &options = state.getOptions();
3367a1579acSMatthias Springer 
3377a1579acSMatthias Springer   for (OpOperand *uRead : usesRead) {
3387a1579acSMatthias Springer     Operation *readingOp = uRead->getOwner();
3397a1579acSMatthias Springer 
3407a1579acSMatthias Springer     // Find most recent writes of uRead by following the SSA use-def chain.
3417a1579acSMatthias Springer     // E.g.:
3427a1579acSMatthias Springer     //
3437a1579acSMatthias Springer     // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
3447a1579acSMatthias Springer     // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
3457a1579acSMatthias Springer     // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
3467a1579acSMatthias Springer     //
3477a1579acSMatthias Springer     // In the above example, if uRead is the OpOperand of reading_op, lastWrite
3487a1579acSMatthias Springer     // is %0. Note that operations that create an alias but do not write (such
3497a1579acSMatthias Springer     // as ExtractSliceOp) are skipped.
3507a1579acSMatthias Springer     SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get());
3517a1579acSMatthias Springer 
3527a1579acSMatthias Springer     // Look for conflicting memory writes. Potential conflicts are writes to an
3537a1579acSMatthias Springer     // alias that have been decided to bufferize inplace.
3547a1579acSMatthias Springer     for (OpOperand *uConflictingWrite : usesWrite) {
3557a1579acSMatthias Springer       // Throughout this loop, check for multiple requirements that have to be
3567a1579acSMatthias Springer       // met for uConflictingWrite to be an actual conflict.
3577a1579acSMatthias Springer       Operation *conflictingWritingOp = uConflictingWrite->getOwner();
3587a1579acSMatthias Springer 
3597a1579acSMatthias Springer       // No conflict if the readingOp dominates conflictingWritingOp, i.e., the
3607a1579acSMatthias Springer       // write is not visible when reading.
3617a1579acSMatthias Springer       if (happensBefore(readingOp, conflictingWritingOp, domInfo))
3627a1579acSMatthias Springer         continue;
3637a1579acSMatthias Springer 
3647a1579acSMatthias Springer       // No conflict if the reading use equals the use of the conflicting write.
3657a1579acSMatthias Springer       // A use cannot conflict with itself. Note: Just being the same op is not
3667a1579acSMatthias Springer       // enough. It has to be the same use.
3677a1579acSMatthias Springer       if (uConflictingWrite == uRead)
3687a1579acSMatthias Springer         continue;
3697a1579acSMatthias Springer 
3707a1579acSMatthias Springer       // No conflict if the op interface says so.
3717a1579acSMatthias Springer       if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp))
3727a1579acSMatthias Springer         if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
3737a1579acSMatthias Springer           continue;
3747a1579acSMatthias Springer 
3757a1579acSMatthias Springer       if (conflictingWritingOp != readingOp)
3767a1579acSMatthias Springer         if (auto bufferizableOp =
3777a1579acSMatthias Springer                 options.dynCastBufferizableOp(conflictingWritingOp))
3787a1579acSMatthias Springer           if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
3797a1579acSMatthias Springer             continue;
3807a1579acSMatthias Springer 
3817a1579acSMatthias Springer       // Ops are not conflicting if they are in mutually exclusive regions.
3827a1579acSMatthias Springer       if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp))
3837a1579acSMatthias Springer         continue;
3847a1579acSMatthias Springer 
3857a1579acSMatthias Springer       // Check all possible last writes.
3867a1579acSMatthias Springer       for (Value lastWrite : lastWrites) {
3877a1579acSMatthias Springer         // No conflict if the conflicting write happens before the last
3887a1579acSMatthias Springer         // write.
3897a1579acSMatthias Springer         if (Operation *writingOp = lastWrite.getDefiningOp()) {
3907a1579acSMatthias Springer           if (happensBefore(conflictingWritingOp, writingOp, domInfo))
3917a1579acSMatthias Springer             // conflictingWritingOp happens before writingOp. No conflict.
3927a1579acSMatthias Springer             continue;
3937a1579acSMatthias Springer           // No conflict if conflictingWritingOp is contained in writingOp.
3947a1579acSMatthias Springer           if (writingOp->isProperAncestor(conflictingWritingOp))
3957a1579acSMatthias Springer             continue;
3967a1579acSMatthias Springer         } else {
3977a1579acSMatthias Springer           auto bbArg = lastWrite.cast<BlockArgument>();
3987a1579acSMatthias Springer           Block *block = bbArg.getOwner();
3997a1579acSMatthias Springer           if (!block->findAncestorOpInBlock(*conflictingWritingOp))
4007a1579acSMatthias Springer             // conflictingWritingOp happens outside of the block. No
4017a1579acSMatthias Springer             // conflict.
4027a1579acSMatthias Springer             continue;
4037a1579acSMatthias Springer         }
4047a1579acSMatthias Springer 
4057a1579acSMatthias Springer         // No conflict if the conflicting write and the last write are the same
4067a1579acSMatthias Springer         // use.
407585a8a32SMatthias Springer         SmallVector<OpResult> aliasingOpResult =
408585a8a32SMatthias Springer             state.getAliasingOpResult(*uConflictingWrite);
409585a8a32SMatthias Springer         if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite)
4107a1579acSMatthias Springer           continue;
4117a1579acSMatthias Springer 
4127a1579acSMatthias Springer         // All requirements are met. Conflict found!
4137a1579acSMatthias Springer 
4147a1579acSMatthias Springer         if (options.printConflicts)
4157a1579acSMatthias Springer           annotateConflict(uRead, uConflictingWrite, lastWrite);
4167a1579acSMatthias Springer 
4177a1579acSMatthias Springer         return true;
4187a1579acSMatthias Springer       }
4197a1579acSMatthias Springer     }
4207a1579acSMatthias Springer   }
4217a1579acSMatthias Springer 
4227a1579acSMatthias Springer   return false;
4237a1579acSMatthias Springer }
4247a1579acSMatthias Springer 
4257a1579acSMatthias Springer /// Return true if bufferizing `operand` inplace would create a conflict. A read
4267a1579acSMatthias Springer /// R and a write W of the same alias set is a conflict if inplace bufferization
4277a1579acSMatthias Springer /// of W changes the value read by R to a value different from the one that
4287a1579acSMatthias Springer /// would be expected by tracing back R's origin through SSA use-def chains.
4297a1579acSMatthias Springer /// A conflict can only be introduced by a new alias and/or an inplace
4307a1579acSMatthias Springer /// bufferization decision.
4317a1579acSMatthias Springer ///
4327a1579acSMatthias Springer /// Example:
4337a1579acSMatthias Springer /// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?}
4347a1579acSMatthias Springer /// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32>
4357a1579acSMatthias Springer /// %e = tensor.extract_slice %1
4367a1579acSMatthias Springer /// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32>
4377a1579acSMatthias Springer /// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32>
4387a1579acSMatthias Springer ///
4397a1579acSMatthias Springer /// In the above example, the two TransferWriteOps have already been decided to
4407a1579acSMatthias Springer /// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a
4417a1579acSMatthias Springer /// conflict because:
4427a1579acSMatthias Springer /// * According to SSA use-def chains, we expect to read the result of %1.
4437a1579acSMatthias Springer /// * However, adding an alias {%0, %t} would mean that the second
4447a1579acSMatthias Springer ///   TransferWriteOp overwrites the first one. Therefore, the TransferReadOp
4457a1579acSMatthias Springer ///   would no longer be reading the result of %1.
4467a1579acSMatthias Springer ///
4477a1579acSMatthias Springer /// If `checkConsistencyOnly` is true, this function checks if there is a
4487a1579acSMatthias Springer /// read-after-write conflict without bufferizing `operand` inplace. This would
4497a1579acSMatthias Springer /// indicate a problem with the current inplace bufferization decisions.
4507a1579acSMatthias Springer ///
4517a1579acSMatthias Springer /// Note: If `checkConsistencyOnly`, this function may be called with a null
4527a1579acSMatthias Springer /// OpResult. In that case, only the consistency of bufferization decisions
4537a1579acSMatthias Springer /// involving aliases of the given OpOperand are checked.
4547a1579acSMatthias Springer static bool wouldCreateReadAfterWriteInterference(
4557a1579acSMatthias Springer     OpOperand &operand, const DominanceInfo &domInfo, BufferizationState &state,
4567a1579acSMatthias Springer     const BufferizationAliasInfo &aliasInfo,
4577a1579acSMatthias Springer     bool checkConsistencyOnly = false) {
4587a1579acSMatthias Springer   // Helper function to iterate on aliases of `root` and capture the reads.
4597a1579acSMatthias Springer   auto getAliasingReads = [&](DenseSet<OpOperand *> &res, Value root) {
4607a1579acSMatthias Springer     aliasInfo.applyOnAliases(root, [&](Value alias) {
4617a1579acSMatthias Springer       for (auto &use : alias.getUses())
4627a1579acSMatthias Springer         // Read to a value that aliases root.
4637a1579acSMatthias Springer         if (state.bufferizesToMemoryRead(use))
4647a1579acSMatthias Springer           res.insert(&use);
4657a1579acSMatthias Springer     });
4667a1579acSMatthias Springer   };
4677a1579acSMatthias Springer 
4687a1579acSMatthias Springer   // Helper function to iterate on aliases of `root` and capture the writes.
4697a1579acSMatthias Springer   auto getAliasingInplaceWrites = [&](DenseSet<OpOperand *> &res, Value root) {
4707a1579acSMatthias Springer     aliasInfo.applyOnAliases(root, [&](Value alias) {
4717a1579acSMatthias Springer       for (auto &use : alias.getUses())
4727a1579acSMatthias Springer         // Inplace write to a value that aliases root.
4737a1579acSMatthias Springer         if (isInplaceMemoryWrite(use, aliasInfo, state))
4747a1579acSMatthias Springer           res.insert(&use);
4757a1579acSMatthias Springer     });
4767a1579acSMatthias Springer   };
4777a1579acSMatthias Springer 
4787a1579acSMatthias Springer   // Collect reads and writes of all aliases of OpOperand and OpResult.
4797a1579acSMatthias Springer   DenseSet<OpOperand *> usesRead, usesWrite;
4807a1579acSMatthias Springer   getAliasingReads(usesRead, operand.get());
4817a1579acSMatthias Springer   getAliasingInplaceWrites(usesWrite, operand.get());
482585a8a32SMatthias Springer   for (OpResult result : state.getAliasingOpResult(operand)) {
4837a1579acSMatthias Springer     getAliasingReads(usesRead, result);
4847a1579acSMatthias Springer     getAliasingInplaceWrites(usesWrite, result);
4857a1579acSMatthias Springer   }
4867a1579acSMatthias Springer   if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
4877a1579acSMatthias Springer     usesWrite.insert(&operand);
4887a1579acSMatthias Springer 
4897a1579acSMatthias Springer   return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state,
4907a1579acSMatthias Springer                                        aliasInfo);
4917a1579acSMatthias Springer }
4927a1579acSMatthias Springer 
4937a1579acSMatthias Springer /// Return true if bufferizing `opOperand` inplace would create a write to a
4947a1579acSMatthias Springer /// non-writable buffer.
4957a1579acSMatthias Springer static bool
4967a1579acSMatthias Springer wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand,
4977a1579acSMatthias Springer                                     const BufferizationAliasInfo &aliasInfo,
4987a1579acSMatthias Springer                                     BufferizationState &state) {
4997a1579acSMatthias Springer   // Certain buffers are not writeable:
5007a1579acSMatthias Springer   //   1. A function bbArg that is not inplaceable or
5017a1579acSMatthias Springer   //   2. A constant op.
5027a1579acSMatthias Springer   bool nonWritable =
5037a1579acSMatthias Springer       aliasesNonWritableBuffer(opOperand.get(), aliasInfo, state);
5047a1579acSMatthias Springer   if (!nonWritable)
5057a1579acSMatthias Springer     return false;
5067a1579acSMatthias Springer 
5077a1579acSMatthias Springer   // This is a problem only if the buffer is written to via some alias.
5087a1579acSMatthias Springer   bool hasWrite = aliasesInPlaceWrite(opOperand.get(), aliasInfo, state) ||
5097a1579acSMatthias Springer                   state.bufferizesToMemoryWrite(opOperand);
5107a1579acSMatthias Springer 
511585a8a32SMatthias Springer   for (OpResult opResult : state.getAliasingOpResult(opOperand))
5127a1579acSMatthias Springer     hasWrite |= aliasesInPlaceWrite(opResult, aliasInfo, state);
5137a1579acSMatthias Springer 
5147a1579acSMatthias Springer   return hasWrite;
5157a1579acSMatthias Springer }
5167a1579acSMatthias Springer 
5177a1579acSMatthias Springer //===----------------------------------------------------------------------===//
5187a1579acSMatthias Springer // Bufferization analyses.
5197a1579acSMatthias Springer //===----------------------------------------------------------------------===//
5207a1579acSMatthias Springer 
5217a1579acSMatthias Springer /// Determine if `operand` can be bufferized in-place.
5227a1579acSMatthias Springer static LogicalResult bufferizableInPlaceAnalysisImpl(
5237a1579acSMatthias Springer     OpOperand &operand, BufferizationAliasInfo &aliasInfo,
5247a1579acSMatthias Springer     BufferizationState &state, const DominanceInfo &domInfo) {
5257a1579acSMatthias Springer   bool foundInterference =
5267a1579acSMatthias Springer       wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) ||
5277a1579acSMatthias Springer       wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo);
5287a1579acSMatthias Springer 
5297a1579acSMatthias Springer   if (foundInterference)
5307a1579acSMatthias Springer     aliasInfo.bufferizeOutOfPlace(operand);
5317a1579acSMatthias Springer   else
5327a1579acSMatthias Springer     aliasInfo.bufferizeInPlace(operand, state);
5337a1579acSMatthias Springer 
5347a1579acSMatthias Springer   return success();
5357a1579acSMatthias Springer }
5367a1579acSMatthias Springer 
5377a1579acSMatthias Springer /// Analyze the `ops` to determine which OpOperands are inplaceable. Walk ops in
5387a1579acSMatthias Springer /// reverse and bufferize ops greedily. This is a good starter heuristic.
5397a1579acSMatthias Springer ///
5407a1579acSMatthias Springer /// Even if an op does not read or write, it may still create an alias when
5417a1579acSMatthias Springer /// bufferized in-place. An example of such ops is tensor.extract_slice.
5427a1579acSMatthias Springer ///
5437a1579acSMatthias Springer /// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace:
5447a1579acSMatthias Springer ///
5457a1579acSMatthias Springer /// When bufferized out of place, an ExtractSliceOp lowers to alloc + copy. This
5467a1579acSMatthias Springer /// cannot change the flow of information for either the source or the
5477a1579acSMatthias Springer /// result buffers.
5487a1579acSMatthias Springer ///
5497a1579acSMatthias Springer /// When bufferized inplace, an ExtractSliceOp does not by itself create any
5507a1579acSMatthias Springer /// read or write from memory. Instead, it has the effect of merging the alias
5517a1579acSMatthias Springer /// sets of the source and the result buffers.
5527a1579acSMatthias Springer ///
5537a1579acSMatthias Springer /// An analysis is required to ensure inplace bufferization would not result in
5547a1579acSMatthias Springer /// RaW dependence violations.
5557a1579acSMatthias Springer static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
5567a1579acSMatthias Springer                                      BufferizationAliasInfo &aliasInfo,
5577a1579acSMatthias Springer                                      BufferizationState &state,
5587a1579acSMatthias Springer                                      const DominanceInfo &domInfo,
5597a1579acSMatthias Springer                                      unsigned analysisFuzzerSeed = 0) {
5607a1579acSMatthias Springer   if (analysisFuzzerSeed) {
5617a1579acSMatthias Springer     // This is a fuzzer. For testing purposes only. Randomize the order in which
5627a1579acSMatthias Springer     // operations are analyzed. The bufferization quality is likely worse, but
5637a1579acSMatthias Springer     // we want to make sure that no assertions are triggered anywhere.
5647a1579acSMatthias Springer     std::mt19937 g(analysisFuzzerSeed);
5657a1579acSMatthias Springer     llvm::shuffle(ops.begin(), ops.end(), g);
5667a1579acSMatthias Springer   }
5677a1579acSMatthias Springer 
5687a1579acSMatthias Springer   // Walk ops in reverse for better interference analysis.
5697a1579acSMatthias Springer   for (Operation *op : reverse(ops))
5707a1579acSMatthias Springer     for (OpOperand &opOperand : op->getOpOperands())
5717a1579acSMatthias Springer       if (opOperand.get().getType().isa<TensorType>())
5727a1579acSMatthias Springer         if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
5737a1579acSMatthias Springer           if (failed(bufferizableInPlaceAnalysisImpl(opOperand, aliasInfo,
5747a1579acSMatthias Springer                                                      state, domInfo)))
5757a1579acSMatthias Springer             return failure();
5767a1579acSMatthias Springer 
5777a1579acSMatthias Springer   return success();
5787a1579acSMatthias Springer }
5797a1579acSMatthias Springer 
5807a1579acSMatthias Springer /// Return true if the given op has a tensor result or a tensor operand.
5817a1579acSMatthias Springer static bool hasTensorSemantics(Operation *op) {
5827a1579acSMatthias Springer   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
5837a1579acSMatthias Springer   bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
5847a1579acSMatthias Springer   return hasTensorResult || hasTensorOperand;
5857a1579acSMatthias Springer }
5867a1579acSMatthias Springer 
5877a1579acSMatthias Springer /// Analyze all ops that are contained in `op`.
5887a1579acSMatthias Springer static LogicalResult inPlaceAnalysis(Operation *op,
5897a1579acSMatthias Springer                                      BufferizationAliasInfo &aliasInfo,
5907a1579acSMatthias Springer                                      BufferizationState &state,
5917a1579acSMatthias Springer                                      const DominanceInfo &domInfo,
5927a1579acSMatthias Springer                                      unsigned analysisFuzzerSeed = 0) {
5937a1579acSMatthias Springer   // Collect ops so we can build our own reverse traversal.
5947a1579acSMatthias Springer   SmallVector<Operation *> ops;
5957a1579acSMatthias Springer   op->walk([&](Operation *op) {
5967a1579acSMatthias Springer     // No tensors => no buffers.
5977a1579acSMatthias Springer     if (!hasTensorSemantics(op))
5987a1579acSMatthias Springer       return;
5997a1579acSMatthias Springer     ops.push_back(op);
6007a1579acSMatthias Springer   });
6017a1579acSMatthias Springer 
6027a1579acSMatthias Springer   return inPlaceAnalysis(ops, aliasInfo, state, domInfo, analysisFuzzerSeed);
6037a1579acSMatthias Springer }
6047a1579acSMatthias Springer 
6057a1579acSMatthias Springer /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
6067a1579acSMatthias Springer static void equivalenceAnalysis(SmallVector<Operation *> &ops,
6077a1579acSMatthias Springer                                 BufferizationAliasInfo &aliasInfo,
6087a1579acSMatthias Springer                                 BufferizationState &state) {
6097a1579acSMatthias Springer   for (Operation *op : ops)
6107a1579acSMatthias Springer     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
6117a1579acSMatthias Springer       for (OpResult opResult : op->getOpResults())
6127a1579acSMatthias Springer         if (opResult.getType().isa<TensorType>())
6137a1579acSMatthias Springer           for (OpOperand *opOperand :
6147a1579acSMatthias Springer                bufferizableOp.getAliasingOpOperand(opResult, state))
6157a1579acSMatthias Springer             if (state.isInPlace(*opOperand))
6167a1579acSMatthias Springer               if (bufferizableOp.bufferRelation(opResult, state) ==
6177a1579acSMatthias Springer                   BufferRelation::Equivalent)
6187a1579acSMatthias Springer                 aliasInfo.unionEquivalenceClasses(opResult, opOperand->get());
6197a1579acSMatthias Springer }
6207a1579acSMatthias Springer 
6217a1579acSMatthias Springer /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
6227a1579acSMatthias Springer /// in `op`.
6237a1579acSMatthias Springer static void equivalenceAnalysis(Operation *op,
6247a1579acSMatthias Springer                                 BufferizationAliasInfo &aliasInfo,
6257a1579acSMatthias Springer                                 BufferizationState &state) {
6267a1579acSMatthias Springer   // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
6277a1579acSMatthias Springer   SmallVector<Operation *> ops;
6287a1579acSMatthias Springer   op->walk<WalkOrder::PostOrder>([&](Operation *op) {
6297a1579acSMatthias Springer     // No tensors => no buffers.
6307a1579acSMatthias Springer     if (none_of(op->getResultTypes(), isaTensor))
6317a1579acSMatthias Springer       return;
6327a1579acSMatthias Springer     ops.push_back(op);
6337a1579acSMatthias Springer   });
6347a1579acSMatthias Springer 
6357a1579acSMatthias Springer   equivalenceAnalysis(ops, aliasInfo, state);
6367a1579acSMatthias Springer }
6377a1579acSMatthias Springer 
6387a1579acSMatthias Springer /// Assert that the current bufferization decisions are consistent.
6397a1579acSMatthias Springer static LogicalResult
6407a1579acSMatthias Springer checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
6417a1579acSMatthias Springer                           BufferizationState &state,
6427a1579acSMatthias Springer                           const BufferizationAliasInfo &aliasInfo) {
6437a1579acSMatthias Springer   const BufferizationOptions &options = state.getOptions();
6447a1579acSMatthias Springer   Operation *inconsistentOp = nullptr;
6457a1579acSMatthias Springer   WalkResult walkResult = op->walk([&](Operation *op) {
6467a1579acSMatthias Springer     if (auto bufferizableOp = options.dynCastBufferizableOp(op))
6477a1579acSMatthias Springer       for (OpOperand &opOperand : op->getOpOperands())
6487a1579acSMatthias Springer         if (opOperand.get().getType().isa<TensorType>()) {
6497a1579acSMatthias Springer           if (wouldCreateReadAfterWriteInterference(
6507a1579acSMatthias Springer                   opOperand, domInfo, state, aliasInfo,
6517a1579acSMatthias Springer                   /*checkConsistencyOnly=*/true)) {
6527a1579acSMatthias Springer             // This error can happen if certain "mustBufferizeInPlace" interface
6537a1579acSMatthias Springer             // methods are implemented incorrectly, such that the IR already has
6547a1579acSMatthias Springer             // a RaW conflict before making any bufferization decisions.
6557a1579acSMatthias Springer             inconsistentOp = op;
6567a1579acSMatthias Springer             return WalkResult::interrupt();
6577a1579acSMatthias Springer           }
6587a1579acSMatthias Springer         }
6597a1579acSMatthias Springer     return WalkResult::advance();
6607a1579acSMatthias Springer   });
6617a1579acSMatthias Springer 
6627a1579acSMatthias Springer   if (walkResult.wasInterrupted())
6637a1579acSMatthias Springer     return inconsistentOp->emitError("input IR has RaW conflict");
6647a1579acSMatthias Springer   return success();
6657a1579acSMatthias Springer }
6667a1579acSMatthias Springer 
6677a1579acSMatthias Springer /// Annotate the IR with the result of the analysis. For testing/debugging only.
6687a1579acSMatthias Springer static void
6697a1579acSMatthias Springer annotateOpsWithBufferizationMarkers(Operation *op,
6707a1579acSMatthias Springer                                     const BufferizationAliasInfo &aliasInfo,
6717a1579acSMatthias Springer                                     BufferizationState &state) {
6727a1579acSMatthias Springer   op->walk([&](Operation *op) {
6737a1579acSMatthias Springer     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
6747a1579acSMatthias Springer       for (OpOperand &opOperand : op->getOpOperands())
6757a1579acSMatthias Springer         if (opOperand.get().getType().isa<TensorType>())
6767a1579acSMatthias Springer           setInPlaceOpOperand(opOperand, aliasInfo.isInPlace(opOperand));
6777a1579acSMatthias Springer   });
6787a1579acSMatthias Springer }
6797a1579acSMatthias Springer 
6807a1579acSMatthias Springer /// Assert that IR is in destination-passing style. I.e., every value that is
6817a1579acSMatthias Springer /// returned or yielded from a block is:
6827a1579acSMatthias Springer /// * aliasing a bbArg of that block or a parent block, or
6837a1579acSMatthias Springer /// * aliasing an OpResult of a op in a parent block.
6847a1579acSMatthias Springer ///
6857a1579acSMatthias Springer /// Example:
6867a1579acSMatthias Springer /// ```
6877a1579acSMatthias Springer /// %0 = "some_op" : tensor<?xf32>
6887a1579acSMatthias Springer /// %1 = scf.if %c -> (tensor<?xf32>) {
6897a1579acSMatthias Springer ///   scf.yield %0 : tensor<?xf32>
6907a1579acSMatthias Springer /// } else {
6917a1579acSMatthias Springer ///   %t = linalg.init_tensor : tensor<?xf32>
6927a1579acSMatthias Springer ///   scf.yield %t : tensor<?xf32>
6937a1579acSMatthias Springer /// }
6947a1579acSMatthias Springer /// ```
6957a1579acSMatthias Springer /// In the above example, the first scf.yield op satifies destination-passing
6967a1579acSMatthias Springer /// style because the yielded value %0 is defined in the parent block. The
6977a1579acSMatthias Springer /// second scf.yield op does not satisfy destination-passing style because the
6987a1579acSMatthias Springer /// yielded value %t is defined in the same block as the scf.yield op.
6997a1579acSMatthias Springer // TODO: The current implementation checks for equivalent values instead of
7007a1579acSMatthias Springer // aliasing values, which is stricter than needed. We can currently not check
7017a1579acSMatthias Springer // for aliasing values because the analysis is a maybe-alias analysis and we
7027a1579acSMatthias Springer // need a must-alias analysis here.
703cdb7675cSMatthias Springer static LogicalResult
704cdb7675cSMatthias Springer assertDestinationPassingStyle(Operation *op, BufferizationState &state,
7057a1579acSMatthias Springer                               BufferizationAliasInfo &aliasInfo,
706cdb7675cSMatthias Springer                               SmallVector<Operation *> &newOps) {
7077a1579acSMatthias Springer   LogicalResult status = success();
7087a1579acSMatthias Springer   DominanceInfo domInfo(op);
7097a1579acSMatthias Springer   op->walk([&](Operation *returnOp) {
7107a1579acSMatthias Springer     if (!isRegionReturnLike(returnOp))
7117a1579acSMatthias Springer       return WalkResult::advance();
7127a1579acSMatthias Springer 
7137a1579acSMatthias Springer     for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
7147a1579acSMatthias Springer       Value returnVal = returnValOperand.get();
7157a1579acSMatthias Springer       // Skip non-tensor values.
7167a1579acSMatthias Springer       if (!returnVal.getType().isa<TensorType>())
7177a1579acSMatthias Springer         continue;
7187a1579acSMatthias Springer 
7197a1579acSMatthias Springer       bool foundEquivValue = false;
7207a1579acSMatthias Springer       aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
7217a1579acSMatthias Springer         if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
7227a1579acSMatthias Springer           Operation *definingOp = bbArg.getOwner()->getParentOp();
7237a1579acSMatthias Springer           if (definingOp->isProperAncestor(returnOp))
7247a1579acSMatthias Springer             foundEquivValue = true;
7257a1579acSMatthias Springer           return;
7267a1579acSMatthias Springer         }
7277a1579acSMatthias Springer 
7287a1579acSMatthias Springer         Operation *definingOp = equivVal.getDefiningOp();
7297a1579acSMatthias Springer         if (definingOp->getBlock()->findAncestorOpInBlock(
7307a1579acSMatthias Springer                 *returnOp->getParentOp()))
7317a1579acSMatthias Springer           // Skip ops that happen after `returnOp` and parent ops.
7327a1579acSMatthias Springer           if (happensBefore(definingOp, returnOp, domInfo))
7337a1579acSMatthias Springer             foundEquivValue = true;
7347a1579acSMatthias Springer       });
7357a1579acSMatthias Springer 
7367a1579acSMatthias Springer       if (!foundEquivValue)
7377a1579acSMatthias Springer         status =
7387a1579acSMatthias Springer             returnOp->emitError()
7397a1579acSMatthias Springer             << "operand #" << returnValOperand.getOperandNumber()
7407a1579acSMatthias Springer             << " of ReturnLike op does not satisfy destination passing style";
7417a1579acSMatthias Springer     }
7427a1579acSMatthias Springer 
7437a1579acSMatthias Springer     return WalkResult::advance();
7447a1579acSMatthias Springer   });
7457a1579acSMatthias Springer 
7467a1579acSMatthias Springer   return status;
7477a1579acSMatthias Springer }
7487a1579acSMatthias Springer 
7497a1579acSMatthias Springer LogicalResult bufferization::analyzeOp(Operation *op,
7507a1579acSMatthias Springer                                        AnalysisBufferizationState &state) {
7517a1579acSMatthias Springer   DominanceInfo domInfo(op);
7527a1579acSMatthias Springer   BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
7537a1579acSMatthias Springer   const auto &options =
7547a1579acSMatthias Springer       static_cast<const AnalysisBufferizationOptions &>(state.getOptions());
7557a1579acSMatthias Springer 
7567a1579acSMatthias Springer   if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
7577a1579acSMatthias Springer     return failure();
7587a1579acSMatthias Springer 
7597a1579acSMatthias Springer   // If the analysis fails, just return.
7607a1579acSMatthias Springer   if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo,
7617a1579acSMatthias Springer                              options.analysisFuzzerSeed)))
7627a1579acSMatthias Springer     return failure();
7637a1579acSMatthias Springer   equivalenceAnalysis(op, aliasInfo, state);
7647a1579acSMatthias Springer 
765cdb7675cSMatthias Springer   for (const PostAnalysisStepFn &fn : options.postAnalysisSteps) {
7667a1579acSMatthias Springer     SmallVector<Operation *> newOps;
767cdb7675cSMatthias Springer     if (failed(fn(op, state, aliasInfo, newOps)))
7687a1579acSMatthias Springer       return failure();
769cdb7675cSMatthias Springer     // Analyze ops that were created by the PostAnalysisStepFn.
7707a1579acSMatthias Springer     if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo)))
7717a1579acSMatthias Springer       return failure();
7727a1579acSMatthias Springer     equivalenceAnalysis(newOps, aliasInfo, state);
7737a1579acSMatthias Springer   }
7747a1579acSMatthias Springer 
7757a1579acSMatthias Springer   if (!options.allowReturnMemref) {
7767a1579acSMatthias Springer     SmallVector<Operation *> newOps;
777cdb7675cSMatthias Springer     if (failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps)))
7787a1579acSMatthias Springer       return failure();
7797a1579acSMatthias Springer   }
7807a1579acSMatthias Springer 
781*4ec00fb3SMatthias Springer   // Analysis verification: After setting up alias/equivalence sets, each op
782*4ec00fb3SMatthias Springer   // can check for expected invariants/limitations and fail the analysis if
783*4ec00fb3SMatthias Springer   // necessary.
784*4ec00fb3SMatthias Springer   bool passedAnalysis = true;
785*4ec00fb3SMatthias Springer   op->walk([&](Operation *op) {
786*4ec00fb3SMatthias Springer     if (BufferizableOpInterface bufferizableOp =
787*4ec00fb3SMatthias Springer             options.dynCastBufferizableOp(op))
788*4ec00fb3SMatthias Springer       if (failed(bufferizableOp.verifyAnalysis(state)))
789*4ec00fb3SMatthias Springer         passedAnalysis = false;
790*4ec00fb3SMatthias Springer   });
791*4ec00fb3SMatthias Springer   if (!passedAnalysis)
792*4ec00fb3SMatthias Springer     return failure();
793*4ec00fb3SMatthias Springer 
7947a1579acSMatthias Springer   // Annotate operations if we only want to report the analysis.
7957a1579acSMatthias Springer   if (options.testAnalysisOnly)
7967a1579acSMatthias Springer     annotateOpsWithBufferizationMarkers(op, aliasInfo, state);
7977a1579acSMatthias Springer 
7987a1579acSMatthias Springer   return success();
7997a1579acSMatthias Springer }
8007a1579acSMatthias Springer 
8017a1579acSMatthias Springer LogicalResult bufferization::runOneShotBufferize(
8027a1579acSMatthias Springer     Operation *op, std::unique_ptr<AnalysisBufferizationOptions> options) {
8037a1579acSMatthias Springer   AnalysisBufferizationState state(op, *options);
8047a1579acSMatthias Springer   if (failed(analyzeOp(op, state)))
8057a1579acSMatthias Springer     return failure();
8067a1579acSMatthias Springer   if (options->testAnalysisOnly)
8077a1579acSMatthias Springer     return success();
8087a1579acSMatthias Springer   return bufferizeOp(op, state);
8097a1579acSMatthias Springer }
810