1e07a7fd5SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2e07a7fd5SMatthias Springer //
3e07a7fd5SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e07a7fd5SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5e07a7fd5SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e07a7fd5SMatthias Springer //
7e07a7fd5SMatthias Springer //===----------------------------------------------------------------------===//
8e07a7fd5SMatthias Springer 
9e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
10e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13e07a7fd5SMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h"
14e07a7fd5SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
15e07a7fd5SMatthias Springer #include "mlir/IR/Dialect.h"
16e07a7fd5SMatthias Springer #include "mlir/IR/Operation.h"
17e07a7fd5SMatthias Springer 
18e07a7fd5SMatthias Springer namespace mlir {
19e07a7fd5SMatthias Springer namespace bufferization {
20e07a7fd5SMatthias Springer namespace func_ext {
21e07a7fd5SMatthias Springer 
22e07a7fd5SMatthias Springer void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
23e07a7fd5SMatthias Springer   analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
24e07a7fd5SMatthias Springer   auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
25e07a7fd5SMatthias Springer   auto createdAliasingOperands =
26e07a7fd5SMatthias Springer       aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping());
27e07a7fd5SMatthias Springer   auto createdAliasingResults =
28e07a7fd5SMatthias Springer       aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping());
29e07a7fd5SMatthias Springer   auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet());
30e07a7fd5SMatthias Springer   auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet());
31e07a7fd5SMatthias Springer   (void)createdEquiv;
32e07a7fd5SMatthias Springer   (void)createdAliasingOperands;
33e07a7fd5SMatthias Springer   (void)createdAliasingResults;
34e07a7fd5SMatthias Springer   (void)createdRead;
35e07a7fd5SMatthias Springer   (void)createdWritten;
36e07a7fd5SMatthias Springer #ifndef NDEBUG
37e07a7fd5SMatthias Springer   assert(createdEquiv.second && "equivalence info exists already");
38e07a7fd5SMatthias Springer   assert(createdAliasingOperands.second && "aliasing info exists already");
39e07a7fd5SMatthias Springer   assert(createdAliasingResults.second && "aliasing info exists already");
40e07a7fd5SMatthias Springer   assert(createdRead.second && "bbarg access info exists already");
41e07a7fd5SMatthias Springer   assert(createdWritten.second && "bbarg access info exists already");
42e07a7fd5SMatthias Springer #endif // NDEBUG
43e07a7fd5SMatthias Springer }
44e07a7fd5SMatthias Springer 
45e07a7fd5SMatthias Springer /// Return the unique ReturnOp that terminates `funcOp`.
46e07a7fd5SMatthias Springer /// Return nullptr if there is no such unique ReturnOp.
47e07a7fd5SMatthias Springer static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
48e07a7fd5SMatthias Springer   func::ReturnOp returnOp;
49e07a7fd5SMatthias Springer   for (Block &b : funcOp.getBody()) {
50e07a7fd5SMatthias Springer     if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
51e07a7fd5SMatthias Springer       if (returnOp)
52e07a7fd5SMatthias Springer         return nullptr;
53e07a7fd5SMatthias Springer       returnOp = candidateOp;
54e07a7fd5SMatthias Springer     }
55e07a7fd5SMatthias Springer   }
56e07a7fd5SMatthias Springer   return returnOp;
57e07a7fd5SMatthias Springer }
58e07a7fd5SMatthias Springer 
59e07a7fd5SMatthias Springer /// Return the index-th bufferized function argument type. This assumes that the
60e07a7fd5SMatthias Springer /// specified argument is a tensor. If the tensor is ranked, a layout map may be
61*f287da8aSMatthias Springer /// specified by the user. If no layout map is specified, the default layout map
62*f287da8aSMatthias Springer /// (as per `options.functionBoundaryTypeConversion`) is used.
63e07a7fd5SMatthias Springer static BaseMemRefType
64e07a7fd5SMatthias Springer getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
65e07a7fd5SMatthias Springer                              const BufferizationOptions &options) {
66e07a7fd5SMatthias Springer   auto tensorType =
67e07a7fd5SMatthias Springer       funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
68e07a7fd5SMatthias Springer   assert(tensorType && "expected TensorType");
69*f287da8aSMatthias Springer 
70*f287da8aSMatthias Springer   BaseMemRefType memrefType;
71*f287da8aSMatthias Springer   if (options.functionBoundaryTypeConversion ==
72*f287da8aSMatthias Springer       BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
73*f287da8aSMatthias Springer     memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
74*f287da8aSMatthias Springer   } else {
75*f287da8aSMatthias Springer     // Note: Layout maps on function parameters cannot be inferred. The best we
76*f287da8aSMatthias Springer     // can do at the moment is "fully dynamic".
77*f287da8aSMatthias Springer     memrefType = getMemRefTypeWithFullyDynamicLayout(tensorType);
78*f287da8aSMatthias Springer   }
79e07a7fd5SMatthias Springer 
80e07a7fd5SMatthias Springer   auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
81e07a7fd5SMatthias Springer       index, BufferizationDialect::kBufferLayoutAttrName);
82e07a7fd5SMatthias Springer   if (!layoutAttr)
83e07a7fd5SMatthias Springer     return memrefType;
84e07a7fd5SMatthias Springer 
85e07a7fd5SMatthias Springer   auto rankedMemrefType = memrefType.dyn_cast<MemRefType>();
86e07a7fd5SMatthias Springer   assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
87e07a7fd5SMatthias Springer   return MemRefType::get(
88e07a7fd5SMatthias Springer       rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
89e07a7fd5SMatthias Springer       layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt());
90e07a7fd5SMatthias Springer }
91e07a7fd5SMatthias Springer 
92e07a7fd5SMatthias Springer /// Return the FuncOp called by `callOp`.
93e07a7fd5SMatthias Springer static FuncOp getCalledFunction(CallOpInterface callOp) {
94e07a7fd5SMatthias Springer   SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
95e07a7fd5SMatthias Springer   if (!sym)
96e07a7fd5SMatthias Springer     return nullptr;
97e07a7fd5SMatthias Springer   return dyn_cast_or_null<FuncOp>(
98e07a7fd5SMatthias Springer       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
99e07a7fd5SMatthias Springer }
100e07a7fd5SMatthias Springer 
101e07a7fd5SMatthias Springer /// Get FuncAnalysisState.
102e07a7fd5SMatthias Springer static const FuncAnalysisState &
103e07a7fd5SMatthias Springer getFuncAnalysisState(const AnalysisState &state) {
104e07a7fd5SMatthias Springer   Optional<const FuncAnalysisState *> maybeState =
105e07a7fd5SMatthias Springer       state.getDialectState<FuncAnalysisState>(
106e07a7fd5SMatthias Springer           func::FuncDialect::getDialectNamespace());
107e07a7fd5SMatthias Springer   assert(maybeState.hasValue() && "FuncAnalysisState does not exist");
108e07a7fd5SMatthias Springer   return **maybeState;
109e07a7fd5SMatthias Springer }
110e07a7fd5SMatthias Springer 
111e07a7fd5SMatthias Springer /// Return the state (phase) of analysis of the FuncOp.
112e07a7fd5SMatthias Springer static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
113e07a7fd5SMatthias Springer                                                   FuncOp funcOp) {
114e07a7fd5SMatthias Springer   const FuncAnalysisState &funcState = getFuncAnalysisState(state);
115e07a7fd5SMatthias Springer   auto it = funcState.analyzedFuncOps.find(funcOp);
116e07a7fd5SMatthias Springer   if (it == funcState.analyzedFuncOps.end())
117e07a7fd5SMatthias Springer     return FuncOpAnalysisState::NotAnalyzed;
118e07a7fd5SMatthias Springer   return it->second;
119e07a7fd5SMatthias Springer }
120e07a7fd5SMatthias Springer 
121e07a7fd5SMatthias Springer /// Return the index of the bbArg in the given FuncOp that is equivalent to the
122e07a7fd5SMatthias Springer /// specified return value (if any).
123e07a7fd5SMatthias Springer static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp,
124e07a7fd5SMatthias Springer                                                  const FuncAnalysisState &state,
125e07a7fd5SMatthias Springer                                                  int64_t returnValIdx) {
126e07a7fd5SMatthias Springer   auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
127e07a7fd5SMatthias Springer   if (funcOpIt == state.equivalentFuncArgs.end())
128e07a7fd5SMatthias Springer     // No equivalence info stores for funcOp.
129e07a7fd5SMatthias Springer     return None;
130e07a7fd5SMatthias Springer 
131e07a7fd5SMatthias Springer   auto retValIt = funcOpIt->getSecond().find(returnValIdx);
132e07a7fd5SMatthias Springer   if (retValIt == funcOpIt->getSecond().end())
133e07a7fd5SMatthias Springer     // Return value has no equivalent bbArg.
134e07a7fd5SMatthias Springer     return None;
135e07a7fd5SMatthias Springer 
136e07a7fd5SMatthias Springer   return retValIt->getSecond();
137e07a7fd5SMatthias Springer }
138e07a7fd5SMatthias Springer 
139e07a7fd5SMatthias Springer struct CallOpInterface
140e07a7fd5SMatthias Springer     : public BufferizableOpInterface::ExternalModel<CallOpInterface,
141e07a7fd5SMatthias Springer                                                     func::CallOp> {
142e07a7fd5SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
143e07a7fd5SMatthias Springer                               const AnalysisState &state) const {
144e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
145e07a7fd5SMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
146e07a7fd5SMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
147e07a7fd5SMatthias Springer 
148e07a7fd5SMatthias Springer     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
149e07a7fd5SMatthias Springer     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
150e07a7fd5SMatthias Springer       // FuncOp not analyzed yet. Assume that OpOperand is read.
151e07a7fd5SMatthias Springer       return true;
152e07a7fd5SMatthias Springer 
153e07a7fd5SMatthias Springer     return funcState.readBbArgs.lookup(funcOp).contains(
154e07a7fd5SMatthias Springer         opOperand.getOperandNumber());
155e07a7fd5SMatthias Springer   }
156e07a7fd5SMatthias Springer 
157e07a7fd5SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
158e07a7fd5SMatthias Springer                                const AnalysisState &state) const {
159e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
160e07a7fd5SMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
161e07a7fd5SMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
162e07a7fd5SMatthias Springer 
163e07a7fd5SMatthias Springer     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
164e07a7fd5SMatthias Springer     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
165e07a7fd5SMatthias Springer       // FuncOp not analyzed yet. Assume that OpOperand is written.
166e07a7fd5SMatthias Springer       return true;
167e07a7fd5SMatthias Springer 
168e07a7fd5SMatthias Springer     return funcState.writtenBbArgs.lookup(funcOp).contains(
169e07a7fd5SMatthias Springer         opOperand.getOperandNumber());
170e07a7fd5SMatthias Springer   }
171e07a7fd5SMatthias Springer 
172e07a7fd5SMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
173e07a7fd5SMatthias Springer                                             const AnalysisState &state) const {
174e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
175e07a7fd5SMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
176e07a7fd5SMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
177e07a7fd5SMatthias Springer     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
178e07a7fd5SMatthias Springer     if (getFuncOpAnalysisState(state, funcOp) !=
179e07a7fd5SMatthias Springer         FuncOpAnalysisState::Analyzed) {
180e07a7fd5SMatthias Springer       // FuncOp not analyzed yet. Any OpResult may be aliasing.
181e07a7fd5SMatthias Springer       SmallVector<OpResult> result;
182e07a7fd5SMatthias Springer       for (OpResult opResult : op->getOpResults())
183e07a7fd5SMatthias Springer         if (opResult.getType().isa<TensorType>())
184e07a7fd5SMatthias Springer           result.push_back(opResult);
185e07a7fd5SMatthias Springer       return result;
186e07a7fd5SMatthias Springer     }
187e07a7fd5SMatthias Springer 
188e07a7fd5SMatthias Springer     // Get aliasing results from state.
189e07a7fd5SMatthias Springer     auto aliasingReturnVals =
190e07a7fd5SMatthias Springer         funcState.aliasingReturnVals.lookup(funcOp).lookup(
191e07a7fd5SMatthias Springer             opOperand.getOperandNumber());
192e07a7fd5SMatthias Springer     SmallVector<OpResult> result;
193e07a7fd5SMatthias Springer     for (int64_t resultIdx : aliasingReturnVals)
194e07a7fd5SMatthias Springer       result.push_back(callOp->getOpResult(resultIdx));
195e07a7fd5SMatthias Springer     return result;
196e07a7fd5SMatthias Springer   }
197e07a7fd5SMatthias Springer 
198e07a7fd5SMatthias Springer   SmallVector<OpOperand *>
199e07a7fd5SMatthias Springer   getAliasingOpOperand(Operation *op, OpResult opResult,
200e07a7fd5SMatthias Springer                        const AnalysisState &state) const {
201e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
202e07a7fd5SMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
203e07a7fd5SMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
204e07a7fd5SMatthias Springer     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
205e07a7fd5SMatthias Springer     if (getFuncOpAnalysisState(state, funcOp) !=
206e07a7fd5SMatthias Springer         FuncOpAnalysisState::Analyzed) {
207e07a7fd5SMatthias Springer       // FuncOp not analyzed yet. Any OpOperand may be aliasing.
208e07a7fd5SMatthias Springer       SmallVector<OpOperand *> result;
209e07a7fd5SMatthias Springer       for (OpOperand &opOperand : op->getOpOperands())
210e07a7fd5SMatthias Springer         if (opOperand.get().getType().isa<TensorType>())
211e07a7fd5SMatthias Springer           result.push_back(&opOperand);
212e07a7fd5SMatthias Springer       return result;
213e07a7fd5SMatthias Springer     }
214e07a7fd5SMatthias Springer 
215e07a7fd5SMatthias Springer     // Get aliasing bbArgs from state.
216e07a7fd5SMatthias Springer     auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup(
217e07a7fd5SMatthias Springer         opResult.getResultNumber());
218e07a7fd5SMatthias Springer     SmallVector<OpOperand *> result;
219e07a7fd5SMatthias Springer     for (int64_t bbArgIdx : aliasingFuncArgs)
220e07a7fd5SMatthias Springer       result.push_back(&callOp->getOpOperand(bbArgIdx));
221e07a7fd5SMatthias Springer     return result;
222e07a7fd5SMatthias Springer   }
223e07a7fd5SMatthias Springer 
224e07a7fd5SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
225e07a7fd5SMatthias Springer                                 const AnalysisState &state) const {
226e07a7fd5SMatthias Springer     return BufferRelation::Equivalent;
227e07a7fd5SMatthias Springer   }
228e07a7fd5SMatthias Springer 
229e07a7fd5SMatthias Springer   /// All function arguments are writable. It is the responsibility of the
230e07a7fd5SMatthias Springer   /// CallOp to insert buffer copies where necessary.
231e07a7fd5SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
232e07a7fd5SMatthias Springer                           BufferizationState &state) const {
233e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
234e07a7fd5SMatthias Springer     unsigned numResults = callOp.getNumResults();
235e07a7fd5SMatthias Springer     unsigned numOperands = callOp->getNumOperands();
236e07a7fd5SMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
237e07a7fd5SMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
238e07a7fd5SMatthias Springer     FunctionType funcType = funcOp.getFunctionType();
239e07a7fd5SMatthias Springer     const FuncAnalysisState &funcState =
240e07a7fd5SMatthias Springer         getFuncAnalysisState(state.getAnalysisState());
241e07a7fd5SMatthias Springer     const OneShotBufferizationOptions &options =
242e07a7fd5SMatthias Springer         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
243e07a7fd5SMatthias Springer 
244e07a7fd5SMatthias Springer     // Result types of the bufferized CallOp.
245e07a7fd5SMatthias Springer     SmallVector<Type> resultTypes;
246e07a7fd5SMatthias Springer     // Replacement values for the existing CallOp. These are usually the results
247e07a7fd5SMatthias Springer     // of the bufferized CallOp, unless a tensor result folds onto an operand.
248e07a7fd5SMatthias Springer     SmallVector<Value> replacementValues(numResults, Value());
249e07a7fd5SMatthias Springer     // For non-tensor results: A mapping from return val indices of the old
250e07a7fd5SMatthias Springer     // CallOp to return val indices of the bufferized CallOp.
251e07a7fd5SMatthias Springer     SmallVector<Optional<unsigned>> retValMapping(numResults, None);
252e07a7fd5SMatthias Springer     // Operands of the bufferized CallOp.
253e07a7fd5SMatthias Springer     SmallVector<Value> newOperands(numOperands, Value());
254e07a7fd5SMatthias Springer 
255e07a7fd5SMatthias Springer     // Based on previously gathered equivalence information, we know if a
256e07a7fd5SMatthias Springer     // tensor result folds onto an operand. These are the only tensor value
257e07a7fd5SMatthias Springer     // results that are supported at the moment.
258e07a7fd5SMatthias Springer     //
259e07a7fd5SMatthias Springer     // For tensors return values that do not fold onto an operand, additional
260e07a7fd5SMatthias Springer     // work is needed (TODO) to either:
261e07a7fd5SMatthias Springer     // * hoist a result into an inplaceable operand or
262e07a7fd5SMatthias Springer     // * devise a better representation to truly return a buffer.
263e07a7fd5SMatthias Springer     //
264e07a7fd5SMatthias Springer     // Note: If a function has no body, no equivalence information is
265e07a7fd5SMatthias Springer     // available. Consequently, a tensor return value cannot be proven to fold
266e07a7fd5SMatthias Springer     // onto a FuncOp bbArg, so calls to such functions are not bufferizable at
267e07a7fd5SMatthias Springer     // the moment.
268e07a7fd5SMatthias Springer 
269e07a7fd5SMatthias Springer     // 1. Compute the result types of the new CallOp. Tensor results that are
270e07a7fd5SMatthias Springer     // equivalent to a FuncOp bbArg are no longer returned.
271e07a7fd5SMatthias Springer     for (const auto &it : llvm::enumerate(callOp.getResultTypes())) {
272e07a7fd5SMatthias Springer       unsigned returnValIdx = it.index();
273e07a7fd5SMatthias Springer       Type returnType = it.value();
274e07a7fd5SMatthias Springer       if (!returnType.isa<TensorType>()) {
275e07a7fd5SMatthias Springer         // Non-tensor values are returned.
276e07a7fd5SMatthias Springer         retValMapping[returnValIdx] = resultTypes.size();
277e07a7fd5SMatthias Springer         resultTypes.push_back(returnType);
278e07a7fd5SMatthias Springer         continue;
279e07a7fd5SMatthias Springer       }
280e07a7fd5SMatthias Springer 
281e8f7d019SAlexander Belyaev       if (options.dropEquivalentFuncResults) {
282e07a7fd5SMatthias Springer         if (Optional<int64_t> bbArgIdx =
283e07a7fd5SMatthias Springer                 getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) {
284e07a7fd5SMatthias Springer           // Return operands that are equivalent to some bbArg, are not
285e07a7fd5SMatthias Springer           // returned.
286e07a7fd5SMatthias Springer           FailureOr<Value> bufferOrFailure =
287e07a7fd5SMatthias Springer               state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx));
288e07a7fd5SMatthias Springer           if (failed(bufferOrFailure))
289e07a7fd5SMatthias Springer             return failure();
290e07a7fd5SMatthias Springer           replacementValues[returnValIdx] = *bufferOrFailure;
291e07a7fd5SMatthias Springer           newOperands[*bbArgIdx] = *bufferOrFailure;
292e07a7fd5SMatthias Springer           continue;
293e07a7fd5SMatthias Springer         }
294e8f7d019SAlexander Belyaev       }
295e07a7fd5SMatthias Springer 
296e07a7fd5SMatthias Springer       if (!options.allowReturnAllocs)
297e07a7fd5SMatthias Springer         return callOp->emitError(
298e07a7fd5SMatthias Springer             "call to FuncOp that returns non-equivalent tensors not supported");
299e07a7fd5SMatthias Springer 
300e07a7fd5SMatthias Springer       // Returning a memref. This memref is not equivalent to any bbArg. It is
301e07a7fd5SMatthias Springer       // likely a newly allocated buffer. We may want to hoist such allocations
302e07a7fd5SMatthias Springer       // to the call site in the future.
303e07a7fd5SMatthias Springer       retValMapping[returnValIdx] = resultTypes.size();
304e07a7fd5SMatthias Springer       resultTypes.push_back(funcType.getResult(resultTypes.size()));
305e07a7fd5SMatthias Springer     }
306e07a7fd5SMatthias Springer 
307e07a7fd5SMatthias Springer     // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
308e07a7fd5SMatthias Springer     for (OpOperand &opOperand : callOp->getOpOperands()) {
309e07a7fd5SMatthias Springer       unsigned idx = opOperand.getOperandNumber();
310e07a7fd5SMatthias Springer       Value tensorOperand = opOperand.get();
311e07a7fd5SMatthias Springer 
312e07a7fd5SMatthias Springer       // Non-tensor operands are just copied.
313e07a7fd5SMatthias Springer       if (!tensorOperand.getType().isa<TensorType>()) {
314e07a7fd5SMatthias Springer         newOperands[idx] = tensorOperand;
315e07a7fd5SMatthias Springer         continue;
316e07a7fd5SMatthias Springer       }
317e07a7fd5SMatthias Springer 
318e07a7fd5SMatthias Springer       // Retrieve buffers for tensor operands. Tensor operand buffers, who's
319e07a7fd5SMatthias Springer       // corresponding FuncOp bbArgs are equivalent to a returned tensor, were
320e07a7fd5SMatthias Springer       // already stored in `newOperands` during Step 1.
321e07a7fd5SMatthias Springer       Value buffer = newOperands[idx];
322e07a7fd5SMatthias Springer       if (!buffer) {
323e07a7fd5SMatthias Springer         FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, opOperand);
324e07a7fd5SMatthias Springer         if (failed(bufferOrFailure))
325e07a7fd5SMatthias Springer           return failure();
326e07a7fd5SMatthias Springer         buffer = *bufferOrFailure;
327e07a7fd5SMatthias Springer       }
328e07a7fd5SMatthias Springer 
329e07a7fd5SMatthias Springer       // Caller / callee type mismatch is handled with a CastOp.
330e07a7fd5SMatthias Springer       auto memRefType = funcType.getInput(idx);
331e07a7fd5SMatthias Springer       // Since we don't yet have a clear layout story, to_memref may
332e07a7fd5SMatthias Springer       // conservatively turn tensors into more dynamic memref than necessary.
333e07a7fd5SMatthias Springer       // If the memref type of the callee fails, introduce an extra memref.cast
334e07a7fd5SMatthias Springer       // that will either canonicalize away or fail compilation until we can do
335e07a7fd5SMatthias Springer       // something better.
336e07a7fd5SMatthias Springer       if (buffer.getType() != memRefType) {
337e07a7fd5SMatthias Springer         assert(
338e07a7fd5SMatthias Springer             memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
339e07a7fd5SMatthias Springer             "CallOp::bufferize: cast incompatible");
340e07a7fd5SMatthias Springer         Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
341e07a7fd5SMatthias Springer                                                            memRefType, buffer);
342e07a7fd5SMatthias Springer         buffer = castBuffer;
343e07a7fd5SMatthias Springer       }
344e07a7fd5SMatthias Springer       newOperands[idx] = buffer;
345e07a7fd5SMatthias Springer     }
346e07a7fd5SMatthias Springer 
347e07a7fd5SMatthias Springer     // 3. Create the new CallOp.
348e07a7fd5SMatthias Springer     Operation *newCallOp = rewriter.create<func::CallOp>(
349e07a7fd5SMatthias Springer         callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
350e07a7fd5SMatthias Springer     newCallOp->setAttrs(callOp->getAttrs());
351e07a7fd5SMatthias Springer     // Get replacement values for non-tensor / non-equivalent results.
352e07a7fd5SMatthias Springer     for (unsigned i = 0; i < replacementValues.size(); ++i) {
353e07a7fd5SMatthias Springer       if (replacementValues[i])
354e07a7fd5SMatthias Springer         continue;
355e07a7fd5SMatthias Springer       replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
356e07a7fd5SMatthias Springer     }
357e07a7fd5SMatthias Springer 
358e07a7fd5SMatthias Springer     // 4. Replace the old op with the new op.
359e07a7fd5SMatthias Springer     replaceOpWithBufferizedValues(rewriter, callOp, replacementValues);
360e07a7fd5SMatthias Springer 
361e07a7fd5SMatthias Springer     return success();
362e07a7fd5SMatthias Springer   }
363e07a7fd5SMatthias Springer };
364e07a7fd5SMatthias Springer 
365e07a7fd5SMatthias Springer struct ReturnOpInterface
366e07a7fd5SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
367e07a7fd5SMatthias Springer                                                     func::ReturnOp> {
368e07a7fd5SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
369e07a7fd5SMatthias Springer                               const AnalysisState &state) const {
370e07a7fd5SMatthias Springer     return true;
371e07a7fd5SMatthias Springer   }
372e07a7fd5SMatthias Springer 
373e07a7fd5SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
374e07a7fd5SMatthias Springer                                const AnalysisState &state) const {
375e07a7fd5SMatthias Springer     return false;
376e07a7fd5SMatthias Springer   }
377e07a7fd5SMatthias Springer 
378e07a7fd5SMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
379e07a7fd5SMatthias Springer                                             const AnalysisState &state) const {
380e07a7fd5SMatthias Springer     return {};
381e07a7fd5SMatthias Springer   }
382e07a7fd5SMatthias Springer 
383e07a7fd5SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
384e07a7fd5SMatthias Springer                           BufferizationState &state) const {
385e07a7fd5SMatthias Springer #ifndef NDEBUG
386e07a7fd5SMatthias Springer     auto returnOp = cast<func::ReturnOp>(op);
387e07a7fd5SMatthias Springer     assert(isa<FuncOp>(returnOp->getParentOp()) &&
388e07a7fd5SMatthias Springer            "only support FuncOp parent for ReturnOp");
389e07a7fd5SMatthias Springer #endif // NDEBUG
390e07a7fd5SMatthias Springer 
391e07a7fd5SMatthias Springer     // ReturnOps are bufferized as part of FuncOps.
392e07a7fd5SMatthias Springer     return failure();
393e07a7fd5SMatthias Springer   }
394e07a7fd5SMatthias Springer };
395e07a7fd5SMatthias Springer 
396e07a7fd5SMatthias Springer struct FuncOpInterface
397e07a7fd5SMatthias Springer     : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
398*f287da8aSMatthias Springer   /// Rewrite function bbArgs and return values into buffer form. This function
399*f287da8aSMatthias Springer   /// bufferizes the function signature and the ReturnOp. When the entire
400*f287da8aSMatthias Springer   /// function body has been bufferized, function return types can be switched
401*f287da8aSMatthias Springer   /// to more concise memref types as part of `foldMemRefCasts`.
402e07a7fd5SMatthias Springer   ///
403e07a7fd5SMatthias Springer   /// When a tensor function argument is known to be equivalent to a tensor
404e07a7fd5SMatthias Springer   /// result, it is dropped from the return values.
405e07a7fd5SMatthias Springer   ///
406e07a7fd5SMatthias Springer   /// All function bbArgs are writable unless they are explicitly marked as
407e07a7fd5SMatthias Springer   /// read-only. Callers must insert copies when needed.
408e07a7fd5SMatthias Springer   ///
409e07a7fd5SMatthias Springer   /// Note: Returning a memref is possible, but corresponding CallOp
410e07a7fd5SMatthias Springer   /// bufferizations fail unless `allowReturnAllocs`.
411e07a7fd5SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
412e07a7fd5SMatthias Springer                           BufferizationState &state) const {
413e07a7fd5SMatthias Springer     auto funcOp = cast<FuncOp>(op);
414e07a7fd5SMatthias Springer     FunctionType funcType = funcOp.getFunctionType();
415e07a7fd5SMatthias Springer     const FuncAnalysisState &funcState =
416e07a7fd5SMatthias Springer         getFuncAnalysisState(state.getAnalysisState());
417e8f7d019SAlexander Belyaev     const OneShotBufferizationOptions &options =
418e8f7d019SAlexander Belyaev         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
419e07a7fd5SMatthias Springer 
420e07a7fd5SMatthias Springer     // Construct the bufferized function type.
421e07a7fd5SMatthias Springer     SmallVector<Type> argTypes;
422e07a7fd5SMatthias Springer     for (const auto &it : llvm::enumerate(funcType.getInputs())) {
423e07a7fd5SMatthias Springer       Type argType = it.value();
424e07a7fd5SMatthias Springer       if (auto tensorType = argType.dyn_cast<TensorType>()) {
425e07a7fd5SMatthias Springer         argTypes.push_back(
426e07a7fd5SMatthias Springer             getBufferizedFunctionArgType(funcOp, it.index(), options));
427e07a7fd5SMatthias Springer         continue;
428e07a7fd5SMatthias Springer       }
429e07a7fd5SMatthias Springer       argTypes.push_back(argType);
430e07a7fd5SMatthias Springer     }
431e07a7fd5SMatthias Springer 
432e07a7fd5SMatthias Springer     // Bodiless functions are assumed opaque and we cannot know the
433e07a7fd5SMatthias Springer     // bufferization contract they want to enforce. As a consequence, only
434e07a7fd5SMatthias Springer     // support functions that don't return any tensors atm.
435e07a7fd5SMatthias Springer     if (funcOp.getBody().empty()) {
436e07a7fd5SMatthias Springer       SmallVector<Type> retTypes;
437e07a7fd5SMatthias Springer       for (Type resultType : funcType.getResults()) {
438e07a7fd5SMatthias Springer         if (resultType.isa<TensorType>())
439e07a7fd5SMatthias Springer           return funcOp->emitError() << "cannot bufferize bodiless function "
440e07a7fd5SMatthias Springer                                      << "that returns a tensor";
441e07a7fd5SMatthias Springer         retTypes.push_back(resultType);
442e07a7fd5SMatthias Springer       }
443e07a7fd5SMatthias Springer       funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
444e07a7fd5SMatthias Springer       return success();
445e07a7fd5SMatthias Springer     }
446e07a7fd5SMatthias Springer 
447e07a7fd5SMatthias Springer     // TODO: Support functions with multiple returns.
448e07a7fd5SMatthias Springer     func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
449e07a7fd5SMatthias Springer     assert(returnOp && "expected func with single return op");
450*f287da8aSMatthias Springer     Location loc = returnOp.getLoc();
451e07a7fd5SMatthias Springer 
452e07a7fd5SMatthias Springer     // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg.
453e07a7fd5SMatthias Springer     Block &frontBlock = funcOp.getBody().front();
454e07a7fd5SMatthias Springer     for (BlockArgument &bbArg : frontBlock.getArguments()) {
455e07a7fd5SMatthias Springer       auto tensorType = bbArg.getType().dyn_cast<TensorType>();
456e07a7fd5SMatthias Springer       // Non-tensor types stay the same.
457e07a7fd5SMatthias Springer       if (!tensorType)
458e07a7fd5SMatthias Springer         continue;
459e07a7fd5SMatthias Springer 
460e07a7fd5SMatthias Springer       // Collect all uses of the bbArg.
461e07a7fd5SMatthias Springer       SmallVector<OpOperand *> bbArgUses;
462e07a7fd5SMatthias Springer       for (OpOperand &use : bbArg.getUses())
463e07a7fd5SMatthias Springer         bbArgUses.push_back(&use);
464e07a7fd5SMatthias Springer 
465e07a7fd5SMatthias Springer       // Change the bbArg type to memref.
466e07a7fd5SMatthias Springer       Type memrefType =
467e07a7fd5SMatthias Springer           getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options);
468e07a7fd5SMatthias Springer       bbArg.setType(memrefType);
469e07a7fd5SMatthias Springer 
470e07a7fd5SMatthias Springer       // Replace all uses of the original tensor bbArg.
471e07a7fd5SMatthias Springer       rewriter.setInsertionPointToStart(&frontBlock);
472e07a7fd5SMatthias Springer       if (!bbArgUses.empty()) {
473e07a7fd5SMatthias Springer         // Insert to_tensor because the remaining function body has not been
474e07a7fd5SMatthias Springer         // bufferized yet.
475e07a7fd5SMatthias Springer         Value toTensorOp =
476e07a7fd5SMatthias Springer             rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg);
477e07a7fd5SMatthias Springer         for (OpOperand *use : bbArgUses)
478e07a7fd5SMatthias Springer           use->set(toTensorOp);
479e07a7fd5SMatthias Springer       }
480e07a7fd5SMatthias Springer     }
481e07a7fd5SMatthias Springer 
482e07a7fd5SMatthias Springer     // 2. For each result, keep track of which inplace argument it reuses.
483e07a7fd5SMatthias Springer     SmallVector<Value> returnValues;
484e07a7fd5SMatthias Springer     for (OpOperand &returnOperand : returnOp->getOpOperands()) {
485e07a7fd5SMatthias Springer       Value returnVal = returnOperand.get();
486*f287da8aSMatthias Springer       auto tensorType = returnVal.getType().dyn_cast<TensorType>();
487*f287da8aSMatthias Springer       rewriter.setInsertionPoint(returnOp);
488e07a7fd5SMatthias Springer 
489e07a7fd5SMatthias Springer       // If not a tensor type just forward it.
490*f287da8aSMatthias Springer       if (!tensorType) {
491e07a7fd5SMatthias Springer         returnValues.push_back(returnVal);
492e07a7fd5SMatthias Springer         continue;
493e07a7fd5SMatthias Springer       }
494e07a7fd5SMatthias Springer 
495e07a7fd5SMatthias Springer       // If return operand is equivalent to some bbArg, no need to return it.
496e8f7d019SAlexander Belyaev       if (options.dropEquivalentFuncResults) {
497e07a7fd5SMatthias Springer         if (Optional<int64_t> equivBbArgIdx = getEquivalentFuncArgIdx(
498e07a7fd5SMatthias Springer                 funcOp, funcState, returnOperand.getOperandNumber())) {
499*f287da8aSMatthias Springer           // TODO: Use memref type with fully dynamic layout map and add folder
500*f287da8aSMatthias Springer           // for memref.cast + memref.copy.
501e07a7fd5SMatthias Springer           Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
502*f287da8aSMatthias Springer               loc, getMemRefType(tensorType, options), returnVal);
503e07a7fd5SMatthias Springer           BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx);
504e07a7fd5SMatthias Springer           // Note: This copy will fold away. It must be inserted here to ensure
505e07a7fd5SMatthias Springer           // that `returnVal` still has at least one use and does not fold away.
506e07a7fd5SMatthias Springer           if (failed(
507248e113eSMatthias Springer                   options.createMemCpy(rewriter, loc, toMemrefOp, equivBbArg)))
508e07a7fd5SMatthias Springer             return funcOp->emitError("could not generate copy for bbArg");
509e07a7fd5SMatthias Springer           continue;
510e07a7fd5SMatthias Springer         }
511e8f7d019SAlexander Belyaev       }
512e07a7fd5SMatthias Springer 
513*f287da8aSMatthias Springer       BaseMemRefType resultType;
514*f287da8aSMatthias Springer       if (options.functionBoundaryTypeConversion ==
515*f287da8aSMatthias Springer           BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
516*f287da8aSMatthias Springer         resultType = getMemRefTypeWithStaticIdentityLayout(tensorType);
517*f287da8aSMatthias Springer       } else {
518*f287da8aSMatthias Springer         // Note: If `InferLayoutMap`, cast are later folded away.
519*f287da8aSMatthias Springer         resultType = getMemRefTypeWithFullyDynamicLayout(tensorType);
520*f287da8aSMatthias Springer       }
521*f287da8aSMatthias Springer       Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
522*f287da8aSMatthias Springer           loc, resultType, returnVal);
523*f287da8aSMatthias Springer       returnValues.push_back(toMemrefOp);
524e07a7fd5SMatthias Springer     }
525e07a7fd5SMatthias Springer 
526e07a7fd5SMatthias Springer     // 3. Rewrite the terminator without the in-place bufferizable values.
527e07a7fd5SMatthias Springer     returnOp.operandsMutable().assign(returnValues);
528e07a7fd5SMatthias Springer 
529e07a7fd5SMatthias Springer     // 4. Rewrite the FuncOp type to buffer form.
530e07a7fd5SMatthias Springer     funcOp.setType(FunctionType::get(op->getContext(), argTypes,
531e07a7fd5SMatthias Springer                                      ValueRange(returnValues).getTypes()));
532e07a7fd5SMatthias Springer 
533e07a7fd5SMatthias Springer     return success();
534e07a7fd5SMatthias Springer   }
535e07a7fd5SMatthias Springer 
536e07a7fd5SMatthias Springer   /// Return `true` if the given function argument is writable.
537e07a7fd5SMatthias Springer   bool isWritable(Operation *op, Value value,
538e07a7fd5SMatthias Springer                   const AnalysisState &state) const {
539e07a7fd5SMatthias Springer     auto funcOp = cast<FuncOp>(op);
540e07a7fd5SMatthias Springer     BlockArgument bbArg = value.dyn_cast<BlockArgument>();
541e07a7fd5SMatthias Springer     assert(bbArg && "expected BlockArgument");
542e07a7fd5SMatthias Springer 
543e07a7fd5SMatthias Springer     // "bufferization.writable" overrides other writability decisions. This is
544e07a7fd5SMatthias Springer     // currently used for testing only.
545e07a7fd5SMatthias Springer     if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>(
546e07a7fd5SMatthias Springer             bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName))
547e07a7fd5SMatthias Springer       return writable.getValue();
548e07a7fd5SMatthias Springer 
549e07a7fd5SMatthias Springer     // All function arguments are writable by default.
550e07a7fd5SMatthias Springer     return true;
551e07a7fd5SMatthias Springer   }
552e07a7fd5SMatthias Springer 
553e07a7fd5SMatthias Springer   bool isAllocationHoistingBarrier(Operation *op) const { return true; }
554e07a7fd5SMatthias Springer };
555e07a7fd5SMatthias Springer 
556e07a7fd5SMatthias Springer } // namespace func_ext
557e07a7fd5SMatthias Springer } // namespace bufferization
558e07a7fd5SMatthias Springer } // namespace mlir
559e07a7fd5SMatthias Springer 
560e07a7fd5SMatthias Springer void mlir::bufferization::func_ext::
561e07a7fd5SMatthias Springer     registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
562e07a7fd5SMatthias Springer   registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
563e07a7fd5SMatthias Springer     func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
564e07a7fd5SMatthias Springer     func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
565e07a7fd5SMatthias Springer     func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
566e07a7fd5SMatthias Springer   });
567e07a7fd5SMatthias Springer }
568