1e07a7fd5SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2e07a7fd5SMatthias Springer //
3e07a7fd5SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e07a7fd5SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5e07a7fd5SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e07a7fd5SMatthias Springer //
7e07a7fd5SMatthias Springer //===----------------------------------------------------------------------===//
8e07a7fd5SMatthias Springer
9e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
10e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13e07a7fd5SMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h"
14e07a7fd5SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
15e07a7fd5SMatthias Springer #include "mlir/IR/Dialect.h"
16e07a7fd5SMatthias Springer #include "mlir/IR/Operation.h"
17e07a7fd5SMatthias Springer
18e07a7fd5SMatthias Springer namespace mlir {
19e07a7fd5SMatthias Springer namespace bufferization {
20e07a7fd5SMatthias Springer namespace func_ext {
21e07a7fd5SMatthias Springer
startFunctionAnalysis(FuncOp funcOp)22e07a7fd5SMatthias Springer void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
23e07a7fd5SMatthias Springer analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
24e07a7fd5SMatthias Springer auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
25e07a7fd5SMatthias Springer auto createdAliasingOperands =
26e07a7fd5SMatthias Springer aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping());
27e07a7fd5SMatthias Springer auto createdAliasingResults =
28e07a7fd5SMatthias Springer aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping());
29e07a7fd5SMatthias Springer auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet());
30e07a7fd5SMatthias Springer auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet());
31e07a7fd5SMatthias Springer (void)createdEquiv;
32e07a7fd5SMatthias Springer (void)createdAliasingOperands;
33e07a7fd5SMatthias Springer (void)createdAliasingResults;
34e07a7fd5SMatthias Springer (void)createdRead;
35e07a7fd5SMatthias Springer (void)createdWritten;
36e07a7fd5SMatthias Springer #ifndef NDEBUG
37e07a7fd5SMatthias Springer assert(createdEquiv.second && "equivalence info exists already");
38e07a7fd5SMatthias Springer assert(createdAliasingOperands.second && "aliasing info exists already");
39e07a7fd5SMatthias Springer assert(createdAliasingResults.second && "aliasing info exists already");
40e07a7fd5SMatthias Springer assert(createdRead.second && "bbarg access info exists already");
41e07a7fd5SMatthias Springer assert(createdWritten.second && "bbarg access info exists already");
42e07a7fd5SMatthias Springer #endif // NDEBUG
43e07a7fd5SMatthias Springer }
44e07a7fd5SMatthias Springer
45e07a7fd5SMatthias Springer /// Return the unique ReturnOp that terminates `funcOp`.
46e07a7fd5SMatthias Springer /// Return nullptr if there is no such unique ReturnOp.
getAssumedUniqueReturnOp(FuncOp funcOp)47e07a7fd5SMatthias Springer static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
48e07a7fd5SMatthias Springer func::ReturnOp returnOp;
49e07a7fd5SMatthias Springer for (Block &b : funcOp.getBody()) {
50e07a7fd5SMatthias Springer if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
51e07a7fd5SMatthias Springer if (returnOp)
52e07a7fd5SMatthias Springer return nullptr;
53e07a7fd5SMatthias Springer returnOp = candidateOp;
54e07a7fd5SMatthias Springer }
55e07a7fd5SMatthias Springer }
56e07a7fd5SMatthias Springer return returnOp;
57e07a7fd5SMatthias Springer }
58e07a7fd5SMatthias Springer
59e07a7fd5SMatthias Springer /// Return the index-th bufferized function argument type. This assumes that the
60e07a7fd5SMatthias Springer /// specified argument is a tensor. If the tensor is ranked, a layout map may be
61f287da8aSMatthias Springer /// specified by the user. If no layout map is specified, the default layout map
62f287da8aSMatthias Springer /// (as per `options.functionBoundaryTypeConversion`) is used.
63e07a7fd5SMatthias Springer static BaseMemRefType
getBufferizedFunctionArgType(FuncOp funcOp,int64_t index,const BufferizationOptions & options)64e07a7fd5SMatthias Springer getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
65e07a7fd5SMatthias Springer const BufferizationOptions &options) {
66e07a7fd5SMatthias Springer auto tensorType =
67e07a7fd5SMatthias Springer funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
68e07a7fd5SMatthias Springer assert(tensorType && "expected TensorType");
69f287da8aSMatthias Springer
70f287da8aSMatthias Springer BaseMemRefType memrefType;
71f287da8aSMatthias Springer if (options.functionBoundaryTypeConversion ==
72f287da8aSMatthias Springer BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
73f287da8aSMatthias Springer memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
74f287da8aSMatthias Springer } else {
75f287da8aSMatthias Springer // Note: Layout maps on function parameters cannot be inferred. The best we
76f287da8aSMatthias Springer // can do at the moment is "fully dynamic".
77f287da8aSMatthias Springer memrefType = getMemRefTypeWithFullyDynamicLayout(tensorType);
78f287da8aSMatthias Springer }
79e07a7fd5SMatthias Springer
80e07a7fd5SMatthias Springer auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
81e07a7fd5SMatthias Springer index, BufferizationDialect::kBufferLayoutAttrName);
82e07a7fd5SMatthias Springer if (!layoutAttr)
83e07a7fd5SMatthias Springer return memrefType;
84e07a7fd5SMatthias Springer
85e07a7fd5SMatthias Springer auto rankedMemrefType = memrefType.dyn_cast<MemRefType>();
86e07a7fd5SMatthias Springer assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
87e07a7fd5SMatthias Springer return MemRefType::get(
88e07a7fd5SMatthias Springer rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
89e07a7fd5SMatthias Springer layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt());
90e07a7fd5SMatthias Springer }
91e07a7fd5SMatthias Springer
92e07a7fd5SMatthias Springer /// Return the FuncOp called by `callOp`.
getCalledFunction(CallOpInterface callOp)93e07a7fd5SMatthias Springer static FuncOp getCalledFunction(CallOpInterface callOp) {
94e07a7fd5SMatthias Springer SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
95e07a7fd5SMatthias Springer if (!sym)
96e07a7fd5SMatthias Springer return nullptr;
97e07a7fd5SMatthias Springer return dyn_cast_or_null<FuncOp>(
98e07a7fd5SMatthias Springer SymbolTable::lookupNearestSymbolFrom(callOp, sym));
99e07a7fd5SMatthias Springer }
100e07a7fd5SMatthias Springer
101e07a7fd5SMatthias Springer /// Get FuncAnalysisState.
102e07a7fd5SMatthias Springer static const FuncAnalysisState &
getFuncAnalysisState(const AnalysisState & state)103e07a7fd5SMatthias Springer getFuncAnalysisState(const AnalysisState &state) {
104e07a7fd5SMatthias Springer Optional<const FuncAnalysisState *> maybeState =
105e07a7fd5SMatthias Springer state.getDialectState<FuncAnalysisState>(
106e07a7fd5SMatthias Springer func::FuncDialect::getDialectNamespace());
1075413bf1bSKazu Hirata assert(maybeState && "FuncAnalysisState does not exist");
108e07a7fd5SMatthias Springer return **maybeState;
109e07a7fd5SMatthias Springer }
110e07a7fd5SMatthias Springer
111e07a7fd5SMatthias Springer /// Return the state (phase) of analysis of the FuncOp.
getFuncOpAnalysisState(const AnalysisState & state,FuncOp funcOp)112e07a7fd5SMatthias Springer static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
113e07a7fd5SMatthias Springer FuncOp funcOp) {
114cd80617aSMatthias Springer Optional<const FuncAnalysisState *> maybeState =
115cd80617aSMatthias Springer state.getDialectState<FuncAnalysisState>(
116cd80617aSMatthias Springer func::FuncDialect::getDialectNamespace());
117491d2701SKazu Hirata if (!maybeState.has_value())
118cd80617aSMatthias Springer return FuncOpAnalysisState::NotAnalyzed;
119*c27d8152SKazu Hirata const auto &analyzedFuncOps = maybeState.value()->analyzedFuncOps;
120cd80617aSMatthias Springer auto it = analyzedFuncOps.find(funcOp);
121cd80617aSMatthias Springer if (it == analyzedFuncOps.end())
122e07a7fd5SMatthias Springer return FuncOpAnalysisState::NotAnalyzed;
123e07a7fd5SMatthias Springer return it->second;
124e07a7fd5SMatthias Springer }
125e07a7fd5SMatthias Springer
126e07a7fd5SMatthias Springer /// Return the index of the bbArg in the given FuncOp that is equivalent to the
127e07a7fd5SMatthias Springer /// specified return value (if any).
getEquivalentFuncArgIdx(FuncOp funcOp,const FuncAnalysisState & state,int64_t returnValIdx)128e07a7fd5SMatthias Springer static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp,
129e07a7fd5SMatthias Springer const FuncAnalysisState &state,
130e07a7fd5SMatthias Springer int64_t returnValIdx) {
131e07a7fd5SMatthias Springer auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
132e07a7fd5SMatthias Springer if (funcOpIt == state.equivalentFuncArgs.end())
133e07a7fd5SMatthias Springer // No equivalence info stores for funcOp.
134e07a7fd5SMatthias Springer return None;
135e07a7fd5SMatthias Springer
136e07a7fd5SMatthias Springer auto retValIt = funcOpIt->getSecond().find(returnValIdx);
137e07a7fd5SMatthias Springer if (retValIt == funcOpIt->getSecond().end())
138e07a7fd5SMatthias Springer // Return value has no equivalent bbArg.
139e07a7fd5SMatthias Springer return None;
140e07a7fd5SMatthias Springer
141e07a7fd5SMatthias Springer return retValIt->getSecond();
142e07a7fd5SMatthias Springer }
143e07a7fd5SMatthias Springer
144e07a7fd5SMatthias Springer struct CallOpInterface
145e07a7fd5SMatthias Springer : public BufferizableOpInterface::ExternalModel<CallOpInterface,
146e07a7fd5SMatthias Springer func::CallOp> {
bufferizesToMemoryReadmlir::bufferization::func_ext::CallOpInterface147e07a7fd5SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
148e07a7fd5SMatthias Springer const AnalysisState &state) const {
149e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op);
150e07a7fd5SMatthias Springer FuncOp funcOp = getCalledFunction(callOp);
151e07a7fd5SMatthias Springer assert(funcOp && "expected CallOp to a FuncOp");
152e07a7fd5SMatthias Springer
153e07a7fd5SMatthias Springer if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
154e07a7fd5SMatthias Springer // FuncOp not analyzed yet. Assume that OpOperand is read.
155e07a7fd5SMatthias Springer return true;
156e07a7fd5SMatthias Springer
157cd80617aSMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state);
158e07a7fd5SMatthias Springer return funcState.readBbArgs.lookup(funcOp).contains(
159e07a7fd5SMatthias Springer opOperand.getOperandNumber());
160e07a7fd5SMatthias Springer }
161e07a7fd5SMatthias Springer
bufferizesToMemoryWritemlir::bufferization::func_ext::CallOpInterface162e07a7fd5SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
163e07a7fd5SMatthias Springer const AnalysisState &state) const {
164e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op);
165e07a7fd5SMatthias Springer FuncOp funcOp = getCalledFunction(callOp);
166e07a7fd5SMatthias Springer assert(funcOp && "expected CallOp to a FuncOp");
167e07a7fd5SMatthias Springer
168e07a7fd5SMatthias Springer if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
169e07a7fd5SMatthias Springer // FuncOp not analyzed yet. Assume that OpOperand is written.
170e07a7fd5SMatthias Springer return true;
171e07a7fd5SMatthias Springer
172cd80617aSMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state);
173e07a7fd5SMatthias Springer return funcState.writtenBbArgs.lookup(funcOp).contains(
174e07a7fd5SMatthias Springer opOperand.getOperandNumber());
175e07a7fd5SMatthias Springer }
176e07a7fd5SMatthias Springer
getAliasingOpResultmlir::bufferization::func_ext::CallOpInterface177e07a7fd5SMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
178e07a7fd5SMatthias Springer const AnalysisState &state) const {
179e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op);
180e07a7fd5SMatthias Springer FuncOp funcOp = getCalledFunction(callOp);
181e07a7fd5SMatthias Springer assert(funcOp && "expected CallOp to a FuncOp");
182e07a7fd5SMatthias Springer if (getFuncOpAnalysisState(state, funcOp) !=
183e07a7fd5SMatthias Springer FuncOpAnalysisState::Analyzed) {
184e07a7fd5SMatthias Springer // FuncOp not analyzed yet. Any OpResult may be aliasing.
185e07a7fd5SMatthias Springer SmallVector<OpResult> result;
186e07a7fd5SMatthias Springer for (OpResult opResult : op->getOpResults())
187e07a7fd5SMatthias Springer if (opResult.getType().isa<TensorType>())
188e07a7fd5SMatthias Springer result.push_back(opResult);
189e07a7fd5SMatthias Springer return result;
190e07a7fd5SMatthias Springer }
191e07a7fd5SMatthias Springer
192e07a7fd5SMatthias Springer // Get aliasing results from state.
193cd80617aSMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state);
194e07a7fd5SMatthias Springer auto aliasingReturnVals =
195e07a7fd5SMatthias Springer funcState.aliasingReturnVals.lookup(funcOp).lookup(
196e07a7fd5SMatthias Springer opOperand.getOperandNumber());
197e07a7fd5SMatthias Springer SmallVector<OpResult> result;
198e07a7fd5SMatthias Springer for (int64_t resultIdx : aliasingReturnVals)
199e07a7fd5SMatthias Springer result.push_back(callOp->getOpResult(resultIdx));
200e07a7fd5SMatthias Springer return result;
201e07a7fd5SMatthias Springer }
202e07a7fd5SMatthias Springer
203e07a7fd5SMatthias Springer SmallVector<OpOperand *>
getAliasingOpOperandmlir::bufferization::func_ext::CallOpInterface204e07a7fd5SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult,
205e07a7fd5SMatthias Springer const AnalysisState &state) const {
206e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op);
207e07a7fd5SMatthias Springer FuncOp funcOp = getCalledFunction(callOp);
208e07a7fd5SMatthias Springer assert(funcOp && "expected CallOp to a FuncOp");
209e07a7fd5SMatthias Springer if (getFuncOpAnalysisState(state, funcOp) !=
210e07a7fd5SMatthias Springer FuncOpAnalysisState::Analyzed) {
211e07a7fd5SMatthias Springer // FuncOp not analyzed yet. Any OpOperand may be aliasing.
212e07a7fd5SMatthias Springer SmallVector<OpOperand *> result;
213e07a7fd5SMatthias Springer for (OpOperand &opOperand : op->getOpOperands())
214e07a7fd5SMatthias Springer if (opOperand.get().getType().isa<TensorType>())
215e07a7fd5SMatthias Springer result.push_back(&opOperand);
216e07a7fd5SMatthias Springer return result;
217e07a7fd5SMatthias Springer }
218e07a7fd5SMatthias Springer
219e07a7fd5SMatthias Springer // Get aliasing bbArgs from state.
220cd80617aSMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state);
221e07a7fd5SMatthias Springer auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup(
222e07a7fd5SMatthias Springer opResult.getResultNumber());
223e07a7fd5SMatthias Springer SmallVector<OpOperand *> result;
224e07a7fd5SMatthias Springer for (int64_t bbArgIdx : aliasingFuncArgs)
225e07a7fd5SMatthias Springer result.push_back(&callOp->getOpOperand(bbArgIdx));
226e07a7fd5SMatthias Springer return result;
227e07a7fd5SMatthias Springer }
228e07a7fd5SMatthias Springer
bufferRelationmlir::bufferization::func_ext::CallOpInterface229e07a7fd5SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult,
230e07a7fd5SMatthias Springer const AnalysisState &state) const {
23188539c5bSMatthias Springer func::CallOp callOp = cast<func::CallOp>(op);
23288539c5bSMatthias Springer FuncOp funcOp = getCalledFunction(callOp);
23388539c5bSMatthias Springer assert(funcOp && "expected CallOp to a FuncOp");
23488539c5bSMatthias Springer if (getFuncOpAnalysisState(state, funcOp) !=
23588539c5bSMatthias Springer FuncOpAnalysisState::Analyzed) {
23688539c5bSMatthias Springer // Function not analyzed yet. The conservative answer is "None".
23788539c5bSMatthias Springer return BufferRelation::None;
23888539c5bSMatthias Springer }
23988539c5bSMatthias Springer
240cd80617aSMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state);
24188539c5bSMatthias Springer Optional<int64_t> maybeEquiv =
24288539c5bSMatthias Springer getEquivalentFuncArgIdx(funcOp, funcState, opResult.getResultNumber());
243037f0995SKazu Hirata if (maybeEquiv) {
24488539c5bSMatthias Springer #ifndef NDEBUG
24588539c5bSMatthias Springer SmallVector<OpOperand *> aliasingOpOperands =
24688539c5bSMatthias Springer getAliasingOpOperand(op, opResult, state);
24788539c5bSMatthias Springer assert(aliasingOpOperands.size() == 1 &&
24888539c5bSMatthias Springer "expected exactly 1 aliasing OpOperand");
2496d5fc1e3SKazu Hirata assert(aliasingOpOperands.front()->getOperandNumber() == *maybeEquiv &&
25088539c5bSMatthias Springer "inconsistent analysis state");
25188539c5bSMatthias Springer #endif
252e07a7fd5SMatthias Springer return BufferRelation::Equivalent;
253e07a7fd5SMatthias Springer }
25488539c5bSMatthias Springer return BufferRelation::None;
25588539c5bSMatthias Springer }
256e07a7fd5SMatthias Springer
257e07a7fd5SMatthias Springer /// All function arguments are writable. It is the responsibility of the
258e07a7fd5SMatthias Springer /// CallOp to insert buffer copies where necessary.
bufferizemlir::bufferization::func_ext::CallOpInterface259e07a7fd5SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
260b55d55ecSMatthias Springer const BufferizationOptions &options) const {
261e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op);
262e07a7fd5SMatthias Springer unsigned numResults = callOp.getNumResults();
263e07a7fd5SMatthias Springer unsigned numOperands = callOp->getNumOperands();
264e07a7fd5SMatthias Springer FuncOp funcOp = getCalledFunction(callOp);
265e07a7fd5SMatthias Springer assert(funcOp && "expected CallOp to a FuncOp");
266e07a7fd5SMatthias Springer FunctionType funcType = funcOp.getFunctionType();
267e07a7fd5SMatthias Springer
268e07a7fd5SMatthias Springer // Result types of the bufferized CallOp.
269e07a7fd5SMatthias Springer SmallVector<Type> resultTypes;
270e07a7fd5SMatthias Springer // Replacement values for the existing CallOp. These are usually the results
271e07a7fd5SMatthias Springer // of the bufferized CallOp, unless a tensor result folds onto an operand.
272e07a7fd5SMatthias Springer SmallVector<Value> replacementValues(numResults, Value());
273e07a7fd5SMatthias Springer // For non-tensor results: A mapping from return val indices of the old
274e07a7fd5SMatthias Springer // CallOp to return val indices of the bufferized CallOp.
275e07a7fd5SMatthias Springer SmallVector<Optional<unsigned>> retValMapping(numResults, None);
276e07a7fd5SMatthias Springer // Operands of the bufferized CallOp.
277e07a7fd5SMatthias Springer SmallVector<Value> newOperands(numOperands, Value());
278e07a7fd5SMatthias Springer
27988539c5bSMatthias Springer // 1. Compute the result types of the new CallOp.
280e07a7fd5SMatthias Springer for (const auto &it : llvm::enumerate(callOp.getResultTypes())) {
281e07a7fd5SMatthias Springer unsigned returnValIdx = it.index();
282e07a7fd5SMatthias Springer Type returnType = it.value();
283e07a7fd5SMatthias Springer if (!returnType.isa<TensorType>()) {
284e07a7fd5SMatthias Springer // Non-tensor values are returned.
285e07a7fd5SMatthias Springer retValMapping[returnValIdx] = resultTypes.size();
286e07a7fd5SMatthias Springer resultTypes.push_back(returnType);
287e07a7fd5SMatthias Springer continue;
288e07a7fd5SMatthias Springer }
289e07a7fd5SMatthias Springer
29088539c5bSMatthias Springer // Returning a memref.
291e07a7fd5SMatthias Springer retValMapping[returnValIdx] = resultTypes.size();
292e07a7fd5SMatthias Springer resultTypes.push_back(funcType.getResult(resultTypes.size()));
293e07a7fd5SMatthias Springer }
294e07a7fd5SMatthias Springer
295e07a7fd5SMatthias Springer // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
296e07a7fd5SMatthias Springer for (OpOperand &opOperand : callOp->getOpOperands()) {
297e07a7fd5SMatthias Springer unsigned idx = opOperand.getOperandNumber();
298e07a7fd5SMatthias Springer Value tensorOperand = opOperand.get();
299e07a7fd5SMatthias Springer
300e07a7fd5SMatthias Springer // Non-tensor operands are just copied.
301e07a7fd5SMatthias Springer if (!tensorOperand.getType().isa<TensorType>()) {
302e07a7fd5SMatthias Springer newOperands[idx] = tensorOperand;
303e07a7fd5SMatthias Springer continue;
304e07a7fd5SMatthias Springer }
305e07a7fd5SMatthias Springer
30688539c5bSMatthias Springer // Retrieve buffers for tensor operands.
307e07a7fd5SMatthias Springer Value buffer = newOperands[idx];
3085d50f51cSMatthias Springer if (!buffer) {
3095d50f51cSMatthias Springer FailureOr<Value> maybeBuffer =
3105d50f51cSMatthias Springer getBuffer(rewriter, opOperand.get(), options);
3115d50f51cSMatthias Springer if (failed(maybeBuffer))
3125d50f51cSMatthias Springer return failure();
3135d50f51cSMatthias Springer buffer = *maybeBuffer;
3145d50f51cSMatthias Springer }
315e07a7fd5SMatthias Springer
316e07a7fd5SMatthias Springer // Caller / callee type mismatch is handled with a CastOp.
317e07a7fd5SMatthias Springer auto memRefType = funcType.getInput(idx);
318e07a7fd5SMatthias Springer // Since we don't yet have a clear layout story, to_memref may
319e07a7fd5SMatthias Springer // conservatively turn tensors into more dynamic memref than necessary.
320e07a7fd5SMatthias Springer // If the memref type of the callee fails, introduce an extra memref.cast
321e07a7fd5SMatthias Springer // that will either canonicalize away or fail compilation until we can do
322e07a7fd5SMatthias Springer // something better.
323e07a7fd5SMatthias Springer if (buffer.getType() != memRefType) {
324e07a7fd5SMatthias Springer assert(
325e07a7fd5SMatthias Springer memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
326e07a7fd5SMatthias Springer "CallOp::bufferize: cast incompatible");
327e07a7fd5SMatthias Springer Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
328e07a7fd5SMatthias Springer memRefType, buffer);
329e07a7fd5SMatthias Springer buffer = castBuffer;
330e07a7fd5SMatthias Springer }
331e07a7fd5SMatthias Springer newOperands[idx] = buffer;
332e07a7fd5SMatthias Springer }
333e07a7fd5SMatthias Springer
334e07a7fd5SMatthias Springer // 3. Create the new CallOp.
335e07a7fd5SMatthias Springer Operation *newCallOp = rewriter.create<func::CallOp>(
336e07a7fd5SMatthias Springer callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
337e07a7fd5SMatthias Springer newCallOp->setAttrs(callOp->getAttrs());
33888539c5bSMatthias Springer // Get replacement values.
339e07a7fd5SMatthias Springer for (unsigned i = 0; i < replacementValues.size(); ++i) {
340e07a7fd5SMatthias Springer if (replacementValues[i])
341e07a7fd5SMatthias Springer continue;
342e07a7fd5SMatthias Springer replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
343e07a7fd5SMatthias Springer }
344e07a7fd5SMatthias Springer
345e07a7fd5SMatthias Springer // 4. Replace the old op with the new op.
346e07a7fd5SMatthias Springer replaceOpWithBufferizedValues(rewriter, callOp, replacementValues);
347e07a7fd5SMatthias Springer
348e07a7fd5SMatthias Springer return success();
349e07a7fd5SMatthias Springer }
350e07a7fd5SMatthias Springer };
351e07a7fd5SMatthias Springer
352e07a7fd5SMatthias Springer struct ReturnOpInterface
353e07a7fd5SMatthias Springer : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
354e07a7fd5SMatthias Springer func::ReturnOp> {
bufferizesToMemoryReadmlir::bufferization::func_ext::ReturnOpInterface355e07a7fd5SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
356e07a7fd5SMatthias Springer const AnalysisState &state) const {
357e07a7fd5SMatthias Springer return true;
358e07a7fd5SMatthias Springer }
359e07a7fd5SMatthias Springer
bufferizesToMemoryWritemlir::bufferization::func_ext::ReturnOpInterface360e07a7fd5SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
361e07a7fd5SMatthias Springer const AnalysisState &state) const {
362e07a7fd5SMatthias Springer return false;
363e07a7fd5SMatthias Springer }
364e07a7fd5SMatthias Springer
getAliasingOpResultmlir::bufferization::func_ext::ReturnOpInterface365e07a7fd5SMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
366e07a7fd5SMatthias Springer const AnalysisState &state) const {
367e07a7fd5SMatthias Springer return {};
368e07a7fd5SMatthias Springer }
369e07a7fd5SMatthias Springer
bufferizemlir::bufferization::func_ext::ReturnOpInterface370e07a7fd5SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
371b55d55ecSMatthias Springer const BufferizationOptions &options) const {
372e07a7fd5SMatthias Springer #ifndef NDEBUG
373e07a7fd5SMatthias Springer auto returnOp = cast<func::ReturnOp>(op);
374e07a7fd5SMatthias Springer assert(isa<FuncOp>(returnOp->getParentOp()) &&
375e07a7fd5SMatthias Springer "only support FuncOp parent for ReturnOp");
376e07a7fd5SMatthias Springer #endif // NDEBUG
377e07a7fd5SMatthias Springer
378e07a7fd5SMatthias Springer // ReturnOps are bufferized as part of FuncOps.
3790b293bf0SMatthias Springer return success();
380e07a7fd5SMatthias Springer }
381e07a7fd5SMatthias Springer };
382e07a7fd5SMatthias Springer
383e07a7fd5SMatthias Springer struct FuncOpInterface
384e07a7fd5SMatthias Springer : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
385f287da8aSMatthias Springer /// Rewrite function bbArgs and return values into buffer form. This function
386f287da8aSMatthias Springer /// bufferizes the function signature and the ReturnOp. When the entire
387f287da8aSMatthias Springer /// function body has been bufferized, function return types can be switched
388f287da8aSMatthias Springer /// to more concise memref types as part of `foldMemRefCasts`.
389e07a7fd5SMatthias Springer ///
390e07a7fd5SMatthias Springer /// All function bbArgs are writable unless they are explicitly marked as
391e07a7fd5SMatthias Springer /// read-only. Callers must insert copies when needed.
bufferizemlir::bufferization::func_ext::FuncOpInterface392e07a7fd5SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
393b55d55ecSMatthias Springer const BufferizationOptions &options) const {
394e07a7fd5SMatthias Springer auto funcOp = cast<FuncOp>(op);
395e07a7fd5SMatthias Springer FunctionType funcType = funcOp.getFunctionType();
396e07a7fd5SMatthias Springer
397e07a7fd5SMatthias Springer // Construct the bufferized function type.
398e07a7fd5SMatthias Springer SmallVector<Type> argTypes;
399e07a7fd5SMatthias Springer for (const auto &it : llvm::enumerate(funcType.getInputs())) {
400e07a7fd5SMatthias Springer Type argType = it.value();
401e07a7fd5SMatthias Springer if (auto tensorType = argType.dyn_cast<TensorType>()) {
402e07a7fd5SMatthias Springer argTypes.push_back(
403e07a7fd5SMatthias Springer getBufferizedFunctionArgType(funcOp, it.index(), options));
404e07a7fd5SMatthias Springer continue;
405e07a7fd5SMatthias Springer }
406e07a7fd5SMatthias Springer argTypes.push_back(argType);
407e07a7fd5SMatthias Springer }
408e07a7fd5SMatthias Springer
409e07a7fd5SMatthias Springer // Bodiless functions are assumed opaque and we cannot know the
410e07a7fd5SMatthias Springer // bufferization contract they want to enforce. As a consequence, only
411e07a7fd5SMatthias Springer // support functions that don't return any tensors atm.
412e07a7fd5SMatthias Springer if (funcOp.getBody().empty()) {
413e07a7fd5SMatthias Springer SmallVector<Type> retTypes;
414e07a7fd5SMatthias Springer for (Type resultType : funcType.getResults()) {
415e07a7fd5SMatthias Springer if (resultType.isa<TensorType>())
416e07a7fd5SMatthias Springer return funcOp->emitError() << "cannot bufferize bodiless function "
417e07a7fd5SMatthias Springer << "that returns a tensor";
418e07a7fd5SMatthias Springer retTypes.push_back(resultType);
419e07a7fd5SMatthias Springer }
420e07a7fd5SMatthias Springer funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
421e07a7fd5SMatthias Springer return success();
422e07a7fd5SMatthias Springer }
423e07a7fd5SMatthias Springer
424e07a7fd5SMatthias Springer // TODO: Support functions with multiple returns.
425e07a7fd5SMatthias Springer func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
426e07a7fd5SMatthias Springer assert(returnOp && "expected func with single return op");
427f287da8aSMatthias Springer Location loc = returnOp.getLoc();
428e07a7fd5SMatthias Springer
429e07a7fd5SMatthias Springer // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg.
430e07a7fd5SMatthias Springer Block &frontBlock = funcOp.getBody().front();
431e07a7fd5SMatthias Springer for (BlockArgument &bbArg : frontBlock.getArguments()) {
432e07a7fd5SMatthias Springer auto tensorType = bbArg.getType().dyn_cast<TensorType>();
433e07a7fd5SMatthias Springer // Non-tensor types stay the same.
434e07a7fd5SMatthias Springer if (!tensorType)
435e07a7fd5SMatthias Springer continue;
436e07a7fd5SMatthias Springer
437e07a7fd5SMatthias Springer // Collect all uses of the bbArg.
438e07a7fd5SMatthias Springer SmallVector<OpOperand *> bbArgUses;
439e07a7fd5SMatthias Springer for (OpOperand &use : bbArg.getUses())
440e07a7fd5SMatthias Springer bbArgUses.push_back(&use);
441e07a7fd5SMatthias Springer
442e07a7fd5SMatthias Springer // Change the bbArg type to memref.
443e07a7fd5SMatthias Springer Type memrefType =
444e07a7fd5SMatthias Springer getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options);
445e07a7fd5SMatthias Springer bbArg.setType(memrefType);
446e07a7fd5SMatthias Springer
447e07a7fd5SMatthias Springer // Replace all uses of the original tensor bbArg.
448e07a7fd5SMatthias Springer rewriter.setInsertionPointToStart(&frontBlock);
449e07a7fd5SMatthias Springer if (!bbArgUses.empty()) {
450e07a7fd5SMatthias Springer // Insert to_tensor because the remaining function body has not been
451e07a7fd5SMatthias Springer // bufferized yet.
452e07a7fd5SMatthias Springer Value toTensorOp =
453e07a7fd5SMatthias Springer rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg);
454e07a7fd5SMatthias Springer for (OpOperand *use : bbArgUses)
455e07a7fd5SMatthias Springer use->set(toTensorOp);
456e07a7fd5SMatthias Springer }
457e07a7fd5SMatthias Springer }
458e07a7fd5SMatthias Springer
459e07a7fd5SMatthias Springer // 2. For each result, keep track of which inplace argument it reuses.
460e07a7fd5SMatthias Springer SmallVector<Value> returnValues;
461e07a7fd5SMatthias Springer for (OpOperand &returnOperand : returnOp->getOpOperands()) {
462e07a7fd5SMatthias Springer Value returnVal = returnOperand.get();
463f287da8aSMatthias Springer auto tensorType = returnVal.getType().dyn_cast<TensorType>();
464f287da8aSMatthias Springer rewriter.setInsertionPoint(returnOp);
465e07a7fd5SMatthias Springer
466e07a7fd5SMatthias Springer // If not a tensor type just forward it.
467f287da8aSMatthias Springer if (!tensorType) {
468e07a7fd5SMatthias Springer returnValues.push_back(returnVal);
469e07a7fd5SMatthias Springer continue;
470e07a7fd5SMatthias Springer }
471e07a7fd5SMatthias Springer
472f287da8aSMatthias Springer BaseMemRefType resultType;
473f287da8aSMatthias Springer if (options.functionBoundaryTypeConversion ==
474f287da8aSMatthias Springer BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
475f287da8aSMatthias Springer resultType = getMemRefTypeWithStaticIdentityLayout(tensorType);
476f287da8aSMatthias Springer } else {
477f287da8aSMatthias Springer // Note: If `InferLayoutMap`, cast are later folded away.
478f287da8aSMatthias Springer resultType = getMemRefTypeWithFullyDynamicLayout(tensorType);
479f287da8aSMatthias Springer }
480f287da8aSMatthias Springer Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
481f287da8aSMatthias Springer loc, resultType, returnVal);
482f287da8aSMatthias Springer returnValues.push_back(toMemrefOp);
483e07a7fd5SMatthias Springer }
484e07a7fd5SMatthias Springer
485e07a7fd5SMatthias Springer // 3. Rewrite the terminator without the in-place bufferizable values.
486e07a7fd5SMatthias Springer returnOp.operandsMutable().assign(returnValues);
487e07a7fd5SMatthias Springer
488e07a7fd5SMatthias Springer // 4. Rewrite the FuncOp type to buffer form.
489e07a7fd5SMatthias Springer funcOp.setType(FunctionType::get(op->getContext(), argTypes,
490e07a7fd5SMatthias Springer ValueRange(returnValues).getTypes()));
491e07a7fd5SMatthias Springer
492e07a7fd5SMatthias Springer return success();
493e07a7fd5SMatthias Springer }
494e07a7fd5SMatthias Springer
495e07a7fd5SMatthias Springer /// Return `true` if the given function argument is writable.
isWritablemlir::bufferization::func_ext::FuncOpInterface496e07a7fd5SMatthias Springer bool isWritable(Operation *op, Value value,
497e07a7fd5SMatthias Springer const AnalysisState &state) const {
498e07a7fd5SMatthias Springer auto funcOp = cast<FuncOp>(op);
499e07a7fd5SMatthias Springer BlockArgument bbArg = value.dyn_cast<BlockArgument>();
500e07a7fd5SMatthias Springer assert(bbArg && "expected BlockArgument");
501e07a7fd5SMatthias Springer
502e07a7fd5SMatthias Springer // "bufferization.writable" overrides other writability decisions. This is
503e07a7fd5SMatthias Springer // currently used for testing only.
504e07a7fd5SMatthias Springer if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>(
505e07a7fd5SMatthias Springer bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName))
506e07a7fd5SMatthias Springer return writable.getValue();
507e07a7fd5SMatthias Springer
508e07a7fd5SMatthias Springer // All function arguments are writable by default.
509e07a7fd5SMatthias Springer return true;
510e07a7fd5SMatthias Springer }
511e07a7fd5SMatthias Springer };
512e07a7fd5SMatthias Springer
513e07a7fd5SMatthias Springer } // namespace func_ext
514e07a7fd5SMatthias Springer } // namespace bufferization
515e07a7fd5SMatthias Springer } // namespace mlir
516e07a7fd5SMatthias Springer
517e07a7fd5SMatthias Springer void mlir::bufferization::func_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)518e07a7fd5SMatthias Springer registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
519e07a7fd5SMatthias Springer registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
520e07a7fd5SMatthias Springer func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
521e07a7fd5SMatthias Springer func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
522e07a7fd5SMatthias Springer func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
523e07a7fd5SMatthias Springer });
524e07a7fd5SMatthias Springer }
525