1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
17 
18 namespace mlir {
19 namespace bufferization {
20 namespace func_ext {
21 
22 void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
23   analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
24   auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
25   auto createdAliasingOperands =
26       aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping());
27   auto createdAliasingResults =
28       aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping());
29   auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet());
30   auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet());
31   (void)createdEquiv;
32   (void)createdAliasingOperands;
33   (void)createdAliasingResults;
34   (void)createdRead;
35   (void)createdWritten;
36 #ifndef NDEBUG
37   assert(createdEquiv.second && "equivalence info exists already");
38   assert(createdAliasingOperands.second && "aliasing info exists already");
39   assert(createdAliasingResults.second && "aliasing info exists already");
40   assert(createdRead.second && "bbarg access info exists already");
41   assert(createdWritten.second && "bbarg access info exists already");
42 #endif // NDEBUG
43 }
44 
45 /// Return the unique ReturnOp that terminates `funcOp`.
46 /// Return nullptr if there is no such unique ReturnOp.
47 static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
48   func::ReturnOp returnOp;
49   for (Block &b : funcOp.getBody()) {
50     if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
51       if (returnOp)
52         return nullptr;
53       returnOp = candidateOp;
54     }
55   }
56   return returnOp;
57 }
58 
59 /// Return the index-th bufferized function argument type. This assumes that the
60 /// specified argument is a tensor. If the tensor is ranked, a layout map may be
61 /// specified by the user. If no layout map is specified, the default layout map
62 /// (as per `options.functionBoundaryTypeConversion`) is used.
63 static BaseMemRefType
64 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
65                              const BufferizationOptions &options) {
66   auto tensorType =
67       funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
68   assert(tensorType && "expected TensorType");
69 
70   BaseMemRefType memrefType;
71   if (options.functionBoundaryTypeConversion ==
72       BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
73     memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
74   } else {
75     // Note: Layout maps on function parameters cannot be inferred. The best we
76     // can do at the moment is "fully dynamic".
77     memrefType = getMemRefTypeWithFullyDynamicLayout(tensorType);
78   }
79 
80   auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
81       index, BufferizationDialect::kBufferLayoutAttrName);
82   if (!layoutAttr)
83     return memrefType;
84 
85   auto rankedMemrefType = memrefType.dyn_cast<MemRefType>();
86   assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
87   return MemRefType::get(
88       rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
89       layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt());
90 }
91 
92 /// Return the FuncOp called by `callOp`.
93 static FuncOp getCalledFunction(CallOpInterface callOp) {
94   SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
95   if (!sym)
96     return nullptr;
97   return dyn_cast_or_null<FuncOp>(
98       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
99 }
100 
101 /// Get FuncAnalysisState.
102 static const FuncAnalysisState &
103 getFuncAnalysisState(const AnalysisState &state) {
104   Optional<const FuncAnalysisState *> maybeState =
105       state.getDialectState<FuncAnalysisState>(
106           func::FuncDialect::getDialectNamespace());
107   assert(maybeState.hasValue() && "FuncAnalysisState does not exist");
108   return **maybeState;
109 }
110 
111 /// Return the state (phase) of analysis of the FuncOp.
112 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
113                                                   FuncOp funcOp) {
114   Optional<const FuncAnalysisState *> maybeState =
115       state.getDialectState<FuncAnalysisState>(
116           func::FuncDialect::getDialectNamespace());
117   if (!maybeState.hasValue())
118     return FuncOpAnalysisState::NotAnalyzed;
119   const auto &analyzedFuncOps = maybeState.getValue()->analyzedFuncOps;
120   auto it = analyzedFuncOps.find(funcOp);
121   if (it == analyzedFuncOps.end())
122     return FuncOpAnalysisState::NotAnalyzed;
123   return it->second;
124 }
125 
126 /// Return the index of the bbArg in the given FuncOp that is equivalent to the
127 /// specified return value (if any).
128 static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp,
129                                                  const FuncAnalysisState &state,
130                                                  int64_t returnValIdx) {
131   auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
132   if (funcOpIt == state.equivalentFuncArgs.end())
133     // No equivalence info stores for funcOp.
134     return None;
135 
136   auto retValIt = funcOpIt->getSecond().find(returnValIdx);
137   if (retValIt == funcOpIt->getSecond().end())
138     // Return value has no equivalent bbArg.
139     return None;
140 
141   return retValIt->getSecond();
142 }
143 
144 struct CallOpInterface
145     : public BufferizableOpInterface::ExternalModel<CallOpInterface,
146                                                     func::CallOp> {
147   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
148                               const AnalysisState &state) const {
149     func::CallOp callOp = cast<func::CallOp>(op);
150     FuncOp funcOp = getCalledFunction(callOp);
151     assert(funcOp && "expected CallOp to a FuncOp");
152 
153     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
154       // FuncOp not analyzed yet. Assume that OpOperand is read.
155       return true;
156 
157     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
158     return funcState.readBbArgs.lookup(funcOp).contains(
159         opOperand.getOperandNumber());
160   }
161 
162   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
163                                const AnalysisState &state) const {
164     func::CallOp callOp = cast<func::CallOp>(op);
165     FuncOp funcOp = getCalledFunction(callOp);
166     assert(funcOp && "expected CallOp to a FuncOp");
167 
168     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
169       // FuncOp not analyzed yet. Assume that OpOperand is written.
170       return true;
171 
172     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
173     return funcState.writtenBbArgs.lookup(funcOp).contains(
174         opOperand.getOperandNumber());
175   }
176 
177   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
178                                             const AnalysisState &state) const {
179     func::CallOp callOp = cast<func::CallOp>(op);
180     FuncOp funcOp = getCalledFunction(callOp);
181     assert(funcOp && "expected CallOp to a FuncOp");
182     if (getFuncOpAnalysisState(state, funcOp) !=
183         FuncOpAnalysisState::Analyzed) {
184       // FuncOp not analyzed yet. Any OpResult may be aliasing.
185       SmallVector<OpResult> result;
186       for (OpResult opResult : op->getOpResults())
187         if (opResult.getType().isa<TensorType>())
188           result.push_back(opResult);
189       return result;
190     }
191 
192     // Get aliasing results from state.
193     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
194     auto aliasingReturnVals =
195         funcState.aliasingReturnVals.lookup(funcOp).lookup(
196             opOperand.getOperandNumber());
197     SmallVector<OpResult> result;
198     for (int64_t resultIdx : aliasingReturnVals)
199       result.push_back(callOp->getOpResult(resultIdx));
200     return result;
201   }
202 
203   SmallVector<OpOperand *>
204   getAliasingOpOperand(Operation *op, OpResult opResult,
205                        const AnalysisState &state) const {
206     func::CallOp callOp = cast<func::CallOp>(op);
207     FuncOp funcOp = getCalledFunction(callOp);
208     assert(funcOp && "expected CallOp to a FuncOp");
209     if (getFuncOpAnalysisState(state, funcOp) !=
210         FuncOpAnalysisState::Analyzed) {
211       // FuncOp not analyzed yet. Any OpOperand may be aliasing.
212       SmallVector<OpOperand *> result;
213       for (OpOperand &opOperand : op->getOpOperands())
214         if (opOperand.get().getType().isa<TensorType>())
215           result.push_back(&opOperand);
216       return result;
217     }
218 
219     // Get aliasing bbArgs from state.
220     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
221     auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup(
222         opResult.getResultNumber());
223     SmallVector<OpOperand *> result;
224     for (int64_t bbArgIdx : aliasingFuncArgs)
225       result.push_back(&callOp->getOpOperand(bbArgIdx));
226     return result;
227   }
228 
229   BufferRelation bufferRelation(Operation *op, OpResult opResult,
230                                 const AnalysisState &state) const {
231     func::CallOp callOp = cast<func::CallOp>(op);
232     FuncOp funcOp = getCalledFunction(callOp);
233     assert(funcOp && "expected CallOp to a FuncOp");
234     if (getFuncOpAnalysisState(state, funcOp) !=
235         FuncOpAnalysisState::Analyzed) {
236       // Function not analyzed yet. The conservative answer is "None".
237       return BufferRelation::None;
238     }
239 
240     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
241     Optional<int64_t> maybeEquiv =
242         getEquivalentFuncArgIdx(funcOp, funcState, opResult.getResultNumber());
243     if (maybeEquiv.hasValue()) {
244 #ifndef NDEBUG
245       SmallVector<OpOperand *> aliasingOpOperands =
246           getAliasingOpOperand(op, opResult, state);
247       assert(aliasingOpOperands.size() == 1 &&
248              "expected exactly 1 aliasing OpOperand");
249       assert(aliasingOpOperands.front()->getOperandNumber() ==
250                  maybeEquiv.getValue() &&
251              "inconsistent analysis state");
252 #endif
253       return BufferRelation::Equivalent;
254     }
255     return BufferRelation::None;
256   }
257 
258   /// All function arguments are writable. It is the responsibility of the
259   /// CallOp to insert buffer copies where necessary.
260   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
261                           const BufferizationOptions &options) const {
262     func::CallOp callOp = cast<func::CallOp>(op);
263     unsigned numResults = callOp.getNumResults();
264     unsigned numOperands = callOp->getNumOperands();
265     FuncOp funcOp = getCalledFunction(callOp);
266     assert(funcOp && "expected CallOp to a FuncOp");
267     FunctionType funcType = funcOp.getFunctionType();
268 
269     // Result types of the bufferized CallOp.
270     SmallVector<Type> resultTypes;
271     // Replacement values for the existing CallOp. These are usually the results
272     // of the bufferized CallOp, unless a tensor result folds onto an operand.
273     SmallVector<Value> replacementValues(numResults, Value());
274     // For non-tensor results: A mapping from return val indices of the old
275     // CallOp to return val indices of the bufferized CallOp.
276     SmallVector<Optional<unsigned>> retValMapping(numResults, None);
277     // Operands of the bufferized CallOp.
278     SmallVector<Value> newOperands(numOperands, Value());
279 
280     // 1. Compute the result types of the new CallOp.
281     for (const auto &it : llvm::enumerate(callOp.getResultTypes())) {
282       unsigned returnValIdx = it.index();
283       Type returnType = it.value();
284       if (!returnType.isa<TensorType>()) {
285         // Non-tensor values are returned.
286         retValMapping[returnValIdx] = resultTypes.size();
287         resultTypes.push_back(returnType);
288         continue;
289       }
290 
291       // Returning a memref.
292       retValMapping[returnValIdx] = resultTypes.size();
293       resultTypes.push_back(funcType.getResult(resultTypes.size()));
294     }
295 
296     // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
297     for (OpOperand &opOperand : callOp->getOpOperands()) {
298       unsigned idx = opOperand.getOperandNumber();
299       Value tensorOperand = opOperand.get();
300 
301       // Non-tensor operands are just copied.
302       if (!tensorOperand.getType().isa<TensorType>()) {
303         newOperands[idx] = tensorOperand;
304         continue;
305       }
306 
307       // Retrieve buffers for tensor operands.
308       Value buffer = newOperands[idx];
309       if (!buffer)
310         buffer = getBuffer(rewriter, opOperand.get(), options);
311 
312       // Caller / callee type mismatch is handled with a CastOp.
313       auto memRefType = funcType.getInput(idx);
314       // Since we don't yet have a clear layout story, to_memref may
315       // conservatively turn tensors into more dynamic memref than necessary.
316       // If the memref type of the callee fails, introduce an extra memref.cast
317       // that will either canonicalize away or fail compilation until we can do
318       // something better.
319       if (buffer.getType() != memRefType) {
320         assert(
321             memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
322             "CallOp::bufferize: cast incompatible");
323         Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
324                                                            memRefType, buffer);
325         buffer = castBuffer;
326       }
327       newOperands[idx] = buffer;
328     }
329 
330     // 3. Create the new CallOp.
331     Operation *newCallOp = rewriter.create<func::CallOp>(
332         callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
333     newCallOp->setAttrs(callOp->getAttrs());
334     // Get replacement values.
335     for (unsigned i = 0; i < replacementValues.size(); ++i) {
336       if (replacementValues[i])
337         continue;
338       replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
339     }
340 
341     // 4. Replace the old op with the new op.
342     replaceOpWithBufferizedValues(rewriter, callOp, replacementValues);
343 
344     return success();
345   }
346 };
347 
348 struct ReturnOpInterface
349     : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
350                                                     func::ReturnOp> {
351   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
352                               const AnalysisState &state) const {
353     return true;
354   }
355 
356   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
357                                const AnalysisState &state) const {
358     return false;
359   }
360 
361   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
362                                             const AnalysisState &state) const {
363     return {};
364   }
365 
366   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
367                           const BufferizationOptions &options) const {
368 #ifndef NDEBUG
369     auto returnOp = cast<func::ReturnOp>(op);
370     assert(isa<FuncOp>(returnOp->getParentOp()) &&
371            "only support FuncOp parent for ReturnOp");
372 #endif // NDEBUG
373 
374     // ReturnOps are bufferized as part of FuncOps.
375     return success();
376   }
377 };
378 
379 struct FuncOpInterface
380     : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
381   /// Rewrite function bbArgs and return values into buffer form. This function
382   /// bufferizes the function signature and the ReturnOp. When the entire
383   /// function body has been bufferized, function return types can be switched
384   /// to more concise memref types as part of `foldMemRefCasts`.
385   ///
386   /// All function bbArgs are writable unless they are explicitly marked as
387   /// read-only. Callers must insert copies when needed.
388   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
389                           const BufferizationOptions &options) const {
390     auto funcOp = cast<FuncOp>(op);
391     FunctionType funcType = funcOp.getFunctionType();
392 
393     // Construct the bufferized function type.
394     SmallVector<Type> argTypes;
395     for (const auto &it : llvm::enumerate(funcType.getInputs())) {
396       Type argType = it.value();
397       if (auto tensorType = argType.dyn_cast<TensorType>()) {
398         argTypes.push_back(
399             getBufferizedFunctionArgType(funcOp, it.index(), options));
400         continue;
401       }
402       argTypes.push_back(argType);
403     }
404 
405     // Bodiless functions are assumed opaque and we cannot know the
406     // bufferization contract they want to enforce. As a consequence, only
407     // support functions that don't return any tensors atm.
408     if (funcOp.getBody().empty()) {
409       SmallVector<Type> retTypes;
410       for (Type resultType : funcType.getResults()) {
411         if (resultType.isa<TensorType>())
412           return funcOp->emitError() << "cannot bufferize bodiless function "
413                                      << "that returns a tensor";
414         retTypes.push_back(resultType);
415       }
416       funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
417       return success();
418     }
419 
420     // TODO: Support functions with multiple returns.
421     func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
422     assert(returnOp && "expected func with single return op");
423     Location loc = returnOp.getLoc();
424 
425     // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg.
426     Block &frontBlock = funcOp.getBody().front();
427     for (BlockArgument &bbArg : frontBlock.getArguments()) {
428       auto tensorType = bbArg.getType().dyn_cast<TensorType>();
429       // Non-tensor types stay the same.
430       if (!tensorType)
431         continue;
432 
433       // Collect all uses of the bbArg.
434       SmallVector<OpOperand *> bbArgUses;
435       for (OpOperand &use : bbArg.getUses())
436         bbArgUses.push_back(&use);
437 
438       // Change the bbArg type to memref.
439       Type memrefType =
440           getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options);
441       bbArg.setType(memrefType);
442 
443       // Replace all uses of the original tensor bbArg.
444       rewriter.setInsertionPointToStart(&frontBlock);
445       if (!bbArgUses.empty()) {
446         // Insert to_tensor because the remaining function body has not been
447         // bufferized yet.
448         Value toTensorOp =
449             rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg);
450         for (OpOperand *use : bbArgUses)
451           use->set(toTensorOp);
452       }
453     }
454 
455     // 2. For each result, keep track of which inplace argument it reuses.
456     SmallVector<Value> returnValues;
457     for (OpOperand &returnOperand : returnOp->getOpOperands()) {
458       Value returnVal = returnOperand.get();
459       auto tensorType = returnVal.getType().dyn_cast<TensorType>();
460       rewriter.setInsertionPoint(returnOp);
461 
462       // If not a tensor type just forward it.
463       if (!tensorType) {
464         returnValues.push_back(returnVal);
465         continue;
466       }
467 
468       BaseMemRefType resultType;
469       if (options.functionBoundaryTypeConversion ==
470           BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
471         resultType = getMemRefTypeWithStaticIdentityLayout(tensorType);
472       } else {
473         // Note: If `InferLayoutMap`, cast are later folded away.
474         resultType = getMemRefTypeWithFullyDynamicLayout(tensorType);
475       }
476       Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
477           loc, resultType, returnVal);
478       returnValues.push_back(toMemrefOp);
479     }
480 
481     // 3. Rewrite the terminator without the in-place bufferizable values.
482     returnOp.operandsMutable().assign(returnValues);
483 
484     // 4. Rewrite the FuncOp type to buffer form.
485     funcOp.setType(FunctionType::get(op->getContext(), argTypes,
486                                      ValueRange(returnValues).getTypes()));
487 
488     return success();
489   }
490 
491   /// Return `true` if the given function argument is writable.
492   bool isWritable(Operation *op, Value value,
493                   const AnalysisState &state) const {
494     auto funcOp = cast<FuncOp>(op);
495     BlockArgument bbArg = value.dyn_cast<BlockArgument>();
496     assert(bbArg && "expected BlockArgument");
497 
498     // "bufferization.writable" overrides other writability decisions. This is
499     // currently used for testing only.
500     if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>(
501             bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName))
502       return writable.getValue();
503 
504     // All function arguments are writable by default.
505     return true;
506   }
507 };
508 
509 } // namespace func_ext
510 } // namespace bufferization
511 } // namespace mlir
512 
513 void mlir::bufferization::func_ext::
514     registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
515   registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
516     func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
517     func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
518     func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
519   });
520 }
521