17a1579acSMatthias Springer //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===//
27a1579acSMatthias Springer //
37a1579acSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47a1579acSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
57a1579acSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67a1579acSMatthias Springer //
77a1579acSMatthias Springer //===----------------------------------------------------------------------===//
87a1579acSMatthias Springer
97a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
107a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1136550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
127a1579acSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
1379f11591SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
147a1579acSMatthias Springer #include "mlir/IR/AsmState.h"
157a1579acSMatthias Springer #include "mlir/IR/BlockAndValueMapping.h"
167a1579acSMatthias Springer #include "mlir/IR/BuiltinOps.h"
177a1579acSMatthias Springer #include "mlir/IR/Operation.h"
187a1579acSMatthias Springer #include "mlir/IR/TypeUtilities.h"
197a1579acSMatthias Springer #include "mlir/IR/Value.h"
207a1579acSMatthias Springer #include "llvm/Support/Debug.h"
217a1579acSMatthias Springer
2287c770bbSMatthias Springer //===----------------------------------------------------------------------===//
2387c770bbSMatthias Springer // BufferizableOpInterface
2487c770bbSMatthias Springer //===----------------------------------------------------------------------===//
2587c770bbSMatthias Springer
267a1579acSMatthias Springer namespace mlir {
277a1579acSMatthias Springer namespace bufferization {
287a1579acSMatthias Springer
297a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
307a1579acSMatthias Springer
317a1579acSMatthias Springer } // namespace bufferization
327a1579acSMatthias Springer } // namespace mlir
337a1579acSMatthias Springer
347a1579acSMatthias Springer #define DEBUG_TYPE "bufferizable-op-interface"
357a1579acSMatthias Springer #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
367a1579acSMatthias Springer #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
377a1579acSMatthias Springer
387a1579acSMatthias Springer using namespace mlir;
397a1579acSMatthias Springer using namespace bufferization;
407a1579acSMatthias Springer
41c0b0b6a0SMatthias Springer /// Return the owner of the given value.
getOwnerOfValue(Value value)42c0b0b6a0SMatthias Springer static Operation *getOwnerOfValue(Value value) {
43c0b0b6a0SMatthias Springer if (auto opResult = value.dyn_cast<OpResult>())
44c0b0b6a0SMatthias Springer return opResult.getDefiningOp();
45c0b0b6a0SMatthias Springer return value.cast<BlockArgument>().getOwner()->getParentOp();
46c0b0b6a0SMatthias Springer }
47c0b0b6a0SMatthias Springer
allocationDoesNotEscape(OpResult opResult)48c66303c2SMatthias Springer bool bufferization::allocationDoesNotEscape(OpResult opResult) {
49c66303c2SMatthias Springer #ifndef NDEBUG
50c66303c2SMatthias Springer auto bufferizableOp = opResult.getDefiningOp<BufferizableOpInterface>();
51c66303c2SMatthias Springer assert(bufferizableOp && bufferizableOp.bufferizesToAllocation(opResult) &&
52c66303c2SMatthias Springer "expected op that bufferizes to an allocation");
53c66303c2SMatthias Springer #endif // NDEBUG
54c66303c2SMatthias Springer
55c66303c2SMatthias Springer Operation *op = opResult.getDefiningOp();
56c66303c2SMatthias Springer // If there is no 'escape' attribute, we cannot say for sure.
57c66303c2SMatthias Springer if (!op->hasAttr(BufferizationDialect::kEscapeAttrName))
58c66303c2SMatthias Springer return false;
59c66303c2SMatthias Springer auto attr =
60c66303c2SMatthias Springer op->getAttrOfType<ArrayAttr>(BufferizationDialect::kEscapeAttrName);
61c66303c2SMatthias Springer return !attr[opResult.getResultNumber()].cast<BoolAttr>().getValue();
62c66303c2SMatthias Springer }
63c66303c2SMatthias Springer
64b3ebe3beSMatthias Springer /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
65b3ebe3beSMatthias Springer /// shaped value is copied. Otherwise, a tensor with undefined contents is
66b3ebe3beSMatthias Springer /// allocated.
allocateTensorForShapedValue(OpBuilder & b,Location loc,Value shapedValue,bool escape,const BufferizationOptions & options,bool copy)6745b995cdSMatthias Springer FailureOr<Value> bufferization::allocateTensorForShapedValue(
6845b995cdSMatthias Springer OpBuilder &b, Location loc, Value shapedValue, bool escape,
6945b995cdSMatthias Springer const BufferizationOptions &options, bool copy) {
70b3ebe3beSMatthias Springer Value tensor;
71b3ebe3beSMatthias Springer if (shapedValue.getType().isa<RankedTensorType>()) {
72b3ebe3beSMatthias Springer tensor = shapedValue;
73b3ebe3beSMatthias Springer } else if (shapedValue.getType().isa<MemRefType>()) {
74b3ebe3beSMatthias Springer tensor = b.create<ToTensorOp>(loc, shapedValue);
7579f11591SMatthias Springer } else {
76b3ebe3beSMatthias Springer llvm_unreachable("expected RankedTensorType or MemRefType");
7779f11591SMatthias Springer }
78b3ebe3beSMatthias Springer RankedTensorType tensorType = tensor.getType().cast<RankedTensorType>();
79b3ebe3beSMatthias Springer SmallVector<Value> dynamicSizes;
80b3ebe3beSMatthias Springer if (!copy) {
81b3ebe3beSMatthias Springer // Compute the dynamic part of the shape.
82b3ebe3beSMatthias Springer // First try to query the shape via ReifyRankedShapedTypeOpInterface.
83b3ebe3beSMatthias Springer bool reifiedShapes = false;
84b3ebe3beSMatthias Springer if (shapedValue.getType().isa<RankedTensorType>() &&
85b3ebe3beSMatthias Springer shapedValue.isa<OpResult>()) {
86b3ebe3beSMatthias Springer if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
87b3ebe3beSMatthias Springer shapedValue.getDefiningOp())) {
88b3ebe3beSMatthias Springer ReifiedRankedShapedTypeDims resultDims;
89b3ebe3beSMatthias Springer if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
90b3ebe3beSMatthias Springer reifiedShapes = true;
91b3ebe3beSMatthias Springer auto &shape =
92b3ebe3beSMatthias Springer resultDims[shapedValue.cast<OpResult>().getResultNumber()];
93b3ebe3beSMatthias Springer for (const auto &dim : enumerate(tensorType.getShape()))
94b3ebe3beSMatthias Springer if (ShapedType::isDynamic(dim.value()))
95b3ebe3beSMatthias Springer dynamicSizes.push_back(shape[dim.index()]);
96b3ebe3beSMatthias Springer }
97b3ebe3beSMatthias Springer }
98b3ebe3beSMatthias Springer }
99b3ebe3beSMatthias Springer
100b3ebe3beSMatthias Springer // If the shape could not be reified, create DimOps.
101b3ebe3beSMatthias Springer if (!reifiedShapes)
102b3ebe3beSMatthias Springer populateDynamicDimSizes(b, loc, tensor, dynamicSizes);
103b3ebe3beSMatthias Springer }
104b3ebe3beSMatthias Springer
105c0b0b6a0SMatthias Springer // Create AllocTensorOp.
1063474d10eSMatthias Springer auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
1073474d10eSMatthias Springer copy ? tensor : Value());
1083474d10eSMatthias Springer allocTensorOp->setAttr(BufferizationDialect::kEscapeAttrName,
1093474d10eSMatthias Springer b.getBoolArrayAttr({escape}));
110c0b0b6a0SMatthias Springer
111c0b0b6a0SMatthias Springer // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
112c0b0b6a0SMatthias Springer if (copy)
113c0b0b6a0SMatthias Springer return allocTensorOp.getResult();
114c0b0b6a0SMatthias Springer FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
115c0b0b6a0SMatthias Springer if (failed(copyBufferType))
116c0b0b6a0SMatthias Springer return failure();
117c0b0b6a0SMatthias Springer allocTensorOp.setMemorySpaceAttr(
118c0b0b6a0SMatthias Springer b.getIntegerAttr(b.getIntegerType(64, /*isSigned=*/false),
119c0b0b6a0SMatthias Springer copyBufferType->getMemorySpaceAsInt()));
12045b995cdSMatthias Springer return allocTensorOp.getResult();
12179f11591SMatthias Springer }
12279f11591SMatthias Springer
resolveTensorOpOperandConflicts(RewriterBase & rewriter,const AnalysisState & state)12387c770bbSMatthias Springer LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
12487c770bbSMatthias Springer RewriterBase &rewriter, const AnalysisState &state) {
12587b46776SMatthias Springer OpBuilder::InsertionGuard g(rewriter);
12687c770bbSMatthias Springer Operation *op = getOperation();
12787b46776SMatthias Springer SmallVector<OpOperand *> outOfPlaceOpOperands;
12879f11591SMatthias Springer DenseSet<OpOperand *> copiedOpOperands;
129a36c801dSMatthias Springer DenseSet<OpOperand *> escapingOpOperandCopies;
13087b46776SMatthias Springer SmallVector<OpResult> outOfPlaceOpResults;
13179f11591SMatthias Springer DenseSet<OpResult> copiedOpResults;
132a36c801dSMatthias Springer DenseSet<OpResult> escapingOpResultCopies;
13387b46776SMatthias Springer
13487b46776SMatthias Springer // Find all out-of-place OpOperands.
13587c770bbSMatthias Springer for (OpOperand &opOperand : op->getOpOperands()) {
13687c770bbSMatthias Springer Type operandType = opOperand.get().getType();
13787c770bbSMatthias Springer if (!operandType.isa<TensorType>())
13887c770bbSMatthias Springer continue;
13987c770bbSMatthias Springer if (state.isInPlace(opOperand))
14087c770bbSMatthias Springer continue;
14187c770bbSMatthias Springer if (operandType.isa<UnrankedTensorType>())
14287c770bbSMatthias Springer return op->emitError("copies of unranked tensors are not supported");
14387b46776SMatthias Springer
14487c770bbSMatthias Springer SmallVector<OpResult> aliasingOpResults =
14587c770bbSMatthias Springer state.getAliasingOpResult(opOperand);
146a36c801dSMatthias Springer // Is the result yielded from a block? Or are deallocations turned off
147a36c801dSMatthias Springer // entirely? In either case, mark the allocation as "escaping", so that it
148a36c801dSMatthias Springer // will not be deallocated.
149a36c801dSMatthias Springer bool escape = !state.getOptions().createDeallocs ||
150a36c801dSMatthias Springer llvm::any_of(aliasingOpResults, [&](Value v) {
151a36c801dSMatthias Springer return state.isTensorYielded(v);
152a36c801dSMatthias Springer });
153a36c801dSMatthias Springer
15487b46776SMatthias Springer if (aliasingOpResults.size() == 1 &&
15587b46776SMatthias Springer !state.bufferizesToMemoryWrite(opOperand) &&
15687b46776SMatthias Springer state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
15787b46776SMatthias Springer // The op itself does not write but may create exactly one alias. Instead
15887b46776SMatthias Springer // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
15987b46776SMatthias Springer // be smaller than the OpOperand (e.g., in the case of an extract_slice,
16087b46776SMatthias Springer // where the result is usually a smaller part of the source).
16187b46776SMatthias Springer outOfPlaceOpResults.push_back(aliasingOpResults.front());
16279f11591SMatthias Springer if (!state.canOmitTensorCopy(opOperand))
16379f11591SMatthias Springer copiedOpResults.insert(aliasingOpResults.front());
164a36c801dSMatthias Springer if (escape)
165a36c801dSMatthias Springer escapingOpResultCopies.insert(aliasingOpResults.front());
16687b46776SMatthias Springer } else {
16787b46776SMatthias Springer // In all other cases, make a copy of the OpOperand.
16887b46776SMatthias Springer outOfPlaceOpOperands.push_back(&opOperand);
16979f11591SMatthias Springer if (!state.canOmitTensorCopy(opOperand))
17079f11591SMatthias Springer copiedOpOperands.insert(&opOperand);
171a36c801dSMatthias Springer if (escape)
172a36c801dSMatthias Springer escapingOpOperandCopies.insert(&opOperand);
17387b46776SMatthias Springer }
17487b46776SMatthias Springer }
17587b46776SMatthias Springer
17687b46776SMatthias Springer // Insert copies of OpOperands.
17787b46776SMatthias Springer rewriter.setInsertionPoint(op);
17887b46776SMatthias Springer for (OpOperand *opOperand : outOfPlaceOpOperands) {
17945b995cdSMatthias Springer FailureOr<Value> copy = allocateTensorForShapedValue(
180a36c801dSMatthias Springer rewriter, op->getLoc(), opOperand->get(),
18145b995cdSMatthias Springer escapingOpOperandCopies.contains(opOperand), state.getOptions(),
18279f11591SMatthias Springer copiedOpOperands.contains(opOperand));
18345b995cdSMatthias Springer if (failed(copy))
18445b995cdSMatthias Springer return failure();
18545b995cdSMatthias Springer rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); });
18687c770bbSMatthias Springer }
18787b46776SMatthias Springer
18887b46776SMatthias Springer // Insert copies of OpResults.
18987b46776SMatthias Springer rewriter.setInsertionPointAfter(op);
19087b46776SMatthias Springer for (OpResult opResult : outOfPlaceOpResults) {
19145b995cdSMatthias Springer FailureOr<Value> copy = allocateTensorForShapedValue(
19245b995cdSMatthias Springer rewriter, op->getLoc(), opResult,
19345b995cdSMatthias Springer escapingOpResultCopies.contains(opResult), state.getOptions(),
19479f11591SMatthias Springer copiedOpResults.count(opResult));
19545b995cdSMatthias Springer if (failed(copy))
19645b995cdSMatthias Springer return failure();
19787b46776SMatthias Springer SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
19887b46776SMatthias Springer opResult.getUses(), [](OpOperand &use) { return &use; }));
19987b46776SMatthias Springer for (OpOperand *use : uses) {
20087b46776SMatthias Springer // Do not update the alloc_tensor op that we just created.
20145b995cdSMatthias Springer if (use->getOwner() != copy->getDefiningOp())
20245b995cdSMatthias Springer rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); });
20387b46776SMatthias Springer }
20487b46776SMatthias Springer }
20587b46776SMatthias Springer
20687c770bbSMatthias Springer return success();
20787c770bbSMatthias Springer }
20887c770bbSMatthias Springer
shouldDeallocateOpResult(OpResult opResult,const BufferizationOptions & options)209*664ffa46SMatthias Springer bool bufferization::shouldDeallocateOpResult(
210*664ffa46SMatthias Springer OpResult opResult, const BufferizationOptions &options) {
211*664ffa46SMatthias Springer Operation *op = opResult.getOwner();
212*664ffa46SMatthias Springer assert(options.dynCastBufferizableOp(op).bufferizesToAllocation(opResult) &&
213*664ffa46SMatthias Springer "expected that op allocates");
214*664ffa46SMatthias Springer
215*664ffa46SMatthias Springer AnalysisState analysisState(options);
216*664ffa46SMatthias Springer if (op->hasAttr(BufferizationDialect::kEscapeAttrName)) {
217*664ffa46SMatthias Springer // AllocTensorOp has one result.
218*664ffa46SMatthias Springer ArrayAttr escapeAttr =
219*664ffa46SMatthias Springer op->getAttr(BufferizationDialect::kEscapeAttrName).cast<ArrayAttr>();
220*664ffa46SMatthias Springer return !escapeAttr[0].cast<BoolAttr>().getValue();
221*664ffa46SMatthias Springer }
222*664ffa46SMatthias Springer
223*664ffa46SMatthias Springer // No "escape" annotation found.
224*664ffa46SMatthias Springer if (options.createDeallocs) {
225*664ffa46SMatthias Springer // Perform an ad-hoc analysis.
226*664ffa46SMatthias Springer return !analysisState.isTensorYielded(opResult);
227*664ffa46SMatthias Springer }
228*664ffa46SMatthias Springer
229*664ffa46SMatthias Springer return false;
230*664ffa46SMatthias Springer }
231*664ffa46SMatthias Springer
2327a1579acSMatthias Springer //===----------------------------------------------------------------------===//
2331534177fSMatthias Springer // OpFilter
2341534177fSMatthias Springer //===----------------------------------------------------------------------===//
2351534177fSMatthias Springer
isOpAllowed(Operation * op) const2361534177fSMatthias Springer bool OpFilter::isOpAllowed(Operation *op) const {
2371534177fSMatthias Springer // All other ops: Allow/disallow according to filter.
2381534177fSMatthias Springer bool isAllowed = !hasAllowRule();
2391534177fSMatthias Springer for (const Entry &entry : entries) {
2401534177fSMatthias Springer bool filterResult = entry.fn(op);
2411534177fSMatthias Springer switch (entry.type) {
2421534177fSMatthias Springer case Entry::ALLOW:
2431534177fSMatthias Springer isAllowed |= filterResult;
2441534177fSMatthias Springer break;
2451534177fSMatthias Springer case Entry::DENY:
2461534177fSMatthias Springer if (filterResult)
2471534177fSMatthias Springer // DENY filter matches. This op is no allowed. (Even if other ALLOW
2481534177fSMatthias Springer // filters may match.)
2491534177fSMatthias Springer return false;
2501534177fSMatthias Springer };
2511534177fSMatthias Springer }
2521534177fSMatthias Springer return isAllowed;
2531534177fSMatthias Springer }
2541534177fSMatthias Springer
2551534177fSMatthias Springer //===----------------------------------------------------------------------===//
2567a1579acSMatthias Springer // BufferizationOptions
2577a1579acSMatthias Springer //===----------------------------------------------------------------------===//
2587a1579acSMatthias Springer
259606f7c8fSMatthias Springer /// Default unknown type converter: Use a fully dynamic layout map.
260606f7c8fSMatthias Springer static BaseMemRefType
defaultUnknownTypeConverter(Value value,unsigned memorySpace,const BufferizationOptions & options)261606f7c8fSMatthias Springer defaultUnknownTypeConverter(Value value, unsigned memorySpace,
262606f7c8fSMatthias Springer const BufferizationOptions &options) {
263606f7c8fSMatthias Springer return getMemRefTypeWithFullyDynamicLayout(value.getType().cast<TensorType>(),
264606f7c8fSMatthias Springer memorySpace);
265606f7c8fSMatthias Springer }
266606f7c8fSMatthias Springer
2677a1579acSMatthias Springer // Default constructor for BufferizationOptions.
BufferizationOptions()268606f7c8fSMatthias Springer BufferizationOptions::BufferizationOptions()
269606f7c8fSMatthias Springer : unknownTypeConverterFn(defaultUnknownTypeConverter) {}
2707a1579acSMatthias Springer
isOpAllowed(Operation * op) const271d6dab38aSMatthias Springer bool BufferizationOptions::isOpAllowed(Operation *op) const {
272d6dab38aSMatthias Springer // Special case: If function boundary bufferization is deactivated, do not
273d6dab38aSMatthias Springer // allow ops that belong to the `func` dialect.
274d6dab38aSMatthias Springer bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
275d6dab38aSMatthias Springer if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
276d6dab38aSMatthias Springer return false;
277d6dab38aSMatthias Springer
2781534177fSMatthias Springer return opFilter.isOpAllowed(op);
279d6dab38aSMatthias Springer }
280d6dab38aSMatthias Springer
2817a1579acSMatthias Springer BufferizableOpInterface
dynCastBufferizableOp(Operation * op) const2827a1579acSMatthias Springer BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
2839785eb1bSMatthias Springer auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
2849785eb1bSMatthias Springer if (!bufferizableOp)
2857a1579acSMatthias Springer return nullptr;
2869785eb1bSMatthias Springer if (!isOpAllowed(op))
2879785eb1bSMatthias Springer return nullptr;
2889785eb1bSMatthias Springer return bufferizableOp;
2897a1579acSMatthias Springer }
2907a1579acSMatthias Springer
2917a1579acSMatthias Springer BufferizableOpInterface
dynCastBufferizableOp(Value value) const2927a1579acSMatthias Springer BufferizationOptions::dynCastBufferizableOp(Value value) const {
2937a1579acSMatthias Springer if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
2947a1579acSMatthias Springer if (isOpAllowed(bufferizableOp.getOperation()))
2957a1579acSMatthias Springer return bufferizableOp;
2967a1579acSMatthias Springer return nullptr;
2977a1579acSMatthias Springer }
2987a1579acSMatthias Springer
addDialectStateInitializer(StringRef name,const DialectStateInitFn & fn)29951894cbbSMehdi Amini void BufferizationOptions::addDialectStateInitializer(
30051894cbbSMehdi Amini StringRef name, const DialectStateInitFn &fn) {
3016fc11d4dSMatthias Springer stateInitializers.push_back(
3029597b16aSMatthias Springer [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
3036fc11d4dSMatthias Springer }
3046fc11d4dSMatthias Springer
3057a1579acSMatthias Springer //===----------------------------------------------------------------------===//
3067a1579acSMatthias Springer // Helper functions for BufferizableOpInterface
3077a1579acSMatthias Springer //===----------------------------------------------------------------------===//
3087a1579acSMatthias Springer
setInsertionPointAfter(OpBuilder & b,Value value)3097a1579acSMatthias Springer static void setInsertionPointAfter(OpBuilder &b, Value value) {
3107a1579acSMatthias Springer if (auto bbArg = value.dyn_cast<BlockArgument>()) {
3117a1579acSMatthias Springer b.setInsertionPointToStart(bbArg.getOwner());
3127a1579acSMatthias Springer } else {
3137a1579acSMatthias Springer b.setInsertionPointAfter(value.getDefiningOp());
3147a1579acSMatthias Springer }
3157a1579acSMatthias Springer }
3167a1579acSMatthias Springer
3177a1579acSMatthias Springer /// Determine which OpOperand* will alias with `result` if the op is bufferized
3187a1579acSMatthias Springer /// in place. Return an empty vector if the op is not bufferizable.
3197a1579acSMatthias Springer SmallVector<OpOperand *>
getAliasingOpOperand(OpResult result) const3209597b16aSMatthias Springer AnalysisState::getAliasingOpOperand(OpResult result) const {
3217a1579acSMatthias Springer if (Operation *op = result.getDefiningOp())
3222fe40c34SMatthias Springer if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
3237a1579acSMatthias Springer return bufferizableOp.getAliasingOpOperand(result, *this);
3247a1579acSMatthias Springer return {};
3257a1579acSMatthias Springer }
3267a1579acSMatthias Springer
3277a1579acSMatthias Springer /// Determine which OpResult will alias with `opOperand` if the op is bufferized
328585a8a32SMatthias Springer /// in place. Return an empty vector if the op is not bufferizable.
329585a8a32SMatthias Springer SmallVector<OpResult>
getAliasingOpResult(OpOperand & opOperand) const3309597b16aSMatthias Springer AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
3317a1579acSMatthias Springer if (auto bufferizableOp =
3322fe40c34SMatthias Springer getOptions().dynCastBufferizableOp(opOperand.getOwner()))
3337a1579acSMatthias Springer return bufferizableOp.getAliasingOpResult(opOperand, *this);
334585a8a32SMatthias Springer return {};
3357a1579acSMatthias Springer }
3367a1579acSMatthias Springer
3377a1579acSMatthias Springer /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
3387a1579acSMatthias Springer /// op is not bufferizable.
bufferizesToMemoryRead(OpOperand & opOperand) const3399597b16aSMatthias Springer bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
3407a1579acSMatthias Springer if (auto bufferizableOp =
3412fe40c34SMatthias Springer getOptions().dynCastBufferizableOp(opOperand.getOwner()))
3427a1579acSMatthias Springer return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
3437a1579acSMatthias Springer
3447a1579acSMatthias Springer // Unknown op that returns a tensor. The inplace analysis does not support it.
3457a1579acSMatthias Springer // Conservatively return true.
3467a1579acSMatthias Springer return true;
3477a1579acSMatthias Springer }
3487a1579acSMatthias Springer
3497a1579acSMatthias Springer /// Return true if `opOperand` bufferizes to a memory write. Return
3507a1579acSMatthias Springer /// `true` if the op is not bufferizable.
bufferizesToMemoryWrite(OpOperand & opOperand) const3519597b16aSMatthias Springer bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
3527a1579acSMatthias Springer if (auto bufferizableOp =
3532fe40c34SMatthias Springer getOptions().dynCastBufferizableOp(opOperand.getOwner()))
3547a1579acSMatthias Springer return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
3557a1579acSMatthias Springer
3567a1579acSMatthias Springer // Unknown op that returns a tensor. The inplace analysis does not support it.
3577a1579acSMatthias Springer // Conservatively return true.
3587a1579acSMatthias Springer return true;
3597a1579acSMatthias Springer }
3607a1579acSMatthias Springer
3617a1579acSMatthias Springer /// Return true if `opOperand` does neither read nor write but bufferizes to an
3627a1579acSMatthias Springer /// alias. Return false if the op is not bufferizable.
bufferizesToAliasOnly(OpOperand & opOperand) const3639597b16aSMatthias Springer bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
3647a1579acSMatthias Springer if (auto bufferizableOp =
3652fe40c34SMatthias Springer getOptions().dynCastBufferizableOp(opOperand.getOwner()))
3667a1579acSMatthias Springer return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
3677a1579acSMatthias Springer
3687a1579acSMatthias Springer // Unknown op that returns a tensor. The inplace analysis does not support it.
3697a1579acSMatthias Springer // Conservatively return false.
3707a1579acSMatthias Springer return false;
3717a1579acSMatthias Springer }
3727a1579acSMatthias Springer
3737a1579acSMatthias Springer /// Return true if the given value is read by an op that bufferizes to a memory
3747a1579acSMatthias Springer /// read. Also takes into account ops that create an alias but do not read by
3757a1579acSMatthias Springer /// themselves (e.g., ExtractSliceOp).
isValueRead(Value value) const3769597b16aSMatthias Springer bool AnalysisState::isValueRead(Value value) const {
3777a1579acSMatthias Springer assert(value.getType().isa<TensorType>() && "expected TensorType");
3787a1579acSMatthias Springer SmallVector<OpOperand *> workingSet;
3797a1579acSMatthias Springer for (OpOperand &use : value.getUses())
3807a1579acSMatthias Springer workingSet.push_back(&use);
3817a1579acSMatthias Springer
3827a1579acSMatthias Springer while (!workingSet.empty()) {
3837a1579acSMatthias Springer OpOperand *uMaybeReading = workingSet.pop_back_val();
3847a1579acSMatthias Springer // Skip over all ops that neither read nor write (but create an alias).
3857a1579acSMatthias Springer if (bufferizesToAliasOnly(*uMaybeReading))
386585a8a32SMatthias Springer for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
387585a8a32SMatthias Springer for (OpOperand &use : opResult.getUses())
3887a1579acSMatthias Springer workingSet.push_back(&use);
3897a1579acSMatthias Springer if (bufferizesToMemoryRead(*uMaybeReading))
3907a1579acSMatthias Springer return true;
3917a1579acSMatthias Springer }
3927a1579acSMatthias Springer
3937a1579acSMatthias Springer return false;
3947a1579acSMatthias Springer }
3957a1579acSMatthias Springer
3967a1579acSMatthias Springer // Starting from `value`, follow the use-def chain in reverse, always selecting
3977a1579acSMatthias Springer // the aliasing OpOperands. Find and return Values for which `condition`
3987a1579acSMatthias Springer // evaluates to true. OpOperands of such matching Values are not traversed any
3997a1579acSMatthias Springer // further.
findValueInReverseUseDefChain(Value value,llvm::function_ref<bool (Value)> condition) const4009597b16aSMatthias Springer llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
4017a1579acSMatthias Springer Value value, llvm::function_ref<bool(Value)> condition) const {
4027a1579acSMatthias Springer llvm::SetVector<Value> result, workingSet;
4037a1579acSMatthias Springer workingSet.insert(value);
4047a1579acSMatthias Springer
4057a1579acSMatthias Springer while (!workingSet.empty()) {
4067a1579acSMatthias Springer Value value = workingSet.pop_back_val();
4077a1579acSMatthias Springer if (condition(value) || value.isa<BlockArgument>()) {
4087a1579acSMatthias Springer result.insert(value);
4097a1579acSMatthias Springer continue;
4107a1579acSMatthias Springer }
4117a1579acSMatthias Springer
4127a1579acSMatthias Springer OpResult opResult = value.cast<OpResult>();
4137a1579acSMatthias Springer SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
4147a1579acSMatthias Springer if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
4157a1579acSMatthias Springer result.insert(value);
4167a1579acSMatthias Springer continue;
4177a1579acSMatthias Springer }
4187a1579acSMatthias Springer
4197a1579acSMatthias Springer for (OpOperand *o : opOperands)
4207a1579acSMatthias Springer workingSet.insert(o->get());
4217a1579acSMatthias Springer }
4227a1579acSMatthias Springer
4237a1579acSMatthias Springer return result;
4247a1579acSMatthias Springer }
4257a1579acSMatthias Springer
4267a1579acSMatthias Springer // Find the Values of the last preceding write of a given Value.
4277a1579acSMatthias Springer llvm::SetVector<Value>
findLastPrecedingWrite(Value value) const4289597b16aSMatthias Springer AnalysisState::findLastPrecedingWrite(Value value) const {
4297a1579acSMatthias Springer return findValueInReverseUseDefChain(value, [&](Value value) {
4307a1579acSMatthias Springer Operation *op = value.getDefiningOp();
4317a1579acSMatthias Springer if (!op)
4327a1579acSMatthias Springer return true;
4337a1579acSMatthias Springer auto bufferizableOp = options.dynCastBufferizableOp(op);
4347a1579acSMatthias Springer if (!bufferizableOp)
4357a1579acSMatthias Springer return true;
4367a1579acSMatthias Springer return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
4377a1579acSMatthias Springer });
4387a1579acSMatthias Springer }
4397a1579acSMatthias Springer
AnalysisState(const BufferizationOptions & options)4409597b16aSMatthias Springer AnalysisState::AnalysisState(const BufferizationOptions &options)
4416fc11d4dSMatthias Springer : options(options) {
4429597b16aSMatthias Springer for (const BufferizationOptions::AnalysisStateInitFn &fn :
4436fc11d4dSMatthias Springer options.stateInitializers)
4446fc11d4dSMatthias Springer fn(*this);
4456fc11d4dSMatthias Springer }
4467a1579acSMatthias Springer
canOmitTensorCopy(OpOperand & opOperand) const44779f11591SMatthias Springer bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
44879f11591SMatthias Springer // Do not copy if the tensor has undefined contents.
44979f11591SMatthias Springer if (hasUndefinedContents(&opOperand))
45079f11591SMatthias Springer return true;
45179f11591SMatthias Springer
45279f11591SMatthias Springer // Do not copy if the buffer of the tensor is entirely overwritten (with
45379f11591SMatthias Springer // values that do not depend on the old tensor).
45479f11591SMatthias Springer if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
45579f11591SMatthias Springer return true;
45679f11591SMatthias Springer
45779f11591SMatthias Springer // Do not copy if the tensor is never read.
45879f11591SMatthias Springer SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand);
45979f11591SMatthias Springer if (!bufferizesToMemoryRead(opOperand) &&
46079f11591SMatthias Springer llvm::none_of(aliasingOpResults,
46179f11591SMatthias Springer [&](OpResult opResult) { return isValueRead(opResult); }))
46279f11591SMatthias Springer return true;
46379f11591SMatthias Springer
46479f11591SMatthias Springer // Default: Cannot omit the copy.
46579f11591SMatthias Springer return false;
46679f11591SMatthias Springer }
46779f11591SMatthias Springer
isInPlace(OpOperand & opOperand) const468a3bca118SMatthias Springer bool AnalysisState::isInPlace(OpOperand &opOperand) const {
469b3ebe3beSMatthias Springer // ToMemrefOps are always in-place.
470b3ebe3beSMatthias Springer if (isa<ToMemrefOp>(opOperand.getOwner()))
471b3ebe3beSMatthias Springer return true;
472b3ebe3beSMatthias Springer
473a3bca118SMatthias Springer // In the absence of analysis information, OpOperands that bufferize to a
474a3bca118SMatthias Springer // memory write are out-of-place, i.e., an alloc and copy is inserted.
475a3bca118SMatthias Springer return !bufferizesToMemoryWrite(opOperand);
476a3bca118SMatthias Springer }
477a3bca118SMatthias Springer
areEquivalentBufferizedValues(Value v1,Value v2) const478a3bca118SMatthias Springer bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const {
479a3bca118SMatthias Springer // In the absence of analysis information, we do not know if the values are
480a3bca118SMatthias Springer // equivalent. The conservative answer is "false".
481a3bca118SMatthias Springer return false;
482a3bca118SMatthias Springer }
483a3bca118SMatthias Springer
areAliasingBufferizedValues(Value v1,Value v2) const484a3bca118SMatthias Springer bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const {
485a3bca118SMatthias Springer // In the absence of analysis information, we do not know if the values may be
486a3bca118SMatthias Springer // aliasing. The conservative answer is "true".
487f2ada383Slorenzo chelini return true;
488a3bca118SMatthias Springer }
489a3bca118SMatthias Springer
hasUndefinedContents(OpOperand * opOperand) const490a3bca118SMatthias Springer bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
491a3bca118SMatthias Springer // In the absence of analysis information, the conservative answer is "false".
492a3bca118SMatthias Springer return false;
493a3bca118SMatthias Springer }
494a3bca118SMatthias Springer
isTensorYielded(Value tensor) const495a3bca118SMatthias Springer bool AnalysisState::isTensorYielded(Value tensor) const {
496a3bca118SMatthias Springer // In the absence of analysis information, the conservative answer is "true".
497a36c801dSMatthias Springer if (!tensor.getDefiningOp<AllocTensorOp>())
498a3bca118SMatthias Springer return true;
499a36c801dSMatthias Springer
500a36c801dSMatthias Springer // For AllocTensorOp results, we can do better: They do not alias with any
501a36c801dSMatthias Springer // preceding value, so we can follow SSA use-def chains and do a simple
502a36c801dSMatthias Springer // analysis.
503a36c801dSMatthias Springer SmallVector<OpOperand *> worklist;
504a36c801dSMatthias Springer for (OpOperand &use : tensor.getUses())
505a36c801dSMatthias Springer worklist.push_back(&use);
506a36c801dSMatthias Springer
507a36c801dSMatthias Springer while (!worklist.empty()) {
508a36c801dSMatthias Springer OpOperand *operand = worklist.pop_back_val();
509a36c801dSMatthias Springer Operation *op = operand->getOwner();
510a36c801dSMatthias Springer
511a36c801dSMatthias Springer // If the op is not bufferizable, we can safely assume that the value is not
512a36c801dSMatthias Springer // yielded. (When bufferizing that op, it must handle such cases.)
513a36c801dSMatthias Springer if (!options.dynCastBufferizableOp(op))
514a36c801dSMatthias Springer continue;
515a36c801dSMatthias Springer
516a36c801dSMatthias Springer // We cannot analyze through ToMemrefOps, so we have to conservatively
517a36c801dSMatthias Springer // assume that the value is yielded.
518a36c801dSMatthias Springer if (isa<ToMemrefOp>(op))
519a36c801dSMatthias Springer return true;
520a36c801dSMatthias Springer
521a36c801dSMatthias Springer // Check if the op is returning/yielding.
522a36c801dSMatthias Springer if (isRegionReturnLike(op))
523a36c801dSMatthias Springer return true;
524a36c801dSMatthias Springer
525a36c801dSMatthias Springer // Add all aliasing OpResults to the worklist.
526a36c801dSMatthias Springer // Note: In the absence of detailed analysis information (e.g., there may be
527a36c801dSMatthias Springer // no function call analysis information), this `getAliasingOpResult` is
528a36c801dSMatthias Springer // conservative and may report additional OpResults as potentially aliasing.
529a36c801dSMatthias Springer for (OpResult opResult : getAliasingOpResult(*operand))
530a36c801dSMatthias Springer for (OpOperand &use : opResult.getUses())
531a36c801dSMatthias Springer worklist.push_back(&use);
532a36c801dSMatthias Springer }
533a36c801dSMatthias Springer
534a36c801dSMatthias Springer // No ReturnLike op found: The value is not yielded.
535a36c801dSMatthias Springer return false;
536a3bca118SMatthias Springer }
537a3bca118SMatthias Springer
5387a1579acSMatthias Springer // bufferization.to_memref is not allowed to change the rank.
ensureToMemrefOpIsValid(Value tensor,Type memrefType)5397a1579acSMatthias Springer static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
5407a1579acSMatthias Springer #ifndef NDEBUG
5417a1579acSMatthias Springer auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
5427a1579acSMatthias Springer assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
5437a1579acSMatthias Springer rankedTensorType.getRank()) &&
5447a1579acSMatthias Springer "to_memref would be invalid: mismatching ranks");
5457a1579acSMatthias Springer #endif
5467a1579acSMatthias Springer }
5477a1579acSMatthias Springer
getBuffer(RewriterBase & rewriter,Value value,const BufferizationOptions & options)5485d50f51cSMatthias Springer FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
549b55d55ecSMatthias Springer const BufferizationOptions &options) {
550ba9d886dSMatthias Springer #ifndef NDEBUG
551b3ebe3beSMatthias Springer auto tensorType = value.getType().dyn_cast<TensorType>();
55226852423SMatthias Springer assert(tensorType && "unexpected non-tensor type");
553ba9d886dSMatthias Springer #endif // NDEBUG
5547a1579acSMatthias Springer
5557a1579acSMatthias Springer // Replace "%t = to_tensor %m" with %m.
556b3ebe3beSMatthias Springer if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
55799260e95SMatthias Springer return toTensorOp.getMemref();
5587a1579acSMatthias Springer
5597a1579acSMatthias Springer // Insert to_memref op.
5607a1579acSMatthias Springer OpBuilder::InsertionGuard g(rewriter);
561b3ebe3beSMatthias Springer setInsertionPointAfter(rewriter, value);
5625d50f51cSMatthias Springer FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
5635d50f51cSMatthias Springer if (failed(memrefType))
5645d50f51cSMatthias Springer return failure();
5655d50f51cSMatthias Springer ensureToMemrefOpIsValid(value, *memrefType);
5665d50f51cSMatthias Springer return rewriter
5675d50f51cSMatthias Springer .create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value)
5685d50f51cSMatthias Springer .getResult();
5697a1579acSMatthias Springer }
5707a1579acSMatthias Springer
571996834e6SMatthias Springer /// Return the buffer type for a given Value (tensor) after bufferization.
5725d50f51cSMatthias Springer FailureOr<BaseMemRefType>
getBufferType(Value value,const BufferizationOptions & options)573b55d55ecSMatthias Springer bufferization::getBufferType(Value value, const BufferizationOptions &options) {
574606f7c8fSMatthias Springer assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
575c0b0b6a0SMatthias Springer Operation *op = getOwnerOfValue(value);
576d7a9bf91SMatthias Springer
577c0b0b6a0SMatthias Springer // ToTensorOp: Take buffer type directly from the op.
578996834e6SMatthias Springer if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
57999260e95SMatthias Springer return toTensorOp.getMemref().getType().cast<BaseMemRefType>();
580d7a9bf91SMatthias Springer
581c0b0b6a0SMatthias Springer // If value is a bbArg of a bufferizable op: query op interface.
582ba9d886dSMatthias Springer if (auto bbArg = value.dyn_cast<BlockArgument>())
583ba9d886dSMatthias Springer if (auto bufferizableOp =
584ba9d886dSMatthias Springer options.dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
585ba9d886dSMatthias Springer return bufferizableOp.getBufferType(bbArg, options);
586ba9d886dSMatthias Springer
587c0b0b6a0SMatthias Springer // Check value is a new buffer allocation with a memory space attribute. In
588c0b0b6a0SMatthias Springer // that case we can at least infer the memory space.
589c0b0b6a0SMatthias Springer Optional<unsigned> memorySpace = None;
590c0b0b6a0SMatthias Springer if (auto opResult = value.dyn_cast<OpResult>()) {
591c0b0b6a0SMatthias Springer if (auto bufferizableOp =
592c0b0b6a0SMatthias Springer options.dynCastBufferizableOp(opResult.getDefiningOp())) {
593c0b0b6a0SMatthias Springer if (bufferizableOp.bufferizesToAllocation(opResult)) {
594c0b0b6a0SMatthias Springer FailureOr<unsigned> queriedMemorySpace =
595c0b0b6a0SMatthias Springer bufferizableOp.getMemorySpace(opResult);
596c0b0b6a0SMatthias Springer if (!failed(queriedMemorySpace))
597c0b0b6a0SMatthias Springer memorySpace = *queriedMemorySpace;
598c0b0b6a0SMatthias Springer }
599c0b0b6a0SMatthias Springer }
600c0b0b6a0SMatthias Springer }
601c0b0b6a0SMatthias Springer
602c0b0b6a0SMatthias Springer // If we still do not know the memory space, use the default memory space (if
603c0b0b6a0SMatthias Springer // any).
604491d2701SKazu Hirata if (!memorySpace.has_value())
605c0b0b6a0SMatthias Springer memorySpace = options.defaultMemorySpace;
606c0b0b6a0SMatthias Springer
607c0b0b6a0SMatthias Springer // If we still do not know the memory space, report a failure.
608491d2701SKazu Hirata if (!memorySpace.has_value())
609c0b0b6a0SMatthias Springer return op->emitError("could not infer memory space");
610c0b0b6a0SMatthias Springer
611606f7c8fSMatthias Springer return getMemRefType(value, options, /*layout=*/{}, *memorySpace);
612d7a9bf91SMatthias Springer }
613d7a9bf91SMatthias Springer
replaceOpWithBufferizedValues(RewriterBase & rewriter,Operation * op,ValueRange values)6147a1579acSMatthias Springer void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
6157a1579acSMatthias Springer Operation *op,
6167a1579acSMatthias Springer ValueRange values) {
6179106d35bSMatthias Springer assert(values.size() == op->getNumResults() &&
6189106d35bSMatthias Springer "expected one value per OpResult");
6197a1579acSMatthias Springer OpBuilder::InsertionGuard g(rewriter);
6207a1579acSMatthias Springer
6217a1579acSMatthias Springer // Replace all OpResults with the given values.
6229106d35bSMatthias Springer SmallVector<Value> replacements;
6237a1579acSMatthias Springer for (OpResult opResult : op->getOpResults()) {
6247a1579acSMatthias Springer Value replacement = values[opResult.getResultNumber()];
6257a1579acSMatthias Springer if (opResult.getType().isa<TensorType>()) {
6267a1579acSMatthias Springer // The OpResult is a tensor. Such values are replaced with memrefs during
6277a1579acSMatthias Springer // bufferization.
6287a1579acSMatthias Springer assert((replacement.getType().isa<MemRefType>() ||
6297a1579acSMatthias Springer replacement.getType().isa<UnrankedMemRefType>()) &&
6307a1579acSMatthias Springer "tensor op result should be replaced with a memref value");
6317a1579acSMatthias Springer // The existing uses of the OpResult still expect a tensor. Insert a
6327a1579acSMatthias Springer // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
6337a1579acSMatthias Springer // loose all of its users and eventually DCE away.
634c30d2893SMatthias Springer rewriter.setInsertionPointAfter(op);
6357a1579acSMatthias Springer replacement = rewriter.create<bufferization::ToTensorOp>(
6367a1579acSMatthias Springer replacement.getLoc(), replacement);
6377a1579acSMatthias Springer }
6389106d35bSMatthias Springer replacements.push_back(replacement);
6397a1579acSMatthias Springer }
6407a1579acSMatthias Springer
6419106d35bSMatthias Springer rewriter.replaceOp(op, replacements);
6427a1579acSMatthias Springer }
6437a1579acSMatthias Springer
6447a1579acSMatthias Springer //===----------------------------------------------------------------------===//
6457a1579acSMatthias Springer // Bufferization-specific scoped alloc/dealloc insertion support.
6467a1579acSMatthias Springer //===----------------------------------------------------------------------===//
6477a1579acSMatthias Springer
64805e0495fSMatthias Springer /// Create a memref allocation with the given type and dynamic extents.
createAlloc(OpBuilder & b,Location loc,MemRefType type,ValueRange dynShape) const649248e113eSMatthias Springer FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
650248e113eSMatthias Springer MemRefType type,
651248e113eSMatthias Springer ValueRange dynShape) const {
652248e113eSMatthias Springer if (allocationFn)
653248e113eSMatthias Springer return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
65405e0495fSMatthias Springer
65505e0495fSMatthias Springer // Default bufferallocation via AllocOp.
656b3ebe3beSMatthias Springer if (bufferAlignment != 0)
657b3ebe3beSMatthias Springer return b
658b3ebe3beSMatthias Springer .create<memref::AllocOp>(loc, type, dynShape,
659b3ebe3beSMatthias Springer b.getI64IntegerAttr(bufferAlignment))
660b3ebe3beSMatthias Springer .getResult();
661b3ebe3beSMatthias Springer return b.create<memref::AllocOp>(loc, type, dynShape).getResult();
66205e0495fSMatthias Springer }
66305e0495fSMatthias Springer
66405e0495fSMatthias Springer /// Creates a memref deallocation. The given memref buffer must have been
66505e0495fSMatthias Springer /// allocated using `createAlloc`.
createDealloc(OpBuilder & b,Location loc,Value allocatedBuffer) const666248e113eSMatthias Springer LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
667248e113eSMatthias Springer Value allocatedBuffer) const {
668248e113eSMatthias Springer if (deallocationFn)
669248e113eSMatthias Springer return (*deallocationFn)(b, loc, allocatedBuffer);
67005e0495fSMatthias Springer
67105e0495fSMatthias Springer // Default buffer deallocation via DeallocOp.
67205e0495fSMatthias Springer b.create<memref::DeallocOp>(loc, allocatedBuffer);
67305e0495fSMatthias Springer return success();
67405e0495fSMatthias Springer }
67505e0495fSMatthias Springer
6767a1579acSMatthias Springer /// Create a memory copy between two memref buffers.
createMemCpy(OpBuilder & b,Location loc,Value from,Value to) const677248e113eSMatthias Springer LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
678248e113eSMatthias Springer Value from, Value to) const {
679248e113eSMatthias Springer if (memCpyFn)
680248e113eSMatthias Springer return (*memCpyFn)(b, loc, from, to);
6817a1579acSMatthias Springer
6827a1579acSMatthias Springer b.create<memref::CopyOp>(loc, from, to);
6837a1579acSMatthias Springer return success();
6847a1579acSMatthias Springer }
6857a1579acSMatthias Springer
6867a1579acSMatthias Springer //===----------------------------------------------------------------------===//
6877a1579acSMatthias Springer // Bufferization-specific BlockAndValueMapping support with debugging.
6887a1579acSMatthias Springer //===----------------------------------------------------------------------===//
6897a1579acSMatthias Springer
isFunctionArgument(Value value)6907a1579acSMatthias Springer bool bufferization::isFunctionArgument(Value value) {
6917a1579acSMatthias Springer auto bbArg = value.dyn_cast<BlockArgument>();
6927a1579acSMatthias Springer if (!bbArg)
6937a1579acSMatthias Springer return false;
69458ceae95SRiver Riddle return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
6957a1579acSMatthias Springer }
6967a1579acSMatthias Springer
getMemRefType(Value value,const BufferizationOptions & options,MemRefLayoutAttrInterface layout,unsigned memorySpace)697606f7c8fSMatthias Springer BaseMemRefType bufferization::getMemRefType(Value value,
69826852423SMatthias Springer const BufferizationOptions &options,
69926852423SMatthias Springer MemRefLayoutAttrInterface layout,
700b06614e2SMatthias Springer unsigned memorySpace) {
701606f7c8fSMatthias Springer auto tensorType = value.getType().cast<TensorType>();
702b06614e2SMatthias Springer auto memorySpaceAttr = IntegerAttr::get(
703b06614e2SMatthias Springer IntegerType::get(tensorType.getContext(), 64), memorySpace);
704b06614e2SMatthias Springer
70526852423SMatthias Springer // Case 1: Unranked memref type.
70626852423SMatthias Springer if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
70726852423SMatthias Springer assert(!layout && "UnrankedTensorType cannot have a layout map");
70826852423SMatthias Springer return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
709b06614e2SMatthias Springer memorySpaceAttr);
7107a1579acSMatthias Springer }
7117a1579acSMatthias Springer
712f287da8aSMatthias Springer // Case 2: Ranked memref type with specified layout.
71326852423SMatthias Springer auto rankedTensorType = tensorType.cast<RankedTensorType>();
714f287da8aSMatthias Springer if (layout) {
71526852423SMatthias Springer return MemRefType::get(rankedTensorType.getShape(),
71626852423SMatthias Springer rankedTensorType.getElementType(), layout,
717b06614e2SMatthias Springer memorySpaceAttr);
71826852423SMatthias Springer }
71926852423SMatthias Springer
720606f7c8fSMatthias Springer return options.unknownTypeConverterFn(value, memorySpace, options);
721f287da8aSMatthias Springer }
722f287da8aSMatthias Springer
723f287da8aSMatthias Springer BaseMemRefType
getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,unsigned memorySpace)724f287da8aSMatthias Springer bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
725b06614e2SMatthias Springer unsigned memorySpace) {
726f287da8aSMatthias Springer // Case 1: Unranked memref type.
727f287da8aSMatthias Springer if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
728f287da8aSMatthias Springer return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
729f287da8aSMatthias Springer memorySpace);
730f287da8aSMatthias Springer }
731f287da8aSMatthias Springer
732f287da8aSMatthias Springer // Case 2: Ranked memref type.
733b06614e2SMatthias Springer auto memorySpaceAttr = IntegerAttr::get(
734b06614e2SMatthias Springer IntegerType::get(tensorType.getContext(), 64), memorySpace);
735f287da8aSMatthias Springer auto rankedTensorType = tensorType.cast<RankedTensorType>();
7367a1579acSMatthias Springer int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
73726852423SMatthias Springer SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
7387a1579acSMatthias Springer ShapedType::kDynamicStrideOrOffset);
7397a1579acSMatthias Springer AffineMap stridedLayout = makeStridedLinearLayoutMap(
74026852423SMatthias Springer dynamicStrides, dynamicOffset, rankedTensorType.getContext());
74126852423SMatthias Springer return MemRefType::get(rankedTensorType.getShape(),
74226852423SMatthias Springer rankedTensorType.getElementType(), stridedLayout,
743b06614e2SMatthias Springer memorySpaceAttr);
7447a1579acSMatthias Springer }
745f287da8aSMatthias Springer
746f287da8aSMatthias Springer /// Return a MemRef type with a static identity layout (i.e., no layout map). If
747f287da8aSMatthias Springer /// the given tensor type is unranked, return an unranked MemRef type.
748f287da8aSMatthias Springer BaseMemRefType
getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,unsigned memorySpace)749f287da8aSMatthias Springer bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
750b06614e2SMatthias Springer unsigned memorySpace) {
751f287da8aSMatthias Springer // Case 1: Unranked memref type.
752f287da8aSMatthias Springer if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
753f287da8aSMatthias Springer return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
754f287da8aSMatthias Springer memorySpace);
755f287da8aSMatthias Springer }
756f287da8aSMatthias Springer
757f287da8aSMatthias Springer // Case 2: Ranked memref type.
758f287da8aSMatthias Springer auto rankedTensorType = tensorType.cast<RankedTensorType>();
759b06614e2SMatthias Springer auto memorySpaceAttr = IntegerAttr::get(
760b06614e2SMatthias Springer IntegerType::get(tensorType.getContext(), 64), memorySpace);
761f287da8aSMatthias Springer MemRefLayoutAttrInterface layout = {};
762f287da8aSMatthias Springer return MemRefType::get(rankedTensorType.getShape(),
763f287da8aSMatthias Springer rankedTensorType.getElementType(), layout,
764b06614e2SMatthias Springer memorySpaceAttr);
765f287da8aSMatthias Springer }
766