1e07a7fd5SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2e07a7fd5SMatthias Springer //
3e07a7fd5SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e07a7fd5SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5e07a7fd5SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e07a7fd5SMatthias Springer //
7e07a7fd5SMatthias Springer //===----------------------------------------------------------------------===//
8e07a7fd5SMatthias Springer 
9e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
10e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13e07a7fd5SMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h"
14e07a7fd5SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
15e07a7fd5SMatthias Springer #include "mlir/IR/Dialect.h"
16e07a7fd5SMatthias Springer #include "mlir/IR/Operation.h"
17e07a7fd5SMatthias Springer 
18e07a7fd5SMatthias Springer namespace mlir {
19e07a7fd5SMatthias Springer namespace bufferization {
20e07a7fd5SMatthias Springer namespace func_ext {
21e07a7fd5SMatthias Springer 
22e07a7fd5SMatthias Springer void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
23e07a7fd5SMatthias Springer   analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
24e07a7fd5SMatthias Springer   auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
25e07a7fd5SMatthias Springer   auto createdAliasingOperands =
26e07a7fd5SMatthias Springer       aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping());
27e07a7fd5SMatthias Springer   auto createdAliasingResults =
28e07a7fd5SMatthias Springer       aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping());
29e07a7fd5SMatthias Springer   auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet());
30e07a7fd5SMatthias Springer   auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet());
31e07a7fd5SMatthias Springer   (void)createdEquiv;
32e07a7fd5SMatthias Springer   (void)createdAliasingOperands;
33e07a7fd5SMatthias Springer   (void)createdAliasingResults;
34e07a7fd5SMatthias Springer   (void)createdRead;
35e07a7fd5SMatthias Springer   (void)createdWritten;
36e07a7fd5SMatthias Springer #ifndef NDEBUG
37e07a7fd5SMatthias Springer   assert(createdEquiv.second && "equivalence info exists already");
38e07a7fd5SMatthias Springer   assert(createdAliasingOperands.second && "aliasing info exists already");
39e07a7fd5SMatthias Springer   assert(createdAliasingResults.second && "aliasing info exists already");
40e07a7fd5SMatthias Springer   assert(createdRead.second && "bbarg access info exists already");
41e07a7fd5SMatthias Springer   assert(createdWritten.second && "bbarg access info exists already");
42e07a7fd5SMatthias Springer #endif // NDEBUG
43e07a7fd5SMatthias Springer }
44e07a7fd5SMatthias Springer 
45e07a7fd5SMatthias Springer /// Return the unique ReturnOp that terminates `funcOp`.
46e07a7fd5SMatthias Springer /// Return nullptr if there is no such unique ReturnOp.
47e07a7fd5SMatthias Springer static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
48e07a7fd5SMatthias Springer   func::ReturnOp returnOp;
49e07a7fd5SMatthias Springer   for (Block &b : funcOp.getBody()) {
50e07a7fd5SMatthias Springer     if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
51e07a7fd5SMatthias Springer       if (returnOp)
52e07a7fd5SMatthias Springer         return nullptr;
53e07a7fd5SMatthias Springer       returnOp = candidateOp;
54e07a7fd5SMatthias Springer     }
55e07a7fd5SMatthias Springer   }
56e07a7fd5SMatthias Springer   return returnOp;
57e07a7fd5SMatthias Springer }
58e07a7fd5SMatthias Springer 
59e07a7fd5SMatthias Springer /// Return the index-th bufferized function argument type. This assumes that the
60e07a7fd5SMatthias Springer /// specified argument is a tensor. If the tensor is ranked, a layout map may be
61e07a7fd5SMatthias Springer /// specified by the user. If no layout map is specified, a fully dynamic map is
62e07a7fd5SMatthias Springer /// used.
63e07a7fd5SMatthias Springer static BaseMemRefType
64e07a7fd5SMatthias Springer getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
65e07a7fd5SMatthias Springer                              const BufferizationOptions &options) {
66e07a7fd5SMatthias Springer   auto tensorType =
67e07a7fd5SMatthias Springer       funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
68e07a7fd5SMatthias Springer   assert(tensorType && "expected TensorType");
69e07a7fd5SMatthias Springer   BaseMemRefType memrefType = getMemRefType(tensorType, options);
70e07a7fd5SMatthias Springer 
71e07a7fd5SMatthias Springer   auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
72e07a7fd5SMatthias Springer       index, BufferizationDialect::kBufferLayoutAttrName);
73e07a7fd5SMatthias Springer   if (!layoutAttr)
74e07a7fd5SMatthias Springer     return memrefType;
75e07a7fd5SMatthias Springer 
76e07a7fd5SMatthias Springer   auto rankedMemrefType = memrefType.dyn_cast<MemRefType>();
77e07a7fd5SMatthias Springer   assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
78e07a7fd5SMatthias Springer   return MemRefType::get(
79e07a7fd5SMatthias Springer       rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
80e07a7fd5SMatthias Springer       layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt());
81e07a7fd5SMatthias Springer }
82e07a7fd5SMatthias Springer 
83e07a7fd5SMatthias Springer /// Return the FuncOp called by `callOp`.
84e07a7fd5SMatthias Springer static FuncOp getCalledFunction(CallOpInterface callOp) {
85e07a7fd5SMatthias Springer   SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
86e07a7fd5SMatthias Springer   if (!sym)
87e07a7fd5SMatthias Springer     return nullptr;
88e07a7fd5SMatthias Springer   return dyn_cast_or_null<FuncOp>(
89e07a7fd5SMatthias Springer       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
90e07a7fd5SMatthias Springer }
91e07a7fd5SMatthias Springer 
92e07a7fd5SMatthias Springer /// Get FuncAnalysisState.
93e07a7fd5SMatthias Springer static const FuncAnalysisState &
94e07a7fd5SMatthias Springer getFuncAnalysisState(const AnalysisState &state) {
95e07a7fd5SMatthias Springer   Optional<const FuncAnalysisState *> maybeState =
96e07a7fd5SMatthias Springer       state.getDialectState<FuncAnalysisState>(
97e07a7fd5SMatthias Springer           func::FuncDialect::getDialectNamespace());
98e07a7fd5SMatthias Springer   assert(maybeState.hasValue() && "FuncAnalysisState does not exist");
99e07a7fd5SMatthias Springer   return **maybeState;
100e07a7fd5SMatthias Springer }
101e07a7fd5SMatthias Springer 
102e07a7fd5SMatthias Springer /// Return the state (phase) of analysis of the FuncOp.
103e07a7fd5SMatthias Springer static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
104e07a7fd5SMatthias Springer                                                   FuncOp funcOp) {
105e07a7fd5SMatthias Springer   const FuncAnalysisState &funcState = getFuncAnalysisState(state);
106e07a7fd5SMatthias Springer   auto it = funcState.analyzedFuncOps.find(funcOp);
107e07a7fd5SMatthias Springer   if (it == funcState.analyzedFuncOps.end())
108e07a7fd5SMatthias Springer     return FuncOpAnalysisState::NotAnalyzed;
109e07a7fd5SMatthias Springer   return it->second;
110e07a7fd5SMatthias Springer }
111e07a7fd5SMatthias Springer 
112e07a7fd5SMatthias Springer /// Return the index of the bbArg in the given FuncOp that is equivalent to the
113e07a7fd5SMatthias Springer /// specified return value (if any).
114e07a7fd5SMatthias Springer static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp,
115e07a7fd5SMatthias Springer                                                  const FuncAnalysisState &state,
116e07a7fd5SMatthias Springer                                                  int64_t returnValIdx) {
117e07a7fd5SMatthias Springer   auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
118e07a7fd5SMatthias Springer   if (funcOpIt == state.equivalentFuncArgs.end())
119e07a7fd5SMatthias Springer     // No equivalence info stores for funcOp.
120e07a7fd5SMatthias Springer     return None;
121e07a7fd5SMatthias Springer 
122e07a7fd5SMatthias Springer   auto retValIt = funcOpIt->getSecond().find(returnValIdx);
123e07a7fd5SMatthias Springer   if (retValIt == funcOpIt->getSecond().end())
124e07a7fd5SMatthias Springer     // Return value has no equivalent bbArg.
125e07a7fd5SMatthias Springer     return None;
126e07a7fd5SMatthias Springer 
127e07a7fd5SMatthias Springer   return retValIt->getSecond();
128e07a7fd5SMatthias Springer }
129e07a7fd5SMatthias Springer 
130e07a7fd5SMatthias Springer struct CallOpInterface
131e07a7fd5SMatthias Springer     : public BufferizableOpInterface::ExternalModel<CallOpInterface,
132e07a7fd5SMatthias Springer                                                     func::CallOp> {
133e07a7fd5SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
134e07a7fd5SMatthias Springer                               const AnalysisState &state) const {
135e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
136e07a7fd5SMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
137e07a7fd5SMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
138e07a7fd5SMatthias Springer 
139e07a7fd5SMatthias Springer     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
140e07a7fd5SMatthias Springer     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
141e07a7fd5SMatthias Springer       // FuncOp not analyzed yet. Assume that OpOperand is read.
142e07a7fd5SMatthias Springer       return true;
143e07a7fd5SMatthias Springer 
144e07a7fd5SMatthias Springer     return funcState.readBbArgs.lookup(funcOp).contains(
145e07a7fd5SMatthias Springer         opOperand.getOperandNumber());
146e07a7fd5SMatthias Springer   }
147e07a7fd5SMatthias Springer 
148e07a7fd5SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
149e07a7fd5SMatthias Springer                                const AnalysisState &state) const {
150e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
151e07a7fd5SMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
152e07a7fd5SMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
153e07a7fd5SMatthias Springer 
154e07a7fd5SMatthias Springer     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
155e07a7fd5SMatthias Springer     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
156e07a7fd5SMatthias Springer       // FuncOp not analyzed yet. Assume that OpOperand is written.
157e07a7fd5SMatthias Springer       return true;
158e07a7fd5SMatthias Springer 
159e07a7fd5SMatthias Springer     return funcState.writtenBbArgs.lookup(funcOp).contains(
160e07a7fd5SMatthias Springer         opOperand.getOperandNumber());
161e07a7fd5SMatthias Springer   }
162e07a7fd5SMatthias Springer 
163e07a7fd5SMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
164e07a7fd5SMatthias Springer                                             const AnalysisState &state) const {
165e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
166e07a7fd5SMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
167e07a7fd5SMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
168e07a7fd5SMatthias Springer     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
169e07a7fd5SMatthias Springer     if (getFuncOpAnalysisState(state, funcOp) !=
170e07a7fd5SMatthias Springer         FuncOpAnalysisState::Analyzed) {
171e07a7fd5SMatthias Springer       // FuncOp not analyzed yet. Any OpResult may be aliasing.
172e07a7fd5SMatthias Springer       SmallVector<OpResult> result;
173e07a7fd5SMatthias Springer       for (OpResult opResult : op->getOpResults())
174e07a7fd5SMatthias Springer         if (opResult.getType().isa<TensorType>())
175e07a7fd5SMatthias Springer           result.push_back(opResult);
176e07a7fd5SMatthias Springer       return result;
177e07a7fd5SMatthias Springer     }
178e07a7fd5SMatthias Springer 
179e07a7fd5SMatthias Springer     // Get aliasing results from state.
180e07a7fd5SMatthias Springer     auto aliasingReturnVals =
181e07a7fd5SMatthias Springer         funcState.aliasingReturnVals.lookup(funcOp).lookup(
182e07a7fd5SMatthias Springer             opOperand.getOperandNumber());
183e07a7fd5SMatthias Springer     SmallVector<OpResult> result;
184e07a7fd5SMatthias Springer     for (int64_t resultIdx : aliasingReturnVals)
185e07a7fd5SMatthias Springer       result.push_back(callOp->getOpResult(resultIdx));
186e07a7fd5SMatthias Springer     return result;
187e07a7fd5SMatthias Springer   }
188e07a7fd5SMatthias Springer 
189e07a7fd5SMatthias Springer   SmallVector<OpOperand *>
190e07a7fd5SMatthias Springer   getAliasingOpOperand(Operation *op, OpResult opResult,
191e07a7fd5SMatthias Springer                        const AnalysisState &state) const {
192e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
193e07a7fd5SMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
194e07a7fd5SMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
195e07a7fd5SMatthias Springer     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
196e07a7fd5SMatthias Springer     if (getFuncOpAnalysisState(state, funcOp) !=
197e07a7fd5SMatthias Springer         FuncOpAnalysisState::Analyzed) {
198e07a7fd5SMatthias Springer       // FuncOp not analyzed yet. Any OpOperand may be aliasing.
199e07a7fd5SMatthias Springer       SmallVector<OpOperand *> result;
200e07a7fd5SMatthias Springer       for (OpOperand &opOperand : op->getOpOperands())
201e07a7fd5SMatthias Springer         if (opOperand.get().getType().isa<TensorType>())
202e07a7fd5SMatthias Springer           result.push_back(&opOperand);
203e07a7fd5SMatthias Springer       return result;
204e07a7fd5SMatthias Springer     }
205e07a7fd5SMatthias Springer 
206e07a7fd5SMatthias Springer     // Get aliasing bbArgs from state.
207e07a7fd5SMatthias Springer     auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup(
208e07a7fd5SMatthias Springer         opResult.getResultNumber());
209e07a7fd5SMatthias Springer     SmallVector<OpOperand *> result;
210e07a7fd5SMatthias Springer     for (int64_t bbArgIdx : aliasingFuncArgs)
211e07a7fd5SMatthias Springer       result.push_back(&callOp->getOpOperand(bbArgIdx));
212e07a7fd5SMatthias Springer     return result;
213e07a7fd5SMatthias Springer   }
214e07a7fd5SMatthias Springer 
215e07a7fd5SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
216e07a7fd5SMatthias Springer                                 const AnalysisState &state) const {
217e07a7fd5SMatthias Springer     return BufferRelation::Equivalent;
218e07a7fd5SMatthias Springer   }
219e07a7fd5SMatthias Springer 
220e07a7fd5SMatthias Springer   /// All function arguments are writable. It is the responsibility of the
221e07a7fd5SMatthias Springer   /// CallOp to insert buffer copies where necessary.
222e07a7fd5SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
223e07a7fd5SMatthias Springer                           BufferizationState &state) const {
224e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
225e07a7fd5SMatthias Springer     unsigned numResults = callOp.getNumResults();
226e07a7fd5SMatthias Springer     unsigned numOperands = callOp->getNumOperands();
227e07a7fd5SMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
228e07a7fd5SMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
229e07a7fd5SMatthias Springer     FunctionType funcType = funcOp.getFunctionType();
230e07a7fd5SMatthias Springer     const FuncAnalysisState &funcState =
231e07a7fd5SMatthias Springer         getFuncAnalysisState(state.getAnalysisState());
232e07a7fd5SMatthias Springer     const OneShotBufferizationOptions &options =
233e07a7fd5SMatthias Springer         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
234e07a7fd5SMatthias Springer 
235e07a7fd5SMatthias Springer     // Result types of the bufferized CallOp.
236e07a7fd5SMatthias Springer     SmallVector<Type> resultTypes;
237e07a7fd5SMatthias Springer     // Replacement values for the existing CallOp. These are usually the results
238e07a7fd5SMatthias Springer     // of the bufferized CallOp, unless a tensor result folds onto an operand.
239e07a7fd5SMatthias Springer     SmallVector<Value> replacementValues(numResults, Value());
240e07a7fd5SMatthias Springer     // For non-tensor results: A mapping from return val indices of the old
241e07a7fd5SMatthias Springer     // CallOp to return val indices of the bufferized CallOp.
242e07a7fd5SMatthias Springer     SmallVector<Optional<unsigned>> retValMapping(numResults, None);
243e07a7fd5SMatthias Springer     // Operands of the bufferized CallOp.
244e07a7fd5SMatthias Springer     SmallVector<Value> newOperands(numOperands, Value());
245e07a7fd5SMatthias Springer 
246e07a7fd5SMatthias Springer     // Based on previously gathered equivalence information, we know if a
247e07a7fd5SMatthias Springer     // tensor result folds onto an operand. These are the only tensor value
248e07a7fd5SMatthias Springer     // results that are supported at the moment.
249e07a7fd5SMatthias Springer     //
250e07a7fd5SMatthias Springer     // For tensors return values that do not fold onto an operand, additional
251e07a7fd5SMatthias Springer     // work is needed (TODO) to either:
252e07a7fd5SMatthias Springer     // * hoist a result into an inplaceable operand or
253e07a7fd5SMatthias Springer     // * devise a better representation to truly return a buffer.
254e07a7fd5SMatthias Springer     //
255e07a7fd5SMatthias Springer     // Note: If a function has no body, no equivalence information is
256e07a7fd5SMatthias Springer     // available. Consequently, a tensor return value cannot be proven to fold
257e07a7fd5SMatthias Springer     // onto a FuncOp bbArg, so calls to such functions are not bufferizable at
258e07a7fd5SMatthias Springer     // the moment.
259e07a7fd5SMatthias Springer 
260e07a7fd5SMatthias Springer     // 1. Compute the result types of the new CallOp. Tensor results that are
261e07a7fd5SMatthias Springer     // equivalent to a FuncOp bbArg are no longer returned.
262e07a7fd5SMatthias Springer     for (const auto &it : llvm::enumerate(callOp.getResultTypes())) {
263e07a7fd5SMatthias Springer       unsigned returnValIdx = it.index();
264e07a7fd5SMatthias Springer       Type returnType = it.value();
265e07a7fd5SMatthias Springer       if (!returnType.isa<TensorType>()) {
266e07a7fd5SMatthias Springer         // Non-tensor values are returned.
267e07a7fd5SMatthias Springer         retValMapping[returnValIdx] = resultTypes.size();
268e07a7fd5SMatthias Springer         resultTypes.push_back(returnType);
269e07a7fd5SMatthias Springer         continue;
270e07a7fd5SMatthias Springer       }
271e07a7fd5SMatthias Springer 
272e8f7d019SAlexander Belyaev       if (options.dropEquivalentFuncResults) {
273e07a7fd5SMatthias Springer         if (Optional<int64_t> bbArgIdx =
274e07a7fd5SMatthias Springer                 getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) {
275e07a7fd5SMatthias Springer           // Return operands that are equivalent to some bbArg, are not
276e07a7fd5SMatthias Springer           // returned.
277e07a7fd5SMatthias Springer           FailureOr<Value> bufferOrFailure =
278e07a7fd5SMatthias Springer               state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx));
279e07a7fd5SMatthias Springer           if (failed(bufferOrFailure))
280e07a7fd5SMatthias Springer             return failure();
281e07a7fd5SMatthias Springer           replacementValues[returnValIdx] = *bufferOrFailure;
282e07a7fd5SMatthias Springer           newOperands[*bbArgIdx] = *bufferOrFailure;
283e07a7fd5SMatthias Springer           continue;
284e07a7fd5SMatthias Springer         }
285e8f7d019SAlexander Belyaev       }
286e07a7fd5SMatthias Springer 
287e07a7fd5SMatthias Springer       if (!options.allowReturnAllocs)
288e07a7fd5SMatthias Springer         return callOp->emitError(
289e07a7fd5SMatthias Springer             "call to FuncOp that returns non-equivalent tensors not supported");
290e07a7fd5SMatthias Springer 
291e07a7fd5SMatthias Springer       // Returning a memref. This memref is not equivalent to any bbArg. It is
292e07a7fd5SMatthias Springer       // likely a newly allocated buffer. We may want to hoist such allocations
293e07a7fd5SMatthias Springer       // to the call site in the future.
294e07a7fd5SMatthias Springer       retValMapping[returnValIdx] = resultTypes.size();
295e07a7fd5SMatthias Springer       resultTypes.push_back(funcType.getResult(resultTypes.size()));
296e07a7fd5SMatthias Springer     }
297e07a7fd5SMatthias Springer 
298e07a7fd5SMatthias Springer     // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
299e07a7fd5SMatthias Springer     for (OpOperand &opOperand : callOp->getOpOperands()) {
300e07a7fd5SMatthias Springer       unsigned idx = opOperand.getOperandNumber();
301e07a7fd5SMatthias Springer       Value tensorOperand = opOperand.get();
302e07a7fd5SMatthias Springer 
303e07a7fd5SMatthias Springer       // Non-tensor operands are just copied.
304e07a7fd5SMatthias Springer       if (!tensorOperand.getType().isa<TensorType>()) {
305e07a7fd5SMatthias Springer         newOperands[idx] = tensorOperand;
306e07a7fd5SMatthias Springer         continue;
307e07a7fd5SMatthias Springer       }
308e07a7fd5SMatthias Springer 
309e07a7fd5SMatthias Springer       // Retrieve buffers for tensor operands. Tensor operand buffers, who's
310e07a7fd5SMatthias Springer       // corresponding FuncOp bbArgs are equivalent to a returned tensor, were
311e07a7fd5SMatthias Springer       // already stored in `newOperands` during Step 1.
312e07a7fd5SMatthias Springer       Value buffer = newOperands[idx];
313e07a7fd5SMatthias Springer       if (!buffer) {
314e07a7fd5SMatthias Springer         FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, opOperand);
315e07a7fd5SMatthias Springer         if (failed(bufferOrFailure))
316e07a7fd5SMatthias Springer           return failure();
317e07a7fd5SMatthias Springer         buffer = *bufferOrFailure;
318e07a7fd5SMatthias Springer       }
319e07a7fd5SMatthias Springer 
320e07a7fd5SMatthias Springer       // Caller / callee type mismatch is handled with a CastOp.
321e07a7fd5SMatthias Springer       auto memRefType = funcType.getInput(idx);
322e07a7fd5SMatthias Springer       // Since we don't yet have a clear layout story, to_memref may
323e07a7fd5SMatthias Springer       // conservatively turn tensors into more dynamic memref than necessary.
324e07a7fd5SMatthias Springer       // If the memref type of the callee fails, introduce an extra memref.cast
325e07a7fd5SMatthias Springer       // that will either canonicalize away or fail compilation until we can do
326e07a7fd5SMatthias Springer       // something better.
327e07a7fd5SMatthias Springer       if (buffer.getType() != memRefType) {
328e07a7fd5SMatthias Springer         assert(
329e07a7fd5SMatthias Springer             memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
330e07a7fd5SMatthias Springer             "CallOp::bufferize: cast incompatible");
331e07a7fd5SMatthias Springer         Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
332e07a7fd5SMatthias Springer                                                            memRefType, buffer);
333e07a7fd5SMatthias Springer         buffer = castBuffer;
334e07a7fd5SMatthias Springer       }
335e07a7fd5SMatthias Springer       newOperands[idx] = buffer;
336e07a7fd5SMatthias Springer     }
337e07a7fd5SMatthias Springer 
338e07a7fd5SMatthias Springer     // 3. Create the new CallOp.
339e07a7fd5SMatthias Springer     Operation *newCallOp = rewriter.create<func::CallOp>(
340e07a7fd5SMatthias Springer         callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
341e07a7fd5SMatthias Springer     newCallOp->setAttrs(callOp->getAttrs());
342e07a7fd5SMatthias Springer     // Get replacement values for non-tensor / non-equivalent results.
343e07a7fd5SMatthias Springer     for (unsigned i = 0; i < replacementValues.size(); ++i) {
344e07a7fd5SMatthias Springer       if (replacementValues[i])
345e07a7fd5SMatthias Springer         continue;
346e07a7fd5SMatthias Springer       replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
347e07a7fd5SMatthias Springer     }
348e07a7fd5SMatthias Springer 
349e07a7fd5SMatthias Springer     // 4. Replace the old op with the new op.
350e07a7fd5SMatthias Springer     replaceOpWithBufferizedValues(rewriter, callOp, replacementValues);
351e07a7fd5SMatthias Springer 
352e07a7fd5SMatthias Springer     return success();
353e07a7fd5SMatthias Springer   }
354e07a7fd5SMatthias Springer };
355e07a7fd5SMatthias Springer 
356e07a7fd5SMatthias Springer struct ReturnOpInterface
357e07a7fd5SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
358e07a7fd5SMatthias Springer                                                     func::ReturnOp> {
359e07a7fd5SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
360e07a7fd5SMatthias Springer                               const AnalysisState &state) const {
361e07a7fd5SMatthias Springer     return true;
362e07a7fd5SMatthias Springer   }
363e07a7fd5SMatthias Springer 
364e07a7fd5SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
365e07a7fd5SMatthias Springer                                const AnalysisState &state) const {
366e07a7fd5SMatthias Springer     return false;
367e07a7fd5SMatthias Springer   }
368e07a7fd5SMatthias Springer 
369e07a7fd5SMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
370e07a7fd5SMatthias Springer                                             const AnalysisState &state) const {
371e07a7fd5SMatthias Springer     return {};
372e07a7fd5SMatthias Springer   }
373e07a7fd5SMatthias Springer 
374e07a7fd5SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
375e07a7fd5SMatthias Springer                           BufferizationState &state) const {
376e07a7fd5SMatthias Springer #ifndef NDEBUG
377e07a7fd5SMatthias Springer     auto returnOp = cast<func::ReturnOp>(op);
378e07a7fd5SMatthias Springer     assert(isa<FuncOp>(returnOp->getParentOp()) &&
379e07a7fd5SMatthias Springer            "only support FuncOp parent for ReturnOp");
380e07a7fd5SMatthias Springer #endif // NDEBUG
381e07a7fd5SMatthias Springer 
382e07a7fd5SMatthias Springer     // ReturnOps are bufferized as part of FuncOps.
383e07a7fd5SMatthias Springer     return failure();
384e07a7fd5SMatthias Springer   }
385e07a7fd5SMatthias Springer };
386e07a7fd5SMatthias Springer 
387e07a7fd5SMatthias Springer struct FuncOpInterface
388e07a7fd5SMatthias Springer     : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
389e07a7fd5SMatthias Springer   /// Rewrite function bbArgs and return values into buffer form (using the
390e07a7fd5SMatthias Springer   /// canonical memref layout for now). This function bufferizes the function
391e07a7fd5SMatthias Springer   /// signature and the ReturnOp. When the entire function body has been
392e07a7fd5SMatthias Springer   /// bufferized, function return types can be switched to more concise memref
393e07a7fd5SMatthias Springer   /// types as part of `foldMemRefCasts`.
394e07a7fd5SMatthias Springer   ///
395e07a7fd5SMatthias Springer   /// When a tensor function argument is known to be equivalent to a tensor
396e07a7fd5SMatthias Springer   /// result, it is dropped from the return values.
397e07a7fd5SMatthias Springer   ///
398e07a7fd5SMatthias Springer   /// All function bbArgs are writable unless they are explicitly marked as
399e07a7fd5SMatthias Springer   /// read-only. Callers must insert copies when needed.
400e07a7fd5SMatthias Springer   ///
401e07a7fd5SMatthias Springer   /// Note: Returning a memref is possible, but corresponding CallOp
402e07a7fd5SMatthias Springer   /// bufferizations fail unless `allowReturnAllocs`.
403e07a7fd5SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
404e07a7fd5SMatthias Springer                           BufferizationState &state) const {
405e07a7fd5SMatthias Springer     auto funcOp = cast<FuncOp>(op);
406e07a7fd5SMatthias Springer     FunctionType funcType = funcOp.getFunctionType();
407e07a7fd5SMatthias Springer     const FuncAnalysisState &funcState =
408e07a7fd5SMatthias Springer         getFuncAnalysisState(state.getAnalysisState());
409e8f7d019SAlexander Belyaev     const OneShotBufferizationOptions &options =
410e8f7d019SAlexander Belyaev         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
411e07a7fd5SMatthias Springer 
412e07a7fd5SMatthias Springer     // Construct the bufferized function type.
413e07a7fd5SMatthias Springer     SmallVector<Type> argTypes;
414e07a7fd5SMatthias Springer     for (const auto &it : llvm::enumerate(funcType.getInputs())) {
415e07a7fd5SMatthias Springer       Type argType = it.value();
416e07a7fd5SMatthias Springer       if (auto tensorType = argType.dyn_cast<TensorType>()) {
417e07a7fd5SMatthias Springer         argTypes.push_back(
418e07a7fd5SMatthias Springer             getBufferizedFunctionArgType(funcOp, it.index(), options));
419e07a7fd5SMatthias Springer         continue;
420e07a7fd5SMatthias Springer       }
421e07a7fd5SMatthias Springer       argTypes.push_back(argType);
422e07a7fd5SMatthias Springer     }
423e07a7fd5SMatthias Springer 
424e07a7fd5SMatthias Springer     // Bodiless functions are assumed opaque and we cannot know the
425e07a7fd5SMatthias Springer     // bufferization contract they want to enforce. As a consequence, only
426e07a7fd5SMatthias Springer     // support functions that don't return any tensors atm.
427e07a7fd5SMatthias Springer     if (funcOp.getBody().empty()) {
428e07a7fd5SMatthias Springer       SmallVector<Type> retTypes;
429e07a7fd5SMatthias Springer       for (Type resultType : funcType.getResults()) {
430e07a7fd5SMatthias Springer         if (resultType.isa<TensorType>())
431e07a7fd5SMatthias Springer           return funcOp->emitError() << "cannot bufferize bodiless function "
432e07a7fd5SMatthias Springer                                      << "that returns a tensor";
433e07a7fd5SMatthias Springer         retTypes.push_back(resultType);
434e07a7fd5SMatthias Springer       }
435e07a7fd5SMatthias Springer       funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
436e07a7fd5SMatthias Springer       return success();
437e07a7fd5SMatthias Springer     }
438e07a7fd5SMatthias Springer 
439e07a7fd5SMatthias Springer     // TODO: Support functions with multiple returns.
440e07a7fd5SMatthias Springer     func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
441e07a7fd5SMatthias Springer     assert(returnOp && "expected func with single return op");
442e07a7fd5SMatthias Springer 
443e07a7fd5SMatthias Springer     // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg.
444e07a7fd5SMatthias Springer     Block &frontBlock = funcOp.getBody().front();
445e07a7fd5SMatthias Springer     for (BlockArgument &bbArg : frontBlock.getArguments()) {
446e07a7fd5SMatthias Springer       auto tensorType = bbArg.getType().dyn_cast<TensorType>();
447e07a7fd5SMatthias Springer       // Non-tensor types stay the same.
448e07a7fd5SMatthias Springer       if (!tensorType)
449e07a7fd5SMatthias Springer         continue;
450e07a7fd5SMatthias Springer 
451e07a7fd5SMatthias Springer       // Collect all uses of the bbArg.
452e07a7fd5SMatthias Springer       SmallVector<OpOperand *> bbArgUses;
453e07a7fd5SMatthias Springer       for (OpOperand &use : bbArg.getUses())
454e07a7fd5SMatthias Springer         bbArgUses.push_back(&use);
455e07a7fd5SMatthias Springer 
456e07a7fd5SMatthias Springer       // Change the bbArg type to memref.
457e07a7fd5SMatthias Springer       Type memrefType =
458e07a7fd5SMatthias Springer           getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options);
459e07a7fd5SMatthias Springer       bbArg.setType(memrefType);
460e07a7fd5SMatthias Springer 
461e07a7fd5SMatthias Springer       // Replace all uses of the original tensor bbArg.
462e07a7fd5SMatthias Springer       rewriter.setInsertionPointToStart(&frontBlock);
463e07a7fd5SMatthias Springer       if (!bbArgUses.empty()) {
464e07a7fd5SMatthias Springer         // Insert to_tensor because the remaining function body has not been
465e07a7fd5SMatthias Springer         // bufferized yet.
466e07a7fd5SMatthias Springer         Value toTensorOp =
467e07a7fd5SMatthias Springer             rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg);
468e07a7fd5SMatthias Springer         for (OpOperand *use : bbArgUses)
469e07a7fd5SMatthias Springer           use->set(toTensorOp);
470e07a7fd5SMatthias Springer       }
471e07a7fd5SMatthias Springer     }
472e07a7fd5SMatthias Springer 
473e07a7fd5SMatthias Springer     // 2. For each result, keep track of which inplace argument it reuses.
474e07a7fd5SMatthias Springer     SmallVector<Value> returnValues;
475e07a7fd5SMatthias Springer     for (OpOperand &returnOperand : returnOp->getOpOperands()) {
476e07a7fd5SMatthias Springer       Value returnVal = returnOperand.get();
477e07a7fd5SMatthias Springer 
478e07a7fd5SMatthias Springer       // If not a tensor type just forward it.
479e07a7fd5SMatthias Springer       if (!returnVal.getType().isa<RankedTensorType>()) {
480e07a7fd5SMatthias Springer         returnValues.push_back(returnVal);
481e07a7fd5SMatthias Springer         continue;
482e07a7fd5SMatthias Springer       }
483e07a7fd5SMatthias Springer 
484e07a7fd5SMatthias Springer       // If return operand is equivalent to some bbArg, no need to return it.
485e8f7d019SAlexander Belyaev       if (options.dropEquivalentFuncResults) {
486e07a7fd5SMatthias Springer         if (Optional<int64_t> equivBbArgIdx = getEquivalentFuncArgIdx(
487e07a7fd5SMatthias Springer                 funcOp, funcState, returnOperand.getOperandNumber())) {
488e07a7fd5SMatthias Springer           rewriter.setInsertionPoint(returnOp);
489e07a7fd5SMatthias Springer           Location loc = returnOp.getLoc();
490e07a7fd5SMatthias Springer           Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
491e8f7d019SAlexander Belyaev               loc,
492e8f7d019SAlexander Belyaev               getMemRefType(returnVal.getType().cast<TensorType>(), options),
493e07a7fd5SMatthias Springer               returnVal);
494e07a7fd5SMatthias Springer           BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx);
495e07a7fd5SMatthias Springer           // Note: This copy will fold away. It must be inserted here to ensure
496e07a7fd5SMatthias Springer           // that `returnVal` still has at least one use and does not fold away.
497e07a7fd5SMatthias Springer           if (failed(
498*248e113eSMatthias Springer                   options.createMemCpy(rewriter, loc, toMemrefOp, equivBbArg)))
499e07a7fd5SMatthias Springer             return funcOp->emitError("could not generate copy for bbArg");
500e07a7fd5SMatthias Springer           continue;
501e07a7fd5SMatthias Springer         }
502e8f7d019SAlexander Belyaev       }
503e07a7fd5SMatthias Springer 
504e07a7fd5SMatthias Springer       returnValues.push_back(*state.getBuffer(rewriter, returnOperand));
505e07a7fd5SMatthias Springer     }
506e07a7fd5SMatthias Springer 
507e07a7fd5SMatthias Springer     // 3. Rewrite the terminator without the in-place bufferizable values.
508e07a7fd5SMatthias Springer     returnOp.operandsMutable().assign(returnValues);
509e07a7fd5SMatthias Springer 
510e07a7fd5SMatthias Springer     // 4. Rewrite the FuncOp type to buffer form.
511e07a7fd5SMatthias Springer     funcOp.setType(FunctionType::get(op->getContext(), argTypes,
512e07a7fd5SMatthias Springer                                      ValueRange(returnValues).getTypes()));
513e07a7fd5SMatthias Springer 
514e07a7fd5SMatthias Springer     return success();
515e07a7fd5SMatthias Springer   }
516e07a7fd5SMatthias Springer 
517e07a7fd5SMatthias Springer   /// Return `true` if the given function argument is writable.
518e07a7fd5SMatthias Springer   bool isWritable(Operation *op, Value value,
519e07a7fd5SMatthias Springer                   const AnalysisState &state) const {
520e07a7fd5SMatthias Springer     auto funcOp = cast<FuncOp>(op);
521e07a7fd5SMatthias Springer     BlockArgument bbArg = value.dyn_cast<BlockArgument>();
522e07a7fd5SMatthias Springer     assert(bbArg && "expected BlockArgument");
523e07a7fd5SMatthias Springer 
524e07a7fd5SMatthias Springer     // "bufferization.writable" overrides other writability decisions. This is
525e07a7fd5SMatthias Springer     // currently used for testing only.
526e07a7fd5SMatthias Springer     if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>(
527e07a7fd5SMatthias Springer             bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName))
528e07a7fd5SMatthias Springer       return writable.getValue();
529e07a7fd5SMatthias Springer 
530e07a7fd5SMatthias Springer     // All function arguments are writable by default.
531e07a7fd5SMatthias Springer     return true;
532e07a7fd5SMatthias Springer   }
533e07a7fd5SMatthias Springer 
534e07a7fd5SMatthias Springer   bool isAllocationHoistingBarrier(Operation *op) const { return true; }
535e07a7fd5SMatthias Springer };
536e07a7fd5SMatthias Springer 
537e07a7fd5SMatthias Springer } // namespace func_ext
538e07a7fd5SMatthias Springer } // namespace bufferization
539e07a7fd5SMatthias Springer } // namespace mlir
540e07a7fd5SMatthias Springer 
541e07a7fd5SMatthias Springer void mlir::bufferization::func_ext::
542e07a7fd5SMatthias Springer     registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
543e07a7fd5SMatthias Springer   registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
544e07a7fd5SMatthias Springer     func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
545e07a7fd5SMatthias Springer     func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
546e07a7fd5SMatthias Springer     func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
547e07a7fd5SMatthias Springer   });
548e07a7fd5SMatthias Springer }
549