1e07a7fd5SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2e07a7fd5SMatthias Springer //
3e07a7fd5SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e07a7fd5SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5e07a7fd5SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e07a7fd5SMatthias Springer //
7e07a7fd5SMatthias Springer //===----------------------------------------------------------------------===//
8e07a7fd5SMatthias Springer 
9e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
10e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13e07a7fd5SMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h"
14e07a7fd5SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
15e07a7fd5SMatthias Springer #include "mlir/IR/Dialect.h"
16e07a7fd5SMatthias Springer #include "mlir/IR/Operation.h"
17e07a7fd5SMatthias Springer 
18e07a7fd5SMatthias Springer namespace mlir {
19e07a7fd5SMatthias Springer namespace bufferization {
20e07a7fd5SMatthias Springer namespace func_ext {
21e07a7fd5SMatthias Springer 
startFunctionAnalysis(FuncOp funcOp)22e07a7fd5SMatthias Springer void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
23e07a7fd5SMatthias Springer   analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
24e07a7fd5SMatthias Springer   auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
25e07a7fd5SMatthias Springer   auto createdAliasingOperands =
26e07a7fd5SMatthias Springer       aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping());
27e07a7fd5SMatthias Springer   auto createdAliasingResults =
28e07a7fd5SMatthias Springer       aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping());
29e07a7fd5SMatthias Springer   auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet());
30e07a7fd5SMatthias Springer   auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet());
31e07a7fd5SMatthias Springer   (void)createdEquiv;
32e07a7fd5SMatthias Springer   (void)createdAliasingOperands;
33e07a7fd5SMatthias Springer   (void)createdAliasingResults;
34e07a7fd5SMatthias Springer   (void)createdRead;
35e07a7fd5SMatthias Springer   (void)createdWritten;
36e07a7fd5SMatthias Springer #ifndef NDEBUG
37e07a7fd5SMatthias Springer   assert(createdEquiv.second && "equivalence info exists already");
38e07a7fd5SMatthias Springer   assert(createdAliasingOperands.second && "aliasing info exists already");
39e07a7fd5SMatthias Springer   assert(createdAliasingResults.second && "aliasing info exists already");
40e07a7fd5SMatthias Springer   assert(createdRead.second && "bbarg access info exists already");
41e07a7fd5SMatthias Springer   assert(createdWritten.second && "bbarg access info exists already");
42e07a7fd5SMatthias Springer #endif // NDEBUG
43e07a7fd5SMatthias Springer }
44e07a7fd5SMatthias Springer 
45e07a7fd5SMatthias Springer /// Return the unique ReturnOp that terminates `funcOp`.
46e07a7fd5SMatthias Springer /// Return nullptr if there is no such unique ReturnOp.
getAssumedUniqueReturnOp(FuncOp funcOp)47e07a7fd5SMatthias Springer static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
48e07a7fd5SMatthias Springer   func::ReturnOp returnOp;
49e07a7fd5SMatthias Springer   for (Block &b : funcOp.getBody()) {
50e07a7fd5SMatthias Springer     if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
51e07a7fd5SMatthias Springer       if (returnOp)
52e07a7fd5SMatthias Springer         return nullptr;
53e07a7fd5SMatthias Springer       returnOp = candidateOp;
54e07a7fd5SMatthias Springer     }
55e07a7fd5SMatthias Springer   }
56e07a7fd5SMatthias Springer   return returnOp;
57e07a7fd5SMatthias Springer }
58e07a7fd5SMatthias Springer 
59e07a7fd5SMatthias Springer /// Return the index-th bufferized function argument type. This assumes that the
60e07a7fd5SMatthias Springer /// specified argument is a tensor. If the tensor is ranked, a layout map may be
61f287da8aSMatthias Springer /// specified by the user. If no layout map is specified, the default layout map
62f287da8aSMatthias Springer /// (as per `options.functionBoundaryTypeConversion`) is used.
63e07a7fd5SMatthias Springer static BaseMemRefType
getBufferizedFunctionArgType(FuncOp funcOp,int64_t index,const BufferizationOptions & options)64e07a7fd5SMatthias Springer getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
65e07a7fd5SMatthias Springer                              const BufferizationOptions &options) {
66e07a7fd5SMatthias Springer   auto tensorType =
67e07a7fd5SMatthias Springer       funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
68e07a7fd5SMatthias Springer   assert(tensorType && "expected TensorType");
69f287da8aSMatthias Springer 
70f287da8aSMatthias Springer   BaseMemRefType memrefType;
71f287da8aSMatthias Springer   if (options.functionBoundaryTypeConversion ==
72f287da8aSMatthias Springer       BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
73f287da8aSMatthias Springer     memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
74f287da8aSMatthias Springer   } else {
75f287da8aSMatthias Springer     // Note: Layout maps on function parameters cannot be inferred. The best we
76f287da8aSMatthias Springer     // can do at the moment is "fully dynamic".
77f287da8aSMatthias Springer     memrefType = getMemRefTypeWithFullyDynamicLayout(tensorType);
78f287da8aSMatthias Springer   }
79e07a7fd5SMatthias Springer 
80e07a7fd5SMatthias Springer   auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
81e07a7fd5SMatthias Springer       index, BufferizationDialect::kBufferLayoutAttrName);
82e07a7fd5SMatthias Springer   if (!layoutAttr)
83e07a7fd5SMatthias Springer     return memrefType;
84e07a7fd5SMatthias Springer 
85e07a7fd5SMatthias Springer   auto rankedMemrefType = memrefType.dyn_cast<MemRefType>();
86e07a7fd5SMatthias Springer   assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
87e07a7fd5SMatthias Springer   return MemRefType::get(
88e07a7fd5SMatthias Springer       rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
89e07a7fd5SMatthias Springer       layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt());
90e07a7fd5SMatthias Springer }
91e07a7fd5SMatthias Springer 
92e07a7fd5SMatthias Springer /// Return the FuncOp called by `callOp`.
getCalledFunction(CallOpInterface callOp)93e07a7fd5SMatthias Springer static FuncOp getCalledFunction(CallOpInterface callOp) {
94e07a7fd5SMatthias Springer   SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
95e07a7fd5SMatthias Springer   if (!sym)
96e07a7fd5SMatthias Springer     return nullptr;
97e07a7fd5SMatthias Springer   return dyn_cast_or_null<FuncOp>(
98e07a7fd5SMatthias Springer       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
99e07a7fd5SMatthias Springer }
100e07a7fd5SMatthias Springer 
101e07a7fd5SMatthias Springer /// Get FuncAnalysisState.
102e07a7fd5SMatthias Springer static const FuncAnalysisState &
getFuncAnalysisState(const AnalysisState & state)103e07a7fd5SMatthias Springer getFuncAnalysisState(const AnalysisState &state) {
104e07a7fd5SMatthias Springer   Optional<const FuncAnalysisState *> maybeState =
105e07a7fd5SMatthias Springer       state.getDialectState<FuncAnalysisState>(
106e07a7fd5SMatthias Springer           func::FuncDialect::getDialectNamespace());
1075413bf1bSKazu Hirata   assert(maybeState && "FuncAnalysisState does not exist");
108e07a7fd5SMatthias Springer   return **maybeState;
109e07a7fd5SMatthias Springer }
110e07a7fd5SMatthias Springer 
111e07a7fd5SMatthias Springer /// Return the state (phase) of analysis of the FuncOp.
getFuncOpAnalysisState(const AnalysisState & state,FuncOp funcOp)112e07a7fd5SMatthias Springer static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
113e07a7fd5SMatthias Springer                                                   FuncOp funcOp) {
114cd80617aSMatthias Springer   Optional<const FuncAnalysisState *> maybeState =
115cd80617aSMatthias Springer       state.getDialectState<FuncAnalysisState>(
116cd80617aSMatthias Springer           func::FuncDialect::getDialectNamespace());
117491d2701SKazu Hirata   if (!maybeState.has_value())
118cd80617aSMatthias Springer     return FuncOpAnalysisState::NotAnalyzed;
119*c27d8152SKazu Hirata   const auto &analyzedFuncOps = maybeState.value()->analyzedFuncOps;
120cd80617aSMatthias Springer   auto it = analyzedFuncOps.find(funcOp);
121cd80617aSMatthias Springer   if (it == analyzedFuncOps.end())
122e07a7fd5SMatthias Springer     return FuncOpAnalysisState::NotAnalyzed;
123e07a7fd5SMatthias Springer   return it->second;
124e07a7fd5SMatthias Springer }
125e07a7fd5SMatthias Springer 
126e07a7fd5SMatthias Springer /// Return the index of the bbArg in the given FuncOp that is equivalent to the
127e07a7fd5SMatthias Springer /// specified return value (if any).
getEquivalentFuncArgIdx(FuncOp funcOp,const FuncAnalysisState & state,int64_t returnValIdx)128e07a7fd5SMatthias Springer static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp,
129e07a7fd5SMatthias Springer                                                  const FuncAnalysisState &state,
130e07a7fd5SMatthias Springer                                                  int64_t returnValIdx) {
131e07a7fd5SMatthias Springer   auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
132e07a7fd5SMatthias Springer   if (funcOpIt == state.equivalentFuncArgs.end())
133e07a7fd5SMatthias Springer     // No equivalence info stores for funcOp.
134e07a7fd5SMatthias Springer     return None;
135e07a7fd5SMatthias Springer 
136e07a7fd5SMatthias Springer   auto retValIt = funcOpIt->getSecond().find(returnValIdx);
137e07a7fd5SMatthias Springer   if (retValIt == funcOpIt->getSecond().end())
138e07a7fd5SMatthias Springer     // Return value has no equivalent bbArg.
139e07a7fd5SMatthias Springer     return None;
140e07a7fd5SMatthias Springer 
141e07a7fd5SMatthias Springer   return retValIt->getSecond();
142e07a7fd5SMatthias Springer }
143e07a7fd5SMatthias Springer 
144e07a7fd5SMatthias Springer struct CallOpInterface
145e07a7fd5SMatthias Springer     : public BufferizableOpInterface::ExternalModel<CallOpInterface,
146e07a7fd5SMatthias Springer                                                     func::CallOp> {
bufferizesToMemoryReadmlir::bufferization::func_ext::CallOpInterface147e07a7fd5SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
148e07a7fd5SMatthias Springer                               const AnalysisState &state) const {
149e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
150e07a7fd5SMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
151e07a7fd5SMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
152e07a7fd5SMatthias Springer 
153e07a7fd5SMatthias Springer     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
154e07a7fd5SMatthias Springer       // FuncOp not analyzed yet. Assume that OpOperand is read.
155e07a7fd5SMatthias Springer       return true;
156e07a7fd5SMatthias Springer 
157cd80617aSMatthias Springer     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
158e07a7fd5SMatthias Springer     return funcState.readBbArgs.lookup(funcOp).contains(
159e07a7fd5SMatthias Springer         opOperand.getOperandNumber());
160e07a7fd5SMatthias Springer   }
161e07a7fd5SMatthias Springer 
bufferizesToMemoryWritemlir::bufferization::func_ext::CallOpInterface162e07a7fd5SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
163e07a7fd5SMatthias Springer                                const AnalysisState &state) const {
164e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
165e07a7fd5SMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
166e07a7fd5SMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
167e07a7fd5SMatthias Springer 
168e07a7fd5SMatthias Springer     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
169e07a7fd5SMatthias Springer       // FuncOp not analyzed yet. Assume that OpOperand is written.
170e07a7fd5SMatthias Springer       return true;
171e07a7fd5SMatthias Springer 
172cd80617aSMatthias Springer     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
173e07a7fd5SMatthias Springer     return funcState.writtenBbArgs.lookup(funcOp).contains(
174e07a7fd5SMatthias Springer         opOperand.getOperandNumber());
175e07a7fd5SMatthias Springer   }
176e07a7fd5SMatthias Springer 
getAliasingOpResultmlir::bufferization::func_ext::CallOpInterface177e07a7fd5SMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
178e07a7fd5SMatthias Springer                                             const AnalysisState &state) const {
179e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
180e07a7fd5SMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
181e07a7fd5SMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
182e07a7fd5SMatthias Springer     if (getFuncOpAnalysisState(state, funcOp) !=
183e07a7fd5SMatthias Springer         FuncOpAnalysisState::Analyzed) {
184e07a7fd5SMatthias Springer       // FuncOp not analyzed yet. Any OpResult may be aliasing.
185e07a7fd5SMatthias Springer       SmallVector<OpResult> result;
186e07a7fd5SMatthias Springer       for (OpResult opResult : op->getOpResults())
187e07a7fd5SMatthias Springer         if (opResult.getType().isa<TensorType>())
188e07a7fd5SMatthias Springer           result.push_back(opResult);
189e07a7fd5SMatthias Springer       return result;
190e07a7fd5SMatthias Springer     }
191e07a7fd5SMatthias Springer 
192e07a7fd5SMatthias Springer     // Get aliasing results from state.
193cd80617aSMatthias Springer     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
194e07a7fd5SMatthias Springer     auto aliasingReturnVals =
195e07a7fd5SMatthias Springer         funcState.aliasingReturnVals.lookup(funcOp).lookup(
196e07a7fd5SMatthias Springer             opOperand.getOperandNumber());
197e07a7fd5SMatthias Springer     SmallVector<OpResult> result;
198e07a7fd5SMatthias Springer     for (int64_t resultIdx : aliasingReturnVals)
199e07a7fd5SMatthias Springer       result.push_back(callOp->getOpResult(resultIdx));
200e07a7fd5SMatthias Springer     return result;
201e07a7fd5SMatthias Springer   }
202e07a7fd5SMatthias Springer 
203e07a7fd5SMatthias Springer   SmallVector<OpOperand *>
getAliasingOpOperandmlir::bufferization::func_ext::CallOpInterface204e07a7fd5SMatthias Springer   getAliasingOpOperand(Operation *op, OpResult opResult,
205e07a7fd5SMatthias Springer                        const AnalysisState &state) const {
206e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
207e07a7fd5SMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
208e07a7fd5SMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
209e07a7fd5SMatthias Springer     if (getFuncOpAnalysisState(state, funcOp) !=
210e07a7fd5SMatthias Springer         FuncOpAnalysisState::Analyzed) {
211e07a7fd5SMatthias Springer       // FuncOp not analyzed yet. Any OpOperand may be aliasing.
212e07a7fd5SMatthias Springer       SmallVector<OpOperand *> result;
213e07a7fd5SMatthias Springer       for (OpOperand &opOperand : op->getOpOperands())
214e07a7fd5SMatthias Springer         if (opOperand.get().getType().isa<TensorType>())
215e07a7fd5SMatthias Springer           result.push_back(&opOperand);
216e07a7fd5SMatthias Springer       return result;
217e07a7fd5SMatthias Springer     }
218e07a7fd5SMatthias Springer 
219e07a7fd5SMatthias Springer     // Get aliasing bbArgs from state.
220cd80617aSMatthias Springer     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
221e07a7fd5SMatthias Springer     auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup(
222e07a7fd5SMatthias Springer         opResult.getResultNumber());
223e07a7fd5SMatthias Springer     SmallVector<OpOperand *> result;
224e07a7fd5SMatthias Springer     for (int64_t bbArgIdx : aliasingFuncArgs)
225e07a7fd5SMatthias Springer       result.push_back(&callOp->getOpOperand(bbArgIdx));
226e07a7fd5SMatthias Springer     return result;
227e07a7fd5SMatthias Springer   }
228e07a7fd5SMatthias Springer 
bufferRelationmlir::bufferization::func_ext::CallOpInterface229e07a7fd5SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
230e07a7fd5SMatthias Springer                                 const AnalysisState &state) const {
23188539c5bSMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
23288539c5bSMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
23388539c5bSMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
23488539c5bSMatthias Springer     if (getFuncOpAnalysisState(state, funcOp) !=
23588539c5bSMatthias Springer         FuncOpAnalysisState::Analyzed) {
23688539c5bSMatthias Springer       // Function not analyzed yet. The conservative answer is "None".
23788539c5bSMatthias Springer       return BufferRelation::None;
23888539c5bSMatthias Springer     }
23988539c5bSMatthias Springer 
240cd80617aSMatthias Springer     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
24188539c5bSMatthias Springer     Optional<int64_t> maybeEquiv =
24288539c5bSMatthias Springer         getEquivalentFuncArgIdx(funcOp, funcState, opResult.getResultNumber());
243037f0995SKazu Hirata     if (maybeEquiv) {
24488539c5bSMatthias Springer #ifndef NDEBUG
24588539c5bSMatthias Springer       SmallVector<OpOperand *> aliasingOpOperands =
24688539c5bSMatthias Springer           getAliasingOpOperand(op, opResult, state);
24788539c5bSMatthias Springer       assert(aliasingOpOperands.size() == 1 &&
24888539c5bSMatthias Springer              "expected exactly 1 aliasing OpOperand");
2496d5fc1e3SKazu Hirata       assert(aliasingOpOperands.front()->getOperandNumber() == *maybeEquiv &&
25088539c5bSMatthias Springer              "inconsistent analysis state");
25188539c5bSMatthias Springer #endif
252e07a7fd5SMatthias Springer       return BufferRelation::Equivalent;
253e07a7fd5SMatthias Springer     }
25488539c5bSMatthias Springer     return BufferRelation::None;
25588539c5bSMatthias Springer   }
256e07a7fd5SMatthias Springer 
257e07a7fd5SMatthias Springer   /// All function arguments are writable. It is the responsibility of the
258e07a7fd5SMatthias Springer   /// CallOp to insert buffer copies where necessary.
bufferizemlir::bufferization::func_ext::CallOpInterface259e07a7fd5SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
260b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
261e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
262e07a7fd5SMatthias Springer     unsigned numResults = callOp.getNumResults();
263e07a7fd5SMatthias Springer     unsigned numOperands = callOp->getNumOperands();
264e07a7fd5SMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
265e07a7fd5SMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
266e07a7fd5SMatthias Springer     FunctionType funcType = funcOp.getFunctionType();
267e07a7fd5SMatthias Springer 
268e07a7fd5SMatthias Springer     // Result types of the bufferized CallOp.
269e07a7fd5SMatthias Springer     SmallVector<Type> resultTypes;
270e07a7fd5SMatthias Springer     // Replacement values for the existing CallOp. These are usually the results
271e07a7fd5SMatthias Springer     // of the bufferized CallOp, unless a tensor result folds onto an operand.
272e07a7fd5SMatthias Springer     SmallVector<Value> replacementValues(numResults, Value());
273e07a7fd5SMatthias Springer     // For non-tensor results: A mapping from return val indices of the old
274e07a7fd5SMatthias Springer     // CallOp to return val indices of the bufferized CallOp.
275e07a7fd5SMatthias Springer     SmallVector<Optional<unsigned>> retValMapping(numResults, None);
276e07a7fd5SMatthias Springer     // Operands of the bufferized CallOp.
277e07a7fd5SMatthias Springer     SmallVector<Value> newOperands(numOperands, Value());
278e07a7fd5SMatthias Springer 
27988539c5bSMatthias Springer     // 1. Compute the result types of the new CallOp.
280e07a7fd5SMatthias Springer     for (const auto &it : llvm::enumerate(callOp.getResultTypes())) {
281e07a7fd5SMatthias Springer       unsigned returnValIdx = it.index();
282e07a7fd5SMatthias Springer       Type returnType = it.value();
283e07a7fd5SMatthias Springer       if (!returnType.isa<TensorType>()) {
284e07a7fd5SMatthias Springer         // Non-tensor values are returned.
285e07a7fd5SMatthias Springer         retValMapping[returnValIdx] = resultTypes.size();
286e07a7fd5SMatthias Springer         resultTypes.push_back(returnType);
287e07a7fd5SMatthias Springer         continue;
288e07a7fd5SMatthias Springer       }
289e07a7fd5SMatthias Springer 
29088539c5bSMatthias Springer       // Returning a memref.
291e07a7fd5SMatthias Springer       retValMapping[returnValIdx] = resultTypes.size();
292e07a7fd5SMatthias Springer       resultTypes.push_back(funcType.getResult(resultTypes.size()));
293e07a7fd5SMatthias Springer     }
294e07a7fd5SMatthias Springer 
295e07a7fd5SMatthias Springer     // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
296e07a7fd5SMatthias Springer     for (OpOperand &opOperand : callOp->getOpOperands()) {
297e07a7fd5SMatthias Springer       unsigned idx = opOperand.getOperandNumber();
298e07a7fd5SMatthias Springer       Value tensorOperand = opOperand.get();
299e07a7fd5SMatthias Springer 
300e07a7fd5SMatthias Springer       // Non-tensor operands are just copied.
301e07a7fd5SMatthias Springer       if (!tensorOperand.getType().isa<TensorType>()) {
302e07a7fd5SMatthias Springer         newOperands[idx] = tensorOperand;
303e07a7fd5SMatthias Springer         continue;
304e07a7fd5SMatthias Springer       }
305e07a7fd5SMatthias Springer 
30688539c5bSMatthias Springer       // Retrieve buffers for tensor operands.
307e07a7fd5SMatthias Springer       Value buffer = newOperands[idx];
3085d50f51cSMatthias Springer       if (!buffer) {
3095d50f51cSMatthias Springer         FailureOr<Value> maybeBuffer =
3105d50f51cSMatthias Springer             getBuffer(rewriter, opOperand.get(), options);
3115d50f51cSMatthias Springer         if (failed(maybeBuffer))
3125d50f51cSMatthias Springer           return failure();
3135d50f51cSMatthias Springer         buffer = *maybeBuffer;
3145d50f51cSMatthias Springer       }
315e07a7fd5SMatthias Springer 
316e07a7fd5SMatthias Springer       // Caller / callee type mismatch is handled with a CastOp.
317e07a7fd5SMatthias Springer       auto memRefType = funcType.getInput(idx);
318e07a7fd5SMatthias Springer       // Since we don't yet have a clear layout story, to_memref may
319e07a7fd5SMatthias Springer       // conservatively turn tensors into more dynamic memref than necessary.
320e07a7fd5SMatthias Springer       // If the memref type of the callee fails, introduce an extra memref.cast
321e07a7fd5SMatthias Springer       // that will either canonicalize away or fail compilation until we can do
322e07a7fd5SMatthias Springer       // something better.
323e07a7fd5SMatthias Springer       if (buffer.getType() != memRefType) {
324e07a7fd5SMatthias Springer         assert(
325e07a7fd5SMatthias Springer             memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
326e07a7fd5SMatthias Springer             "CallOp::bufferize: cast incompatible");
327e07a7fd5SMatthias Springer         Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
328e07a7fd5SMatthias Springer                                                            memRefType, buffer);
329e07a7fd5SMatthias Springer         buffer = castBuffer;
330e07a7fd5SMatthias Springer       }
331e07a7fd5SMatthias Springer       newOperands[idx] = buffer;
332e07a7fd5SMatthias Springer     }
333e07a7fd5SMatthias Springer 
334e07a7fd5SMatthias Springer     // 3. Create the new CallOp.
335e07a7fd5SMatthias Springer     Operation *newCallOp = rewriter.create<func::CallOp>(
336e07a7fd5SMatthias Springer         callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
337e07a7fd5SMatthias Springer     newCallOp->setAttrs(callOp->getAttrs());
33888539c5bSMatthias Springer     // Get replacement values.
339e07a7fd5SMatthias Springer     for (unsigned i = 0; i < replacementValues.size(); ++i) {
340e07a7fd5SMatthias Springer       if (replacementValues[i])
341e07a7fd5SMatthias Springer         continue;
342e07a7fd5SMatthias Springer       replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
343e07a7fd5SMatthias Springer     }
344e07a7fd5SMatthias Springer 
345e07a7fd5SMatthias Springer     // 4. Replace the old op with the new op.
346e07a7fd5SMatthias Springer     replaceOpWithBufferizedValues(rewriter, callOp, replacementValues);
347e07a7fd5SMatthias Springer 
348e07a7fd5SMatthias Springer     return success();
349e07a7fd5SMatthias Springer   }
350e07a7fd5SMatthias Springer };
351e07a7fd5SMatthias Springer 
352e07a7fd5SMatthias Springer struct ReturnOpInterface
353e07a7fd5SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
354e07a7fd5SMatthias Springer                                                     func::ReturnOp> {
bufferizesToMemoryReadmlir::bufferization::func_ext::ReturnOpInterface355e07a7fd5SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
356e07a7fd5SMatthias Springer                               const AnalysisState &state) const {
357e07a7fd5SMatthias Springer     return true;
358e07a7fd5SMatthias Springer   }
359e07a7fd5SMatthias Springer 
bufferizesToMemoryWritemlir::bufferization::func_ext::ReturnOpInterface360e07a7fd5SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
361e07a7fd5SMatthias Springer                                const AnalysisState &state) const {
362e07a7fd5SMatthias Springer     return false;
363e07a7fd5SMatthias Springer   }
364e07a7fd5SMatthias Springer 
getAliasingOpResultmlir::bufferization::func_ext::ReturnOpInterface365e07a7fd5SMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
366e07a7fd5SMatthias Springer                                             const AnalysisState &state) const {
367e07a7fd5SMatthias Springer     return {};
368e07a7fd5SMatthias Springer   }
369e07a7fd5SMatthias Springer 
bufferizemlir::bufferization::func_ext::ReturnOpInterface370e07a7fd5SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
371b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
372e07a7fd5SMatthias Springer #ifndef NDEBUG
373e07a7fd5SMatthias Springer     auto returnOp = cast<func::ReturnOp>(op);
374e07a7fd5SMatthias Springer     assert(isa<FuncOp>(returnOp->getParentOp()) &&
375e07a7fd5SMatthias Springer            "only support FuncOp parent for ReturnOp");
376e07a7fd5SMatthias Springer #endif // NDEBUG
377e07a7fd5SMatthias Springer 
378e07a7fd5SMatthias Springer     // ReturnOps are bufferized as part of FuncOps.
3790b293bf0SMatthias Springer     return success();
380e07a7fd5SMatthias Springer   }
381e07a7fd5SMatthias Springer };
382e07a7fd5SMatthias Springer 
383e07a7fd5SMatthias Springer struct FuncOpInterface
384e07a7fd5SMatthias Springer     : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
385f287da8aSMatthias Springer   /// Rewrite function bbArgs and return values into buffer form. This function
386f287da8aSMatthias Springer   /// bufferizes the function signature and the ReturnOp. When the entire
387f287da8aSMatthias Springer   /// function body has been bufferized, function return types can be switched
388f287da8aSMatthias Springer   /// to more concise memref types as part of `foldMemRefCasts`.
389e07a7fd5SMatthias Springer   ///
390e07a7fd5SMatthias Springer   /// All function bbArgs are writable unless they are explicitly marked as
391e07a7fd5SMatthias Springer   /// read-only. Callers must insert copies when needed.
bufferizemlir::bufferization::func_ext::FuncOpInterface392e07a7fd5SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
393b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
394e07a7fd5SMatthias Springer     auto funcOp = cast<FuncOp>(op);
395e07a7fd5SMatthias Springer     FunctionType funcType = funcOp.getFunctionType();
396e07a7fd5SMatthias Springer 
397e07a7fd5SMatthias Springer     // Construct the bufferized function type.
398e07a7fd5SMatthias Springer     SmallVector<Type> argTypes;
399e07a7fd5SMatthias Springer     for (const auto &it : llvm::enumerate(funcType.getInputs())) {
400e07a7fd5SMatthias Springer       Type argType = it.value();
401e07a7fd5SMatthias Springer       if (auto tensorType = argType.dyn_cast<TensorType>()) {
402e07a7fd5SMatthias Springer         argTypes.push_back(
403e07a7fd5SMatthias Springer             getBufferizedFunctionArgType(funcOp, it.index(), options));
404e07a7fd5SMatthias Springer         continue;
405e07a7fd5SMatthias Springer       }
406e07a7fd5SMatthias Springer       argTypes.push_back(argType);
407e07a7fd5SMatthias Springer     }
408e07a7fd5SMatthias Springer 
409e07a7fd5SMatthias Springer     // Bodiless functions are assumed opaque and we cannot know the
410e07a7fd5SMatthias Springer     // bufferization contract they want to enforce. As a consequence, only
411e07a7fd5SMatthias Springer     // support functions that don't return any tensors atm.
412e07a7fd5SMatthias Springer     if (funcOp.getBody().empty()) {
413e07a7fd5SMatthias Springer       SmallVector<Type> retTypes;
414e07a7fd5SMatthias Springer       for (Type resultType : funcType.getResults()) {
415e07a7fd5SMatthias Springer         if (resultType.isa<TensorType>())
416e07a7fd5SMatthias Springer           return funcOp->emitError() << "cannot bufferize bodiless function "
417e07a7fd5SMatthias Springer                                      << "that returns a tensor";
418e07a7fd5SMatthias Springer         retTypes.push_back(resultType);
419e07a7fd5SMatthias Springer       }
420e07a7fd5SMatthias Springer       funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
421e07a7fd5SMatthias Springer       return success();
422e07a7fd5SMatthias Springer     }
423e07a7fd5SMatthias Springer 
424e07a7fd5SMatthias Springer     // TODO: Support functions with multiple returns.
425e07a7fd5SMatthias Springer     func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
426e07a7fd5SMatthias Springer     assert(returnOp && "expected func with single return op");
427f287da8aSMatthias Springer     Location loc = returnOp.getLoc();
428e07a7fd5SMatthias Springer 
429e07a7fd5SMatthias Springer     // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg.
430e07a7fd5SMatthias Springer     Block &frontBlock = funcOp.getBody().front();
431e07a7fd5SMatthias Springer     for (BlockArgument &bbArg : frontBlock.getArguments()) {
432e07a7fd5SMatthias Springer       auto tensorType = bbArg.getType().dyn_cast<TensorType>();
433e07a7fd5SMatthias Springer       // Non-tensor types stay the same.
434e07a7fd5SMatthias Springer       if (!tensorType)
435e07a7fd5SMatthias Springer         continue;
436e07a7fd5SMatthias Springer 
437e07a7fd5SMatthias Springer       // Collect all uses of the bbArg.
438e07a7fd5SMatthias Springer       SmallVector<OpOperand *> bbArgUses;
439e07a7fd5SMatthias Springer       for (OpOperand &use : bbArg.getUses())
440e07a7fd5SMatthias Springer         bbArgUses.push_back(&use);
441e07a7fd5SMatthias Springer 
442e07a7fd5SMatthias Springer       // Change the bbArg type to memref.
443e07a7fd5SMatthias Springer       Type memrefType =
444e07a7fd5SMatthias Springer           getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options);
445e07a7fd5SMatthias Springer       bbArg.setType(memrefType);
446e07a7fd5SMatthias Springer 
447e07a7fd5SMatthias Springer       // Replace all uses of the original tensor bbArg.
448e07a7fd5SMatthias Springer       rewriter.setInsertionPointToStart(&frontBlock);
449e07a7fd5SMatthias Springer       if (!bbArgUses.empty()) {
450e07a7fd5SMatthias Springer         // Insert to_tensor because the remaining function body has not been
451e07a7fd5SMatthias Springer         // bufferized yet.
452e07a7fd5SMatthias Springer         Value toTensorOp =
453e07a7fd5SMatthias Springer             rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg);
454e07a7fd5SMatthias Springer         for (OpOperand *use : bbArgUses)
455e07a7fd5SMatthias Springer           use->set(toTensorOp);
456e07a7fd5SMatthias Springer       }
457e07a7fd5SMatthias Springer     }
458e07a7fd5SMatthias Springer 
459e07a7fd5SMatthias Springer     // 2. For each result, keep track of which inplace argument it reuses.
460e07a7fd5SMatthias Springer     SmallVector<Value> returnValues;
461e07a7fd5SMatthias Springer     for (OpOperand &returnOperand : returnOp->getOpOperands()) {
462e07a7fd5SMatthias Springer       Value returnVal = returnOperand.get();
463f287da8aSMatthias Springer       auto tensorType = returnVal.getType().dyn_cast<TensorType>();
464f287da8aSMatthias Springer       rewriter.setInsertionPoint(returnOp);
465e07a7fd5SMatthias Springer 
466e07a7fd5SMatthias Springer       // If not a tensor type just forward it.
467f287da8aSMatthias Springer       if (!tensorType) {
468e07a7fd5SMatthias Springer         returnValues.push_back(returnVal);
469e07a7fd5SMatthias Springer         continue;
470e07a7fd5SMatthias Springer       }
471e07a7fd5SMatthias Springer 
472f287da8aSMatthias Springer       BaseMemRefType resultType;
473f287da8aSMatthias Springer       if (options.functionBoundaryTypeConversion ==
474f287da8aSMatthias Springer           BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
475f287da8aSMatthias Springer         resultType = getMemRefTypeWithStaticIdentityLayout(tensorType);
476f287da8aSMatthias Springer       } else {
477f287da8aSMatthias Springer         // Note: If `InferLayoutMap`, cast are later folded away.
478f287da8aSMatthias Springer         resultType = getMemRefTypeWithFullyDynamicLayout(tensorType);
479f287da8aSMatthias Springer       }
480f287da8aSMatthias Springer       Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
481f287da8aSMatthias Springer           loc, resultType, returnVal);
482f287da8aSMatthias Springer       returnValues.push_back(toMemrefOp);
483e07a7fd5SMatthias Springer     }
484e07a7fd5SMatthias Springer 
485e07a7fd5SMatthias Springer     // 3. Rewrite the terminator without the in-place bufferizable values.
486e07a7fd5SMatthias Springer     returnOp.operandsMutable().assign(returnValues);
487e07a7fd5SMatthias Springer 
488e07a7fd5SMatthias Springer     // 4. Rewrite the FuncOp type to buffer form.
489e07a7fd5SMatthias Springer     funcOp.setType(FunctionType::get(op->getContext(), argTypes,
490e07a7fd5SMatthias Springer                                      ValueRange(returnValues).getTypes()));
491e07a7fd5SMatthias Springer 
492e07a7fd5SMatthias Springer     return success();
493e07a7fd5SMatthias Springer   }
494e07a7fd5SMatthias Springer 
495e07a7fd5SMatthias Springer   /// Return `true` if the given function argument is writable.
isWritablemlir::bufferization::func_ext::FuncOpInterface496e07a7fd5SMatthias Springer   bool isWritable(Operation *op, Value value,
497e07a7fd5SMatthias Springer                   const AnalysisState &state) const {
498e07a7fd5SMatthias Springer     auto funcOp = cast<FuncOp>(op);
499e07a7fd5SMatthias Springer     BlockArgument bbArg = value.dyn_cast<BlockArgument>();
500e07a7fd5SMatthias Springer     assert(bbArg && "expected BlockArgument");
501e07a7fd5SMatthias Springer 
502e07a7fd5SMatthias Springer     // "bufferization.writable" overrides other writability decisions. This is
503e07a7fd5SMatthias Springer     // currently used for testing only.
504e07a7fd5SMatthias Springer     if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>(
505e07a7fd5SMatthias Springer             bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName))
506e07a7fd5SMatthias Springer       return writable.getValue();
507e07a7fd5SMatthias Springer 
508e07a7fd5SMatthias Springer     // All function arguments are writable by default.
509e07a7fd5SMatthias Springer     return true;
510e07a7fd5SMatthias Springer   }
511e07a7fd5SMatthias Springer };
512e07a7fd5SMatthias Springer 
513e07a7fd5SMatthias Springer } // namespace func_ext
514e07a7fd5SMatthias Springer } // namespace bufferization
515e07a7fd5SMatthias Springer } // namespace mlir
516e07a7fd5SMatthias Springer 
517e07a7fd5SMatthias Springer void mlir::bufferization::func_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)518e07a7fd5SMatthias Springer     registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
519e07a7fd5SMatthias Springer   registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
520e07a7fd5SMatthias Springer     func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
521e07a7fd5SMatthias Springer     func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
522e07a7fd5SMatthias Springer     func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
523e07a7fd5SMatthias Springer   });
524e07a7fd5SMatthias Springer }
525