17a1579acSMatthias Springer //===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===// 27a1579acSMatthias Springer // 37a1579acSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 47a1579acSMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 57a1579acSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 67a1579acSMatthias Springer // 77a1579acSMatthias Springer //===----------------------------------------------------------------------===// 87a1579acSMatthias Springer // 97a1579acSMatthias Springer // One-Shot Analysis analyzes function bodies. Function boundaries (FuncOp 107a1579acSMatthias Springer // bbArgs, CallOps, ReturnOps) are treated as "unknown" ops. 117a1579acSMatthias Springer // ModuleBufferization.cpp is an extension of One-Shot Analysis for simple 127a1579acSMatthias Springer // call graphs. 137a1579acSMatthias Springer // 147a1579acSMatthias Springer // One-Shot Bufferize consists of two phases. 157a1579acSMatthias Springer // 167a1579acSMatthias Springer // 1. Analyze ops to decide which OpResults can bufferize inplace, i.e., without 177a1579acSMatthias Springer // inserting buffer copies. The analysis queries op bufferization semantics 187a1579acSMatthias Springer // via `BufferizableOpInterface`. 197a1579acSMatthias Springer // 2. Bufferize ops by calling `BufferizableOpInterface::bufferize`. This 207a1579acSMatthias Springer // function does not generate buffer copies for OpResults that were decided 217a1579acSMatthias Springer // to bufferize inplace during the analysis phase. 227a1579acSMatthias Springer // 237a1579acSMatthias Springer // This file contains only the analysis. The actual bufferization is implemented 247a1579acSMatthias Springer // via `bufferizeOp` (Bufferize.h). For convenience, this file also contains a 257a1579acSMatthias Springer // helper function `runOneShotBufferize` that analyzes an op (and its nested 267a1579acSMatthias Springer // ops) and then bufferizes it. 277a1579acSMatthias Springer // 287a1579acSMatthias Springer // Inplace bufferization decisions are passed from the analysis to the 297a1579acSMatthias Springer // bufferization phase via `BufferizationState` and `BufferizationAliasInfo`. 307a1579acSMatthias Springer // They can be printed for debugging purposes with `testAnalysisOnly`. 317a1579acSMatthias Springer // 327a1579acSMatthias Springer // Ops that do not implement `BufferizableOpInterface` can be analyzed but are 337a1579acSMatthias Springer // treated conservatively. E.g., the analysis has to assume that their tensor 347a1579acSMatthias Springer // OpOperands bufferize to memory writes. While such ops can be analyzed, they 357a1579acSMatthias Springer // are not bufferized and remain in the IR. to_tensor and to_memref ops are 367a1579acSMatthias Springer // inserted at the bufferization boundary. 377a1579acSMatthias Springer // 387a1579acSMatthias Springer // This analysis caters to high-performance codegen where buffer reuse is deemed 397a1579acSMatthias Springer // critical: the analysis should fail if the bufferized form of the function 407a1579acSMatthias Springer // needs to return a buffer, unless `allowReturnMemref` is enabled. 417a1579acSMatthias Springer 427a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 437a1579acSMatthias Springer 447a1579acSMatthias Springer #include <random> 457a1579acSMatthias Springer 467a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 477a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 487a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 497a1579acSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 507a1579acSMatthias Springer #include "mlir/IR/AsmState.h" 517a1579acSMatthias Springer #include "mlir/IR/Dominance.h" 527a1579acSMatthias Springer #include "mlir/IR/Operation.h" 537a1579acSMatthias Springer #include "mlir/IR/TypeUtilities.h" 547a1579acSMatthias Springer #include "mlir/Interfaces/ControlFlowInterfaces.h" 557a1579acSMatthias Springer #include "llvm/ADT/DenseSet.h" 567a1579acSMatthias Springer #include "llvm/ADT/SetVector.h" 577a1579acSMatthias Springer 587a1579acSMatthias Springer using namespace mlir; 597a1579acSMatthias Springer using namespace mlir::bufferization; 607a1579acSMatthias Springer 617a1579acSMatthias Springer static bool isaTensor(Type t) { return t.isa<TensorType>(); } 627a1579acSMatthias Springer 637a1579acSMatthias Springer //===----------------------------------------------------------------------===// 647a1579acSMatthias Springer // Bufferization-specific attribute manipulation. 657a1579acSMatthias Springer // These are for testing and debugging only. Bufferization information is 667a1579acSMatthias Springer // stored in BufferizationAliasInfo. When run with `testAnalysisOnly`, the IR 677a1579acSMatthias Springer // is annotated with the results of the analysis (copied from 687a1579acSMatthias Springer // BufferizationAliasInfo), so that they can be checked in tests. 697a1579acSMatthias Springer //===----------------------------------------------------------------------===// 707a1579acSMatthias Springer 717a1579acSMatthias Springer /// Attribute marker to specify op results that can be bufferized inPlace. 727a1579acSMatthias Springer constexpr StringLiteral kInPlaceResultsAttrName = "__inplace_operands_attr__"; 737a1579acSMatthias Springer 747a1579acSMatthias Springer /// Mark whether OpOperand will be bufferized inplace. 757a1579acSMatthias Springer static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) { 767a1579acSMatthias Springer Operation *op = opOperand.getOwner(); 777a1579acSMatthias Springer auto attr = 787a1579acSMatthias Springer op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null<ArrayAttr>(); 797a1579acSMatthias Springer SmallVector<StringRef> inPlaceVector; 807a1579acSMatthias Springer if (attr) { 817a1579acSMatthias Springer inPlaceVector = SmallVector<StringRef>( 827a1579acSMatthias Springer llvm::to_vector<4>(attr.getAsValueRange<StringAttr>())); 837a1579acSMatthias Springer } else { 847a1579acSMatthias Springer inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none"); 857a1579acSMatthias Springer for (OpOperand &opOperand : op->getOpOperands()) 867a1579acSMatthias Springer if (opOperand.get().getType().isa<TensorType>()) 877a1579acSMatthias Springer inPlaceVector[opOperand.getOperandNumber()] = "false"; 887a1579acSMatthias Springer } 897a1579acSMatthias Springer 907a1579acSMatthias Springer inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false"; 917a1579acSMatthias Springer op->setAttr(kInPlaceResultsAttrName, 927a1579acSMatthias Springer OpBuilder(op).getStrArrayAttr(inPlaceVector)); 937a1579acSMatthias Springer } 947a1579acSMatthias Springer 957a1579acSMatthias Springer //===----------------------------------------------------------------------===// 967a1579acSMatthias Springer // BufferizationAliasInfo 977a1579acSMatthias Springer //===----------------------------------------------------------------------===// 987a1579acSMatthias Springer 997a1579acSMatthias Springer BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) { 1007a1579acSMatthias Springer rootOp->walk([&](Operation *op) { 1017a1579acSMatthias Springer for (Value v : op->getResults()) 1027a1579acSMatthias Springer if (v.getType().isa<TensorType>()) 1037a1579acSMatthias Springer createAliasInfoEntry(v); 1047a1579acSMatthias Springer for (Region &r : op->getRegions()) 1057a1579acSMatthias Springer for (Block &b : r.getBlocks()) 1067a1579acSMatthias Springer for (auto bbArg : b.getArguments()) 1077a1579acSMatthias Springer if (bbArg.getType().isa<TensorType>()) 1087a1579acSMatthias Springer createAliasInfoEntry(bbArg); 1097a1579acSMatthias Springer }); 1107a1579acSMatthias Springer } 1117a1579acSMatthias Springer 1127a1579acSMatthias Springer /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the 1137a1579acSMatthias Springer /// beginning the alias and equivalence sets only contain `v` itself. 1147a1579acSMatthias Springer void BufferizationAliasInfo::createAliasInfoEntry(Value v) { 1157a1579acSMatthias Springer aliasInfo.insert(v); 1167a1579acSMatthias Springer equivalentInfo.insert(v); 1177a1579acSMatthias Springer } 1187a1579acSMatthias Springer 1197a1579acSMatthias Springer /// Insert an info entry for `newValue` and merge its alias set with that of 1207a1579acSMatthias Springer /// `alias`. 1217a1579acSMatthias Springer void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) { 1227a1579acSMatthias Springer createAliasInfoEntry(newValue); 1237a1579acSMatthias Springer aliasInfo.unionSets(newValue, alias); 1247a1579acSMatthias Springer } 1257a1579acSMatthias Springer 1267a1579acSMatthias Springer /// Insert an info entry for `newValue` and merge its alias set with that of 1277a1579acSMatthias Springer /// `alias`. Additionally, merge their equivalence classes. 1287a1579acSMatthias Springer void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue, 1297a1579acSMatthias Springer Value alias) { 1307a1579acSMatthias Springer insertNewBufferAlias(newValue, alias); 1317a1579acSMatthias Springer equivalentInfo.unionSets(newValue, alias); 1327a1579acSMatthias Springer } 1337a1579acSMatthias Springer 1347a1579acSMatthias Springer /// Return `true` if a value was marked as in-place bufferized. 1357a1579acSMatthias Springer bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const { 1367a1579acSMatthias Springer return inplaceBufferized.contains(&operand); 1377a1579acSMatthias Springer } 1387a1579acSMatthias Springer 1397a1579acSMatthias Springer /// Set the inPlace bufferization spec to true. 1407a1579acSMatthias Springer void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand, 1417a1579acSMatthias Springer BufferizationState &state) { 1427a1579acSMatthias Springer markInPlace(operand); 143585a8a32SMatthias Springer for (OpResult result : state.getAliasingOpResult(operand)) 1447a1579acSMatthias Springer aliasInfo.unionSets(result, operand.get()); 1457a1579acSMatthias Springer } 1467a1579acSMatthias Springer 1477a1579acSMatthias Springer /// Set the inPlace bufferization spec to false. 1487a1579acSMatthias Springer void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) { 1497a1579acSMatthias Springer assert(!inplaceBufferized.contains(&operand) && 1507a1579acSMatthias Springer "OpOperand was already decided to bufferize inplace"); 1517a1579acSMatthias Springer } 1527a1579acSMatthias Springer 1537a1579acSMatthias Springer /// Apply `fun` to all the members of the equivalence class of `v`. 1547a1579acSMatthias Springer void BufferizationAliasInfo::applyOnEquivalenceClass( 1557a1579acSMatthias Springer Value v, function_ref<void(Value)> fun) const { 1567a1579acSMatthias Springer auto leaderIt = equivalentInfo.findLeader(v); 1577a1579acSMatthias Springer for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit; 1587a1579acSMatthias Springer ++mit) { 1597a1579acSMatthias Springer fun(*mit); 1607a1579acSMatthias Springer } 1617a1579acSMatthias Springer } 1627a1579acSMatthias Springer 1637a1579acSMatthias Springer /// Apply `fun` to all aliases of `v`. 1647a1579acSMatthias Springer void BufferizationAliasInfo::applyOnAliases( 1657a1579acSMatthias Springer Value v, function_ref<void(Value)> fun) const { 1667a1579acSMatthias Springer auto leaderIt = aliasInfo.findLeader(v); 1677a1579acSMatthias Springer for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) { 1687a1579acSMatthias Springer fun(*mit); 1697a1579acSMatthias Springer } 1707a1579acSMatthias Springer } 1717a1579acSMatthias Springer 1727a1579acSMatthias Springer BufferizationAliasInfo::EquivalenceClassRangeType 1737a1579acSMatthias Springer BufferizationAliasInfo::getAliases(Value v) const { 1747a1579acSMatthias Springer DenseSet<Value> res; 1757a1579acSMatthias Springer auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v)); 1767a1579acSMatthias Springer for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end(); 1777a1579acSMatthias Springer mit != meit; ++mit) { 1787a1579acSMatthias Springer res.insert(static_cast<Value>(*mit)); 1797a1579acSMatthias Springer } 1807a1579acSMatthias Springer return BufferizationAliasInfo::EquivalenceClassRangeType( 1817a1579acSMatthias Springer aliasInfo.member_begin(it), aliasInfo.member_end()); 1827a1579acSMatthias Springer } 1837a1579acSMatthias Springer 1847a1579acSMatthias Springer //===----------------------------------------------------------------------===// 1857a1579acSMatthias Springer // AnalysisBufferizationState 1867a1579acSMatthias Springer //===----------------------------------------------------------------------===// 1877a1579acSMatthias Springer 1887a1579acSMatthias Springer AnalysisBufferizationState::AnalysisBufferizationState( 1897a1579acSMatthias Springer Operation *op, const AnalysisBufferizationOptions &options) 1907a1579acSMatthias Springer : BufferizationState(options), aliasInfo(op) { 1917a1579acSMatthias Springer // Set up alias sets for OpResults that must bufferize in-place. This should 1927a1579acSMatthias Springer // be done before making any other bufferization decisions. 1937a1579acSMatthias Springer op->walk([&](BufferizableOpInterface bufferizableOp) { 1947a1579acSMatthias Springer if (!options.isOpAllowed(bufferizableOp)) 1957a1579acSMatthias Springer return WalkResult::skip(); 1967a1579acSMatthias Springer for (OpOperand &opOperand : bufferizableOp->getOpOperands()) { 1977a1579acSMatthias Springer if (opOperand.get().getType().isa<TensorType>()) 1987a1579acSMatthias Springer if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) { 199585a8a32SMatthias Springer for (OpResult opResult : 2007a1579acSMatthias Springer bufferizableOp.getAliasingOpResult(opOperand, *this)) 2017a1579acSMatthias Springer aliasInfo.unionAliasSets(opOperand.get(), opResult); 2027a1579acSMatthias Springer aliasInfo.markInPlace(opOperand); 2037a1579acSMatthias Springer } 2047a1579acSMatthias Springer } 2057a1579acSMatthias Springer return WalkResult::advance(); 2067a1579acSMatthias Springer }); 2077a1579acSMatthias Springer } 2087a1579acSMatthias Springer 2097a1579acSMatthias Springer bool AnalysisBufferizationState::isInPlace(OpOperand &opOperand) const { 2107a1579acSMatthias Springer return aliasInfo.isInPlace(opOperand); 2117a1579acSMatthias Springer } 2127a1579acSMatthias Springer 2137a1579acSMatthias Springer bool AnalysisBufferizationState::areEquivalentBufferizedValues(Value v1, 2147a1579acSMatthias Springer Value v2) const { 2157a1579acSMatthias Springer return aliasInfo.areEquivalentBufferizedValues(v1, v2); 2167a1579acSMatthias Springer } 2177a1579acSMatthias Springer 2187a1579acSMatthias Springer //===----------------------------------------------------------------------===// 2197a1579acSMatthias Springer // Bufferization-specific alias analysis. 2207a1579acSMatthias Springer //===----------------------------------------------------------------------===// 2217a1579acSMatthias Springer 2227a1579acSMatthias Springer /// Return true if opOperand has been decided to bufferize in-place. 2237a1579acSMatthias Springer static bool isInplaceMemoryWrite(OpOperand &opOperand, 2247a1579acSMatthias Springer const BufferizationAliasInfo &aliasInfo, 2257a1579acSMatthias Springer BufferizationState &state) { 2267a1579acSMatthias Springer // OpOperands that do not bufferize to a memory write do not write in-place. 2277a1579acSMatthias Springer if (!state.bufferizesToMemoryWrite(opOperand)) 2287a1579acSMatthias Springer return false; 2297a1579acSMatthias Springer // Check current bufferization decisions. 2307a1579acSMatthias Springer return aliasInfo.isInPlace(opOperand); 2317a1579acSMatthias Springer } 2327a1579acSMatthias Springer 2337a1579acSMatthias Springer /// Return true if, under current bufferization decisions, the buffer of `value` 2347a1579acSMatthias Springer /// is not writable. 2357a1579acSMatthias Springer static bool aliasesNonWritableBuffer(Value value, 2367a1579acSMatthias Springer const BufferizationAliasInfo &aliasInfo, 2377a1579acSMatthias Springer BufferizationState &state) { 2387a1579acSMatthias Springer bool foundNonWritableBuffer = false; 2397a1579acSMatthias Springer aliasInfo.applyOnAliases(value, [&](Value v) { 2407a1579acSMatthias Springer // Query BufferizableOpInterface to see if the value is writable. 2417a1579acSMatthias Springer // TODO: Out-of-place bufferized value could be considered writable. 2427a1579acSMatthias Springer if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(v)) 2437a1579acSMatthias Springer if (bufferizableOp && bufferizableOp.isWritable(v, state)) 2447a1579acSMatthias Springer return; 2457a1579acSMatthias Springer 2467a1579acSMatthias Springer // Query BufferizableOpInterface to see if the BlockArgument is writable. 2477a1579acSMatthias Springer if (auto bbArg = v.dyn_cast<BlockArgument>()) 2487a1579acSMatthias Springer if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp( 2497a1579acSMatthias Springer bbArg.getOwner()->getParentOp())) 2507a1579acSMatthias Springer if (bufferizableOp.isWritable(bbArg, state)) 2517a1579acSMatthias Springer return; 2527a1579acSMatthias Springer 2537a1579acSMatthias Springer foundNonWritableBuffer = true; 2547a1579acSMatthias Springer }); 2557a1579acSMatthias Springer 2567a1579acSMatthias Springer return foundNonWritableBuffer; 2577a1579acSMatthias Springer } 2587a1579acSMatthias Springer 2597a1579acSMatthias Springer /// Return true if the buffer to which `operand` would bufferize is equivalent 2607a1579acSMatthias Springer /// to some buffer write. 2617a1579acSMatthias Springer static bool aliasesInPlaceWrite(Value value, 2627a1579acSMatthias Springer const BufferizationAliasInfo &aliasInfo, 2637a1579acSMatthias Springer BufferizationState &state) { 2647a1579acSMatthias Springer bool foundInplaceWrite = false; 2657a1579acSMatthias Springer aliasInfo.applyOnAliases(value, [&](Value v) { 2667a1579acSMatthias Springer for (auto &use : v.getUses()) { 2677a1579acSMatthias Springer if (isInplaceMemoryWrite(use, aliasInfo, state)) { 2687a1579acSMatthias Springer foundInplaceWrite = true; 2697a1579acSMatthias Springer return; 2707a1579acSMatthias Springer } 2717a1579acSMatthias Springer } 2727a1579acSMatthias Springer }); 2737a1579acSMatthias Springer return foundInplaceWrite; 2747a1579acSMatthias Springer } 2757a1579acSMatthias Springer 2767a1579acSMatthias Springer /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors 2777a1579acSMatthias Springer /// properly dominates `b` and `b` is not inside `a`. 2787a1579acSMatthias Springer static bool happensBefore(Operation *a, Operation *b, 2797a1579acSMatthias Springer const DominanceInfo &domInfo) { 2807a1579acSMatthias Springer do { 2817a1579acSMatthias Springer // TODO: Instead of isProperAncestor + properlyDominates, we should use 2827a1579acSMatthias Springer // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false) 2837a1579acSMatthias Springer if (a->isProperAncestor(b)) 2847a1579acSMatthias Springer return false; 2857a1579acSMatthias Springer if (domInfo.properlyDominates(a, b)) 2867a1579acSMatthias Springer return true; 2877a1579acSMatthias Springer } while ((a = a->getParentOp())); 2887a1579acSMatthias Springer return false; 2897a1579acSMatthias Springer } 2907a1579acSMatthias Springer 2917a1579acSMatthias Springer /// Annotate IR with details about the detected RaW conflict. 2927a1579acSMatthias Springer static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, 2937a1579acSMatthias Springer Value lastWrite) { 2947a1579acSMatthias Springer static uint64_t counter = 0; 2957a1579acSMatthias Springer Operation *readingOp = uRead->getOwner(); 2967a1579acSMatthias Springer Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 2977a1579acSMatthias Springer 2987a1579acSMatthias Springer OpBuilder b(conflictingWritingOp->getContext()); 2997a1579acSMatthias Springer std::string id = "C_" + std::to_string(counter++); 3007a1579acSMatthias Springer 3017a1579acSMatthias Springer std::string conflictingWriteAttr = 3027a1579acSMatthias Springer id + 3037a1579acSMatthias Springer "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) + 3047a1579acSMatthias Springer "]"; 3057a1579acSMatthias Springer conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr()); 3067a1579acSMatthias Springer 3077a1579acSMatthias Springer std::string readAttr = 3087a1579acSMatthias Springer id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]"; 3097a1579acSMatthias Springer readingOp->setAttr(readAttr, b.getUnitAttr()); 3107a1579acSMatthias Springer 3117a1579acSMatthias Springer if (auto opResult = lastWrite.dyn_cast<OpResult>()) { 3127a1579acSMatthias Springer std::string lastWriteAttr = id + "[LAST-WRITE: result " + 3137a1579acSMatthias Springer std::to_string(opResult.getResultNumber()) + 3147a1579acSMatthias Springer "]"; 3157a1579acSMatthias Springer opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr()); 3167a1579acSMatthias Springer } else { 3177a1579acSMatthias Springer auto bbArg = lastWrite.cast<BlockArgument>(); 3187a1579acSMatthias Springer std::string lastWriteAttr = 3197a1579acSMatthias Springer id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]"; 3207a1579acSMatthias Springer bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr()); 3217a1579acSMatthias Springer } 3227a1579acSMatthias Springer } 3237a1579acSMatthias Springer 3247a1579acSMatthias Springer /// Given sets of uses and writes, return true if there is a RaW conflict under 3257a1579acSMatthias Springer /// the assumption that all given reads/writes alias the same buffer and that 3267a1579acSMatthias Springer /// all given writes bufferize inplace. 3277a1579acSMatthias Springer /// 3287a1579acSMatthias Springer /// A conflict is: According to SSA use-def chains, a read R is supposed to read 3297a1579acSMatthias Springer /// the result of a write W1. But because of bufferization decisions, R actually 3307a1579acSMatthias Springer /// reads another write W2. 3317a1579acSMatthias Springer static bool hasReadAfterWriteInterference( 3327a1579acSMatthias Springer const DenseSet<OpOperand *> &usesRead, 3337a1579acSMatthias Springer const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo, 3347a1579acSMatthias Springer BufferizationState &state, const BufferizationAliasInfo &aliasInfo) { 3357a1579acSMatthias Springer const BufferizationOptions &options = state.getOptions(); 3367a1579acSMatthias Springer 3377a1579acSMatthias Springer for (OpOperand *uRead : usesRead) { 3387a1579acSMatthias Springer Operation *readingOp = uRead->getOwner(); 3397a1579acSMatthias Springer 3407a1579acSMatthias Springer // Find most recent writes of uRead by following the SSA use-def chain. 3417a1579acSMatthias Springer // E.g.: 3427a1579acSMatthias Springer // 3437a1579acSMatthias Springer // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32> 3447a1579acSMatthias Springer // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32> 3457a1579acSMatthias Springer // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type 3467a1579acSMatthias Springer // 3477a1579acSMatthias Springer // In the above example, if uRead is the OpOperand of reading_op, lastWrite 3487a1579acSMatthias Springer // is %0. Note that operations that create an alias but do not write (such 3497a1579acSMatthias Springer // as ExtractSliceOp) are skipped. 3507a1579acSMatthias Springer SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get()); 3517a1579acSMatthias Springer 3527a1579acSMatthias Springer // Look for conflicting memory writes. Potential conflicts are writes to an 3537a1579acSMatthias Springer // alias that have been decided to bufferize inplace. 3547a1579acSMatthias Springer for (OpOperand *uConflictingWrite : usesWrite) { 3557a1579acSMatthias Springer // Throughout this loop, check for multiple requirements that have to be 3567a1579acSMatthias Springer // met for uConflictingWrite to be an actual conflict. 3577a1579acSMatthias Springer Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 3587a1579acSMatthias Springer 3597a1579acSMatthias Springer // No conflict if the readingOp dominates conflictingWritingOp, i.e., the 3607a1579acSMatthias Springer // write is not visible when reading. 3617a1579acSMatthias Springer if (happensBefore(readingOp, conflictingWritingOp, domInfo)) 3627a1579acSMatthias Springer continue; 3637a1579acSMatthias Springer 3647a1579acSMatthias Springer // No conflict if the reading use equals the use of the conflicting write. 3657a1579acSMatthias Springer // A use cannot conflict with itself. Note: Just being the same op is not 3667a1579acSMatthias Springer // enough. It has to be the same use. 3677a1579acSMatthias Springer if (uConflictingWrite == uRead) 3687a1579acSMatthias Springer continue; 3697a1579acSMatthias Springer 3707a1579acSMatthias Springer // No conflict if the op interface says so. 3717a1579acSMatthias Springer if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) 3727a1579acSMatthias Springer if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) 3737a1579acSMatthias Springer continue; 3747a1579acSMatthias Springer 3757a1579acSMatthias Springer if (conflictingWritingOp != readingOp) 3767a1579acSMatthias Springer if (auto bufferizableOp = 3777a1579acSMatthias Springer options.dynCastBufferizableOp(conflictingWritingOp)) 3787a1579acSMatthias Springer if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) 3797a1579acSMatthias Springer continue; 3807a1579acSMatthias Springer 3817a1579acSMatthias Springer // Ops are not conflicting if they are in mutually exclusive regions. 3827a1579acSMatthias Springer if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) 3837a1579acSMatthias Springer continue; 3847a1579acSMatthias Springer 3857a1579acSMatthias Springer // Check all possible last writes. 3867a1579acSMatthias Springer for (Value lastWrite : lastWrites) { 3877a1579acSMatthias Springer // No conflict if the conflicting write happens before the last 3887a1579acSMatthias Springer // write. 3897a1579acSMatthias Springer if (Operation *writingOp = lastWrite.getDefiningOp()) { 3907a1579acSMatthias Springer if (happensBefore(conflictingWritingOp, writingOp, domInfo)) 3917a1579acSMatthias Springer // conflictingWritingOp happens before writingOp. No conflict. 3927a1579acSMatthias Springer continue; 3937a1579acSMatthias Springer // No conflict if conflictingWritingOp is contained in writingOp. 3947a1579acSMatthias Springer if (writingOp->isProperAncestor(conflictingWritingOp)) 3957a1579acSMatthias Springer continue; 3967a1579acSMatthias Springer } else { 3977a1579acSMatthias Springer auto bbArg = lastWrite.cast<BlockArgument>(); 3987a1579acSMatthias Springer Block *block = bbArg.getOwner(); 3997a1579acSMatthias Springer if (!block->findAncestorOpInBlock(*conflictingWritingOp)) 4007a1579acSMatthias Springer // conflictingWritingOp happens outside of the block. No 4017a1579acSMatthias Springer // conflict. 4027a1579acSMatthias Springer continue; 4037a1579acSMatthias Springer } 4047a1579acSMatthias Springer 4057a1579acSMatthias Springer // No conflict if the conflicting write and the last write are the same 4067a1579acSMatthias Springer // use. 407585a8a32SMatthias Springer SmallVector<OpResult> aliasingOpResult = 408585a8a32SMatthias Springer state.getAliasingOpResult(*uConflictingWrite); 409585a8a32SMatthias Springer if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite) 4107a1579acSMatthias Springer continue; 4117a1579acSMatthias Springer 4127a1579acSMatthias Springer // All requirements are met. Conflict found! 4137a1579acSMatthias Springer 4147a1579acSMatthias Springer if (options.printConflicts) 4157a1579acSMatthias Springer annotateConflict(uRead, uConflictingWrite, lastWrite); 4167a1579acSMatthias Springer 4177a1579acSMatthias Springer return true; 4187a1579acSMatthias Springer } 4197a1579acSMatthias Springer } 4207a1579acSMatthias Springer } 4217a1579acSMatthias Springer 4227a1579acSMatthias Springer return false; 4237a1579acSMatthias Springer } 4247a1579acSMatthias Springer 4257a1579acSMatthias Springer /// Return true if bufferizing `operand` inplace would create a conflict. A read 4267a1579acSMatthias Springer /// R and a write W of the same alias set is a conflict if inplace bufferization 4277a1579acSMatthias Springer /// of W changes the value read by R to a value different from the one that 4287a1579acSMatthias Springer /// would be expected by tracing back R's origin through SSA use-def chains. 4297a1579acSMatthias Springer /// A conflict can only be introduced by a new alias and/or an inplace 4307a1579acSMatthias Springer /// bufferization decision. 4317a1579acSMatthias Springer /// 4327a1579acSMatthias Springer /// Example: 4337a1579acSMatthias Springer /// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?} 4347a1579acSMatthias Springer /// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32> 4357a1579acSMatthias Springer /// %e = tensor.extract_slice %1 4367a1579acSMatthias Springer /// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32> 4377a1579acSMatthias Springer /// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32> 4387a1579acSMatthias Springer /// 4397a1579acSMatthias Springer /// In the above example, the two TransferWriteOps have already been decided to 4407a1579acSMatthias Springer /// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a 4417a1579acSMatthias Springer /// conflict because: 4427a1579acSMatthias Springer /// * According to SSA use-def chains, we expect to read the result of %1. 4437a1579acSMatthias Springer /// * However, adding an alias {%0, %t} would mean that the second 4447a1579acSMatthias Springer /// TransferWriteOp overwrites the first one. Therefore, the TransferReadOp 4457a1579acSMatthias Springer /// would no longer be reading the result of %1. 4467a1579acSMatthias Springer /// 4477a1579acSMatthias Springer /// If `checkConsistencyOnly` is true, this function checks if there is a 4487a1579acSMatthias Springer /// read-after-write conflict without bufferizing `operand` inplace. This would 4497a1579acSMatthias Springer /// indicate a problem with the current inplace bufferization decisions. 4507a1579acSMatthias Springer /// 4517a1579acSMatthias Springer /// Note: If `checkConsistencyOnly`, this function may be called with a null 4527a1579acSMatthias Springer /// OpResult. In that case, only the consistency of bufferization decisions 4537a1579acSMatthias Springer /// involving aliases of the given OpOperand are checked. 4547a1579acSMatthias Springer static bool wouldCreateReadAfterWriteInterference( 4557a1579acSMatthias Springer OpOperand &operand, const DominanceInfo &domInfo, BufferizationState &state, 4567a1579acSMatthias Springer const BufferizationAliasInfo &aliasInfo, 4577a1579acSMatthias Springer bool checkConsistencyOnly = false) { 4587a1579acSMatthias Springer // Helper function to iterate on aliases of `root` and capture the reads. 4597a1579acSMatthias Springer auto getAliasingReads = [&](DenseSet<OpOperand *> &res, Value root) { 4607a1579acSMatthias Springer aliasInfo.applyOnAliases(root, [&](Value alias) { 4617a1579acSMatthias Springer for (auto &use : alias.getUses()) 4627a1579acSMatthias Springer // Read to a value that aliases root. 4637a1579acSMatthias Springer if (state.bufferizesToMemoryRead(use)) 4647a1579acSMatthias Springer res.insert(&use); 4657a1579acSMatthias Springer }); 4667a1579acSMatthias Springer }; 4677a1579acSMatthias Springer 4687a1579acSMatthias Springer // Helper function to iterate on aliases of `root` and capture the writes. 4697a1579acSMatthias Springer auto getAliasingInplaceWrites = [&](DenseSet<OpOperand *> &res, Value root) { 4707a1579acSMatthias Springer aliasInfo.applyOnAliases(root, [&](Value alias) { 4717a1579acSMatthias Springer for (auto &use : alias.getUses()) 4727a1579acSMatthias Springer // Inplace write to a value that aliases root. 4737a1579acSMatthias Springer if (isInplaceMemoryWrite(use, aliasInfo, state)) 4747a1579acSMatthias Springer res.insert(&use); 4757a1579acSMatthias Springer }); 4767a1579acSMatthias Springer }; 4777a1579acSMatthias Springer 4787a1579acSMatthias Springer // Collect reads and writes of all aliases of OpOperand and OpResult. 4797a1579acSMatthias Springer DenseSet<OpOperand *> usesRead, usesWrite; 4807a1579acSMatthias Springer getAliasingReads(usesRead, operand.get()); 4817a1579acSMatthias Springer getAliasingInplaceWrites(usesWrite, operand.get()); 482585a8a32SMatthias Springer for (OpResult result : state.getAliasingOpResult(operand)) { 4837a1579acSMatthias Springer getAliasingReads(usesRead, result); 4847a1579acSMatthias Springer getAliasingInplaceWrites(usesWrite, result); 4857a1579acSMatthias Springer } 4867a1579acSMatthias Springer if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) 4877a1579acSMatthias Springer usesWrite.insert(&operand); 4887a1579acSMatthias Springer 4897a1579acSMatthias Springer return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state, 4907a1579acSMatthias Springer aliasInfo); 4917a1579acSMatthias Springer } 4927a1579acSMatthias Springer 4937a1579acSMatthias Springer /// Return true if bufferizing `opOperand` inplace would create a write to a 4947a1579acSMatthias Springer /// non-writable buffer. 4957a1579acSMatthias Springer static bool 4967a1579acSMatthias Springer wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, 4977a1579acSMatthias Springer const BufferizationAliasInfo &aliasInfo, 4987a1579acSMatthias Springer BufferizationState &state) { 4997a1579acSMatthias Springer // Certain buffers are not writeable: 5007a1579acSMatthias Springer // 1. A function bbArg that is not inplaceable or 5017a1579acSMatthias Springer // 2. A constant op. 5027a1579acSMatthias Springer bool nonWritable = 5037a1579acSMatthias Springer aliasesNonWritableBuffer(opOperand.get(), aliasInfo, state); 5047a1579acSMatthias Springer if (!nonWritable) 5057a1579acSMatthias Springer return false; 5067a1579acSMatthias Springer 5077a1579acSMatthias Springer // This is a problem only if the buffer is written to via some alias. 5087a1579acSMatthias Springer bool hasWrite = aliasesInPlaceWrite(opOperand.get(), aliasInfo, state) || 5097a1579acSMatthias Springer state.bufferizesToMemoryWrite(opOperand); 5107a1579acSMatthias Springer 511585a8a32SMatthias Springer for (OpResult opResult : state.getAliasingOpResult(opOperand)) 5127a1579acSMatthias Springer hasWrite |= aliasesInPlaceWrite(opResult, aliasInfo, state); 5137a1579acSMatthias Springer 5147a1579acSMatthias Springer return hasWrite; 5157a1579acSMatthias Springer } 5167a1579acSMatthias Springer 5177a1579acSMatthias Springer //===----------------------------------------------------------------------===// 5187a1579acSMatthias Springer // Bufferization analyses. 5197a1579acSMatthias Springer //===----------------------------------------------------------------------===// 5207a1579acSMatthias Springer 5217a1579acSMatthias Springer /// Determine if `operand` can be bufferized in-place. 5227a1579acSMatthias Springer static LogicalResult bufferizableInPlaceAnalysisImpl( 5237a1579acSMatthias Springer OpOperand &operand, BufferizationAliasInfo &aliasInfo, 5247a1579acSMatthias Springer BufferizationState &state, const DominanceInfo &domInfo) { 5257a1579acSMatthias Springer bool foundInterference = 5267a1579acSMatthias Springer wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) || 5277a1579acSMatthias Springer wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo); 5287a1579acSMatthias Springer 5297a1579acSMatthias Springer if (foundInterference) 5307a1579acSMatthias Springer aliasInfo.bufferizeOutOfPlace(operand); 5317a1579acSMatthias Springer else 5327a1579acSMatthias Springer aliasInfo.bufferizeInPlace(operand, state); 5337a1579acSMatthias Springer 5347a1579acSMatthias Springer return success(); 5357a1579acSMatthias Springer } 5367a1579acSMatthias Springer 5377a1579acSMatthias Springer /// Analyze the `ops` to determine which OpOperands are inplaceable. Walk ops in 5387a1579acSMatthias Springer /// reverse and bufferize ops greedily. This is a good starter heuristic. 5397a1579acSMatthias Springer /// 5407a1579acSMatthias Springer /// Even if an op does not read or write, it may still create an alias when 5417a1579acSMatthias Springer /// bufferized in-place. An example of such ops is tensor.extract_slice. 5427a1579acSMatthias Springer /// 5437a1579acSMatthias Springer /// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace: 5447a1579acSMatthias Springer /// 5457a1579acSMatthias Springer /// When bufferized out of place, an ExtractSliceOp lowers to alloc + copy. This 5467a1579acSMatthias Springer /// cannot change the flow of information for either the source or the 5477a1579acSMatthias Springer /// result buffers. 5487a1579acSMatthias Springer /// 5497a1579acSMatthias Springer /// When bufferized inplace, an ExtractSliceOp does not by itself create any 5507a1579acSMatthias Springer /// read or write from memory. Instead, it has the effect of merging the alias 5517a1579acSMatthias Springer /// sets of the source and the result buffers. 5527a1579acSMatthias Springer /// 5537a1579acSMatthias Springer /// An analysis is required to ensure inplace bufferization would not result in 5547a1579acSMatthias Springer /// RaW dependence violations. 5557a1579acSMatthias Springer static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops, 5567a1579acSMatthias Springer BufferizationAliasInfo &aliasInfo, 5577a1579acSMatthias Springer BufferizationState &state, 5587a1579acSMatthias Springer const DominanceInfo &domInfo, 5597a1579acSMatthias Springer unsigned analysisFuzzerSeed = 0) { 5607a1579acSMatthias Springer if (analysisFuzzerSeed) { 5617a1579acSMatthias Springer // This is a fuzzer. For testing purposes only. Randomize the order in which 5627a1579acSMatthias Springer // operations are analyzed. The bufferization quality is likely worse, but 5637a1579acSMatthias Springer // we want to make sure that no assertions are triggered anywhere. 5647a1579acSMatthias Springer std::mt19937 g(analysisFuzzerSeed); 5657a1579acSMatthias Springer llvm::shuffle(ops.begin(), ops.end(), g); 5667a1579acSMatthias Springer } 5677a1579acSMatthias Springer 5687a1579acSMatthias Springer // Walk ops in reverse for better interference analysis. 5697a1579acSMatthias Springer for (Operation *op : reverse(ops)) 5707a1579acSMatthias Springer for (OpOperand &opOperand : op->getOpOperands()) 5717a1579acSMatthias Springer if (opOperand.get().getType().isa<TensorType>()) 5727a1579acSMatthias Springer if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) 5737a1579acSMatthias Springer if (failed(bufferizableInPlaceAnalysisImpl(opOperand, aliasInfo, 5747a1579acSMatthias Springer state, domInfo))) 5757a1579acSMatthias Springer return failure(); 5767a1579acSMatthias Springer 5777a1579acSMatthias Springer return success(); 5787a1579acSMatthias Springer } 5797a1579acSMatthias Springer 5807a1579acSMatthias Springer /// Return true if the given op has a tensor result or a tensor operand. 5817a1579acSMatthias Springer static bool hasTensorSemantics(Operation *op) { 5827a1579acSMatthias Springer bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); 5837a1579acSMatthias Springer bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); 5847a1579acSMatthias Springer return hasTensorResult || hasTensorOperand; 5857a1579acSMatthias Springer } 5867a1579acSMatthias Springer 5877a1579acSMatthias Springer /// Analyze all ops that are contained in `op`. 5887a1579acSMatthias Springer static LogicalResult inPlaceAnalysis(Operation *op, 5897a1579acSMatthias Springer BufferizationAliasInfo &aliasInfo, 5907a1579acSMatthias Springer BufferizationState &state, 5917a1579acSMatthias Springer const DominanceInfo &domInfo, 5927a1579acSMatthias Springer unsigned analysisFuzzerSeed = 0) { 5937a1579acSMatthias Springer // Collect ops so we can build our own reverse traversal. 5947a1579acSMatthias Springer SmallVector<Operation *> ops; 5957a1579acSMatthias Springer op->walk([&](Operation *op) { 5967a1579acSMatthias Springer // No tensors => no buffers. 5977a1579acSMatthias Springer if (!hasTensorSemantics(op)) 5987a1579acSMatthias Springer return; 5997a1579acSMatthias Springer ops.push_back(op); 6007a1579acSMatthias Springer }); 6017a1579acSMatthias Springer 6027a1579acSMatthias Springer return inPlaceAnalysis(ops, aliasInfo, state, domInfo, analysisFuzzerSeed); 6037a1579acSMatthias Springer } 6047a1579acSMatthias Springer 6057a1579acSMatthias Springer /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops. 6067a1579acSMatthias Springer static void equivalenceAnalysis(SmallVector<Operation *> &ops, 6077a1579acSMatthias Springer BufferizationAliasInfo &aliasInfo, 6087a1579acSMatthias Springer BufferizationState &state) { 6097a1579acSMatthias Springer for (Operation *op : ops) 6107a1579acSMatthias Springer if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) 6117a1579acSMatthias Springer for (OpResult opResult : op->getOpResults()) 6127a1579acSMatthias Springer if (opResult.getType().isa<TensorType>()) 6137a1579acSMatthias Springer for (OpOperand *opOperand : 6147a1579acSMatthias Springer bufferizableOp.getAliasingOpOperand(opResult, state)) 6157a1579acSMatthias Springer if (state.isInPlace(*opOperand)) 6167a1579acSMatthias Springer if (bufferizableOp.bufferRelation(opResult, state) == 6177a1579acSMatthias Springer BufferRelation::Equivalent) 6187a1579acSMatthias Springer aliasInfo.unionEquivalenceClasses(opResult, opOperand->get()); 6197a1579acSMatthias Springer } 6207a1579acSMatthias Springer 6217a1579acSMatthias Springer /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained 6227a1579acSMatthias Springer /// in `op`. 6237a1579acSMatthias Springer static void equivalenceAnalysis(Operation *op, 6247a1579acSMatthias Springer BufferizationAliasInfo &aliasInfo, 6257a1579acSMatthias Springer BufferizationState &state) { 6267a1579acSMatthias Springer // Traverse ops in PostOrder: Nested ops first, then enclosing ops. 6277a1579acSMatthias Springer SmallVector<Operation *> ops; 6287a1579acSMatthias Springer op->walk<WalkOrder::PostOrder>([&](Operation *op) { 6297a1579acSMatthias Springer // No tensors => no buffers. 6307a1579acSMatthias Springer if (none_of(op->getResultTypes(), isaTensor)) 6317a1579acSMatthias Springer return; 6327a1579acSMatthias Springer ops.push_back(op); 6337a1579acSMatthias Springer }); 6347a1579acSMatthias Springer 6357a1579acSMatthias Springer equivalenceAnalysis(ops, aliasInfo, state); 6367a1579acSMatthias Springer } 6377a1579acSMatthias Springer 6387a1579acSMatthias Springer /// Assert that the current bufferization decisions are consistent. 6397a1579acSMatthias Springer static LogicalResult 6407a1579acSMatthias Springer checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo, 6417a1579acSMatthias Springer BufferizationState &state, 6427a1579acSMatthias Springer const BufferizationAliasInfo &aliasInfo) { 6437a1579acSMatthias Springer const BufferizationOptions &options = state.getOptions(); 6447a1579acSMatthias Springer Operation *inconsistentOp = nullptr; 6457a1579acSMatthias Springer WalkResult walkResult = op->walk([&](Operation *op) { 6467a1579acSMatthias Springer if (auto bufferizableOp = options.dynCastBufferizableOp(op)) 6477a1579acSMatthias Springer for (OpOperand &opOperand : op->getOpOperands()) 6487a1579acSMatthias Springer if (opOperand.get().getType().isa<TensorType>()) { 6497a1579acSMatthias Springer if (wouldCreateReadAfterWriteInterference( 6507a1579acSMatthias Springer opOperand, domInfo, state, aliasInfo, 6517a1579acSMatthias Springer /*checkConsistencyOnly=*/true)) { 6527a1579acSMatthias Springer // This error can happen if certain "mustBufferizeInPlace" interface 6537a1579acSMatthias Springer // methods are implemented incorrectly, such that the IR already has 6547a1579acSMatthias Springer // a RaW conflict before making any bufferization decisions. 6557a1579acSMatthias Springer inconsistentOp = op; 6567a1579acSMatthias Springer return WalkResult::interrupt(); 6577a1579acSMatthias Springer } 6587a1579acSMatthias Springer } 6597a1579acSMatthias Springer return WalkResult::advance(); 6607a1579acSMatthias Springer }); 6617a1579acSMatthias Springer 6627a1579acSMatthias Springer if (walkResult.wasInterrupted()) 6637a1579acSMatthias Springer return inconsistentOp->emitError("input IR has RaW conflict"); 6647a1579acSMatthias Springer return success(); 6657a1579acSMatthias Springer } 6667a1579acSMatthias Springer 6677a1579acSMatthias Springer /// Annotate the IR with the result of the analysis. For testing/debugging only. 6687a1579acSMatthias Springer static void 6697a1579acSMatthias Springer annotateOpsWithBufferizationMarkers(Operation *op, 6707a1579acSMatthias Springer const BufferizationAliasInfo &aliasInfo, 6717a1579acSMatthias Springer BufferizationState &state) { 6727a1579acSMatthias Springer op->walk([&](Operation *op) { 6737a1579acSMatthias Springer if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) 6747a1579acSMatthias Springer for (OpOperand &opOperand : op->getOpOperands()) 6757a1579acSMatthias Springer if (opOperand.get().getType().isa<TensorType>()) 6767a1579acSMatthias Springer setInPlaceOpOperand(opOperand, aliasInfo.isInPlace(opOperand)); 6777a1579acSMatthias Springer }); 6787a1579acSMatthias Springer } 6797a1579acSMatthias Springer 6807a1579acSMatthias Springer /// Assert that IR is in destination-passing style. I.e., every value that is 6817a1579acSMatthias Springer /// returned or yielded from a block is: 6827a1579acSMatthias Springer /// * aliasing a bbArg of that block or a parent block, or 6837a1579acSMatthias Springer /// * aliasing an OpResult of a op in a parent block. 6847a1579acSMatthias Springer /// 6857a1579acSMatthias Springer /// Example: 6867a1579acSMatthias Springer /// ``` 6877a1579acSMatthias Springer /// %0 = "some_op" : tensor<?xf32> 6887a1579acSMatthias Springer /// %1 = scf.if %c -> (tensor<?xf32>) { 6897a1579acSMatthias Springer /// scf.yield %0 : tensor<?xf32> 6907a1579acSMatthias Springer /// } else { 6917a1579acSMatthias Springer /// %t = linalg.init_tensor : tensor<?xf32> 6927a1579acSMatthias Springer /// scf.yield %t : tensor<?xf32> 6937a1579acSMatthias Springer /// } 6947a1579acSMatthias Springer /// ``` 6957a1579acSMatthias Springer /// In the above example, the first scf.yield op satifies destination-passing 6967a1579acSMatthias Springer /// style because the yielded value %0 is defined in the parent block. The 6977a1579acSMatthias Springer /// second scf.yield op does not satisfy destination-passing style because the 6987a1579acSMatthias Springer /// yielded value %t is defined in the same block as the scf.yield op. 6997a1579acSMatthias Springer // TODO: The current implementation checks for equivalent values instead of 7007a1579acSMatthias Springer // aliasing values, which is stricter than needed. We can currently not check 7017a1579acSMatthias Springer // for aliasing values because the analysis is a maybe-alias analysis and we 7027a1579acSMatthias Springer // need a must-alias analysis here. 703cdb7675cSMatthias Springer static LogicalResult 704cdb7675cSMatthias Springer assertDestinationPassingStyle(Operation *op, BufferizationState &state, 7057a1579acSMatthias Springer BufferizationAliasInfo &aliasInfo, 706cdb7675cSMatthias Springer SmallVector<Operation *> &newOps) { 7077a1579acSMatthias Springer LogicalResult status = success(); 7087a1579acSMatthias Springer DominanceInfo domInfo(op); 7097a1579acSMatthias Springer op->walk([&](Operation *returnOp) { 7107a1579acSMatthias Springer if (!isRegionReturnLike(returnOp)) 7117a1579acSMatthias Springer return WalkResult::advance(); 7127a1579acSMatthias Springer 7137a1579acSMatthias Springer for (OpOperand &returnValOperand : returnOp->getOpOperands()) { 7147a1579acSMatthias Springer Value returnVal = returnValOperand.get(); 7157a1579acSMatthias Springer // Skip non-tensor values. 7167a1579acSMatthias Springer if (!returnVal.getType().isa<TensorType>()) 7177a1579acSMatthias Springer continue; 7187a1579acSMatthias Springer 7197a1579acSMatthias Springer bool foundEquivValue = false; 7207a1579acSMatthias Springer aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) { 7217a1579acSMatthias Springer if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) { 7227a1579acSMatthias Springer Operation *definingOp = bbArg.getOwner()->getParentOp(); 7237a1579acSMatthias Springer if (definingOp->isProperAncestor(returnOp)) 7247a1579acSMatthias Springer foundEquivValue = true; 7257a1579acSMatthias Springer return; 7267a1579acSMatthias Springer } 7277a1579acSMatthias Springer 7287a1579acSMatthias Springer Operation *definingOp = equivVal.getDefiningOp(); 7297a1579acSMatthias Springer if (definingOp->getBlock()->findAncestorOpInBlock( 7307a1579acSMatthias Springer *returnOp->getParentOp())) 7317a1579acSMatthias Springer // Skip ops that happen after `returnOp` and parent ops. 7327a1579acSMatthias Springer if (happensBefore(definingOp, returnOp, domInfo)) 7337a1579acSMatthias Springer foundEquivValue = true; 7347a1579acSMatthias Springer }); 7357a1579acSMatthias Springer 7367a1579acSMatthias Springer if (!foundEquivValue) 7377a1579acSMatthias Springer status = 7387a1579acSMatthias Springer returnOp->emitError() 7397a1579acSMatthias Springer << "operand #" << returnValOperand.getOperandNumber() 7407a1579acSMatthias Springer << " of ReturnLike op does not satisfy destination passing style"; 7417a1579acSMatthias Springer } 7427a1579acSMatthias Springer 7437a1579acSMatthias Springer return WalkResult::advance(); 7447a1579acSMatthias Springer }); 7457a1579acSMatthias Springer 7467a1579acSMatthias Springer return status; 7477a1579acSMatthias Springer } 7487a1579acSMatthias Springer 7497a1579acSMatthias Springer LogicalResult bufferization::analyzeOp(Operation *op, 7507a1579acSMatthias Springer AnalysisBufferizationState &state) { 7517a1579acSMatthias Springer DominanceInfo domInfo(op); 7527a1579acSMatthias Springer BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); 7537a1579acSMatthias Springer const auto &options = 7547a1579acSMatthias Springer static_cast<const AnalysisBufferizationOptions &>(state.getOptions()); 7557a1579acSMatthias Springer 7567a1579acSMatthias Springer if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo))) 7577a1579acSMatthias Springer return failure(); 7587a1579acSMatthias Springer 7597a1579acSMatthias Springer // If the analysis fails, just return. 7607a1579acSMatthias Springer if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo, 7617a1579acSMatthias Springer options.analysisFuzzerSeed))) 7627a1579acSMatthias Springer return failure(); 7637a1579acSMatthias Springer equivalenceAnalysis(op, aliasInfo, state); 7647a1579acSMatthias Springer 765cdb7675cSMatthias Springer for (const PostAnalysisStepFn &fn : options.postAnalysisSteps) { 7667a1579acSMatthias Springer SmallVector<Operation *> newOps; 767cdb7675cSMatthias Springer if (failed(fn(op, state, aliasInfo, newOps))) 7687a1579acSMatthias Springer return failure(); 769cdb7675cSMatthias Springer // Analyze ops that were created by the PostAnalysisStepFn. 7707a1579acSMatthias Springer if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo))) 7717a1579acSMatthias Springer return failure(); 7727a1579acSMatthias Springer equivalenceAnalysis(newOps, aliasInfo, state); 7737a1579acSMatthias Springer } 7747a1579acSMatthias Springer 7757a1579acSMatthias Springer if (!options.allowReturnMemref) { 7767a1579acSMatthias Springer SmallVector<Operation *> newOps; 777cdb7675cSMatthias Springer if (failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps))) 7787a1579acSMatthias Springer return failure(); 7797a1579acSMatthias Springer } 7807a1579acSMatthias Springer 781*4ec00fb3SMatthias Springer // Analysis verification: After setting up alias/equivalence sets, each op 782*4ec00fb3SMatthias Springer // can check for expected invariants/limitations and fail the analysis if 783*4ec00fb3SMatthias Springer // necessary. 784*4ec00fb3SMatthias Springer bool passedAnalysis = true; 785*4ec00fb3SMatthias Springer op->walk([&](Operation *op) { 786*4ec00fb3SMatthias Springer if (BufferizableOpInterface bufferizableOp = 787*4ec00fb3SMatthias Springer options.dynCastBufferizableOp(op)) 788*4ec00fb3SMatthias Springer if (failed(bufferizableOp.verifyAnalysis(state))) 789*4ec00fb3SMatthias Springer passedAnalysis = false; 790*4ec00fb3SMatthias Springer }); 791*4ec00fb3SMatthias Springer if (!passedAnalysis) 792*4ec00fb3SMatthias Springer return failure(); 793*4ec00fb3SMatthias Springer 7947a1579acSMatthias Springer // Annotate operations if we only want to report the analysis. 7957a1579acSMatthias Springer if (options.testAnalysisOnly) 7967a1579acSMatthias Springer annotateOpsWithBufferizationMarkers(op, aliasInfo, state); 7977a1579acSMatthias Springer 7987a1579acSMatthias Springer return success(); 7997a1579acSMatthias Springer } 8007a1579acSMatthias Springer 8017a1579acSMatthias Springer LogicalResult bufferization::runOneShotBufferize( 8027a1579acSMatthias Springer Operation *op, std::unique_ptr<AnalysisBufferizationOptions> options) { 8037a1579acSMatthias Springer AnalysisBufferizationState state(op, *options); 8047a1579acSMatthias Springer if (failed(analyzeOp(op, state))) 8057a1579acSMatthias Springer return failure(); 8067a1579acSMatthias Springer if (options->testAnalysisOnly) 8077a1579acSMatthias Springer return success(); 8087a1579acSMatthias Springer return bufferizeOp(op, state); 8097a1579acSMatthias Springer } 810