1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
17 
18 namespace mlir {
19 namespace bufferization {
20 namespace func_ext {
21 
22 void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
23   analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
24   auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
25   auto createdAliasingOperands =
26       aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping());
27   auto createdAliasingResults =
28       aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping());
29   auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet());
30   auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet());
31   (void)createdEquiv;
32   (void)createdAliasingOperands;
33   (void)createdAliasingResults;
34   (void)createdRead;
35   (void)createdWritten;
36 #ifndef NDEBUG
37   assert(createdEquiv.second && "equivalence info exists already");
38   assert(createdAliasingOperands.second && "aliasing info exists already");
39   assert(createdAliasingResults.second && "aliasing info exists already");
40   assert(createdRead.second && "bbarg access info exists already");
41   assert(createdWritten.second && "bbarg access info exists already");
42 #endif // NDEBUG
43 }
44 
45 /// Return the unique ReturnOp that terminates `funcOp`.
46 /// Return nullptr if there is no such unique ReturnOp.
47 static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
48   func::ReturnOp returnOp;
49   for (Block &b : funcOp.getBody()) {
50     if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
51       if (returnOp)
52         return nullptr;
53       returnOp = candidateOp;
54     }
55   }
56   return returnOp;
57 }
58 
59 /// Return the index-th bufferized function argument type. This assumes that the
60 /// specified argument is a tensor. If the tensor is ranked, a layout map may be
61 /// specified by the user. If no layout map is specified, the default layout map
62 /// (as per `options.functionBoundaryTypeConversion`) is used.
63 static BaseMemRefType
64 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
65                              const BufferizationOptions &options) {
66   auto tensorType =
67       funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
68   assert(tensorType && "expected TensorType");
69 
70   BaseMemRefType memrefType;
71   if (options.functionBoundaryTypeConversion ==
72       BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
73     memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
74   } else {
75     // Note: Layout maps on function parameters cannot be inferred. The best we
76     // can do at the moment is "fully dynamic".
77     memrefType = getMemRefTypeWithFullyDynamicLayout(tensorType);
78   }
79 
80   auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
81       index, BufferizationDialect::kBufferLayoutAttrName);
82   if (!layoutAttr)
83     return memrefType;
84 
85   auto rankedMemrefType = memrefType.dyn_cast<MemRefType>();
86   assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
87   return MemRefType::get(
88       rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
89       layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt());
90 }
91 
92 /// Return the FuncOp called by `callOp`.
93 static FuncOp getCalledFunction(CallOpInterface callOp) {
94   SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
95   if (!sym)
96     return nullptr;
97   return dyn_cast_or_null<FuncOp>(
98       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
99 }
100 
101 /// Get FuncAnalysisState.
102 static const FuncAnalysisState &
103 getFuncAnalysisState(const AnalysisState &state) {
104   Optional<const FuncAnalysisState *> maybeState =
105       state.getDialectState<FuncAnalysisState>(
106           func::FuncDialect::getDialectNamespace());
107   assert(maybeState && "FuncAnalysisState does not exist");
108   return **maybeState;
109 }
110 
111 /// Return the state (phase) of analysis of the FuncOp.
112 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
113                                                   FuncOp funcOp) {
114   Optional<const FuncAnalysisState *> maybeState =
115       state.getDialectState<FuncAnalysisState>(
116           func::FuncDialect::getDialectNamespace());
117   if (!maybeState.has_value())
118     return FuncOpAnalysisState::NotAnalyzed;
119   const auto &analyzedFuncOps = maybeState.value()->analyzedFuncOps;
120   auto it = analyzedFuncOps.find(funcOp);
121   if (it == analyzedFuncOps.end())
122     return FuncOpAnalysisState::NotAnalyzed;
123   return it->second;
124 }
125 
126 /// Return the index of the bbArg in the given FuncOp that is equivalent to the
127 /// specified return value (if any).
128 static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp,
129                                                  const FuncAnalysisState &state,
130                                                  int64_t returnValIdx) {
131   auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
132   if (funcOpIt == state.equivalentFuncArgs.end())
133     // No equivalence info stores for funcOp.
134     return None;
135 
136   auto retValIt = funcOpIt->getSecond().find(returnValIdx);
137   if (retValIt == funcOpIt->getSecond().end())
138     // Return value has no equivalent bbArg.
139     return None;
140 
141   return retValIt->getSecond();
142 }
143 
144 struct CallOpInterface
145     : public BufferizableOpInterface::ExternalModel<CallOpInterface,
146                                                     func::CallOp> {
147   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
148                               const AnalysisState &state) const {
149     func::CallOp callOp = cast<func::CallOp>(op);
150     FuncOp funcOp = getCalledFunction(callOp);
151     assert(funcOp && "expected CallOp to a FuncOp");
152 
153     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
154       // FuncOp not analyzed yet. Assume that OpOperand is read.
155       return true;
156 
157     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
158     return funcState.readBbArgs.lookup(funcOp).contains(
159         opOperand.getOperandNumber());
160   }
161 
162   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
163                                const AnalysisState &state) const {
164     func::CallOp callOp = cast<func::CallOp>(op);
165     FuncOp funcOp = getCalledFunction(callOp);
166     assert(funcOp && "expected CallOp to a FuncOp");
167 
168     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
169       // FuncOp not analyzed yet. Assume that OpOperand is written.
170       return true;
171 
172     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
173     return funcState.writtenBbArgs.lookup(funcOp).contains(
174         opOperand.getOperandNumber());
175   }
176 
177   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
178                                             const AnalysisState &state) const {
179     func::CallOp callOp = cast<func::CallOp>(op);
180     FuncOp funcOp = getCalledFunction(callOp);
181     assert(funcOp && "expected CallOp to a FuncOp");
182     if (getFuncOpAnalysisState(state, funcOp) !=
183         FuncOpAnalysisState::Analyzed) {
184       // FuncOp not analyzed yet. Any OpResult may be aliasing.
185       SmallVector<OpResult> result;
186       for (OpResult opResult : op->getOpResults())
187         if (opResult.getType().isa<TensorType>())
188           result.push_back(opResult);
189       return result;
190     }
191 
192     // Get aliasing results from state.
193     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
194     auto aliasingReturnVals =
195         funcState.aliasingReturnVals.lookup(funcOp).lookup(
196             opOperand.getOperandNumber());
197     SmallVector<OpResult> result;
198     for (int64_t resultIdx : aliasingReturnVals)
199       result.push_back(callOp->getOpResult(resultIdx));
200     return result;
201   }
202 
203   SmallVector<OpOperand *>
204   getAliasingOpOperand(Operation *op, OpResult opResult,
205                        const AnalysisState &state) const {
206     func::CallOp callOp = cast<func::CallOp>(op);
207     FuncOp funcOp = getCalledFunction(callOp);
208     assert(funcOp && "expected CallOp to a FuncOp");
209     if (getFuncOpAnalysisState(state, funcOp) !=
210         FuncOpAnalysisState::Analyzed) {
211       // FuncOp not analyzed yet. Any OpOperand may be aliasing.
212       SmallVector<OpOperand *> result;
213       for (OpOperand &opOperand : op->getOpOperands())
214         if (opOperand.get().getType().isa<TensorType>())
215           result.push_back(&opOperand);
216       return result;
217     }
218 
219     // Get aliasing bbArgs from state.
220     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
221     auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup(
222         opResult.getResultNumber());
223     SmallVector<OpOperand *> result;
224     for (int64_t bbArgIdx : aliasingFuncArgs)
225       result.push_back(&callOp->getOpOperand(bbArgIdx));
226     return result;
227   }
228 
229   BufferRelation bufferRelation(Operation *op, OpResult opResult,
230                                 const AnalysisState &state) const {
231     func::CallOp callOp = cast<func::CallOp>(op);
232     FuncOp funcOp = getCalledFunction(callOp);
233     assert(funcOp && "expected CallOp to a FuncOp");
234     if (getFuncOpAnalysisState(state, funcOp) !=
235         FuncOpAnalysisState::Analyzed) {
236       // Function not analyzed yet. The conservative answer is "None".
237       return BufferRelation::None;
238     }
239 
240     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
241     Optional<int64_t> maybeEquiv =
242         getEquivalentFuncArgIdx(funcOp, funcState, opResult.getResultNumber());
243     if (maybeEquiv) {
244 #ifndef NDEBUG
245       SmallVector<OpOperand *> aliasingOpOperands =
246           getAliasingOpOperand(op, opResult, state);
247       assert(aliasingOpOperands.size() == 1 &&
248              "expected exactly 1 aliasing OpOperand");
249       assert(aliasingOpOperands.front()->getOperandNumber() == *maybeEquiv &&
250              "inconsistent analysis state");
251 #endif
252       return BufferRelation::Equivalent;
253     }
254     return BufferRelation::None;
255   }
256 
257   /// All function arguments are writable. It is the responsibility of the
258   /// CallOp to insert buffer copies where necessary.
259   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
260                           const BufferizationOptions &options) const {
261     func::CallOp callOp = cast<func::CallOp>(op);
262     unsigned numResults = callOp.getNumResults();
263     unsigned numOperands = callOp->getNumOperands();
264     FuncOp funcOp = getCalledFunction(callOp);
265     assert(funcOp && "expected CallOp to a FuncOp");
266     FunctionType funcType = funcOp.getFunctionType();
267 
268     // Result types of the bufferized CallOp.
269     SmallVector<Type> resultTypes;
270     // Replacement values for the existing CallOp. These are usually the results
271     // of the bufferized CallOp, unless a tensor result folds onto an operand.
272     SmallVector<Value> replacementValues(numResults, Value());
273     // For non-tensor results: A mapping from return val indices of the old
274     // CallOp to return val indices of the bufferized CallOp.
275     SmallVector<Optional<unsigned>> retValMapping(numResults, None);
276     // Operands of the bufferized CallOp.
277     SmallVector<Value> newOperands(numOperands, Value());
278 
279     // 1. Compute the result types of the new CallOp.
280     for (const auto &it : llvm::enumerate(callOp.getResultTypes())) {
281       unsigned returnValIdx = it.index();
282       Type returnType = it.value();
283       if (!returnType.isa<TensorType>()) {
284         // Non-tensor values are returned.
285         retValMapping[returnValIdx] = resultTypes.size();
286         resultTypes.push_back(returnType);
287         continue;
288       }
289 
290       // Returning a memref.
291       retValMapping[returnValIdx] = resultTypes.size();
292       resultTypes.push_back(funcType.getResult(resultTypes.size()));
293     }
294 
295     // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
296     for (OpOperand &opOperand : callOp->getOpOperands()) {
297       unsigned idx = opOperand.getOperandNumber();
298       Value tensorOperand = opOperand.get();
299 
300       // Non-tensor operands are just copied.
301       if (!tensorOperand.getType().isa<TensorType>()) {
302         newOperands[idx] = tensorOperand;
303         continue;
304       }
305 
306       // Retrieve buffers for tensor operands.
307       Value buffer = newOperands[idx];
308       if (!buffer)
309         buffer = getBuffer(rewriter, opOperand.get(), options);
310 
311       // Caller / callee type mismatch is handled with a CastOp.
312       auto memRefType = funcType.getInput(idx);
313       // Since we don't yet have a clear layout story, to_memref may
314       // conservatively turn tensors into more dynamic memref than necessary.
315       // If the memref type of the callee fails, introduce an extra memref.cast
316       // that will either canonicalize away or fail compilation until we can do
317       // something better.
318       if (buffer.getType() != memRefType) {
319         assert(
320             memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
321             "CallOp::bufferize: cast incompatible");
322         Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
323                                                            memRefType, buffer);
324         buffer = castBuffer;
325       }
326       newOperands[idx] = buffer;
327     }
328 
329     // 3. Create the new CallOp.
330     Operation *newCallOp = rewriter.create<func::CallOp>(
331         callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
332     newCallOp->setAttrs(callOp->getAttrs());
333     // Get replacement values.
334     for (unsigned i = 0; i < replacementValues.size(); ++i) {
335       if (replacementValues[i])
336         continue;
337       replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
338     }
339 
340     // 4. Replace the old op with the new op.
341     replaceOpWithBufferizedValues(rewriter, callOp, replacementValues);
342 
343     return success();
344   }
345 };
346 
347 struct ReturnOpInterface
348     : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
349                                                     func::ReturnOp> {
350   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
351                               const AnalysisState &state) const {
352     return true;
353   }
354 
355   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
356                                const AnalysisState &state) const {
357     return false;
358   }
359 
360   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
361                                             const AnalysisState &state) const {
362     return {};
363   }
364 
365   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
366                           const BufferizationOptions &options) const {
367 #ifndef NDEBUG
368     auto returnOp = cast<func::ReturnOp>(op);
369     assert(isa<FuncOp>(returnOp->getParentOp()) &&
370            "only support FuncOp parent for ReturnOp");
371 #endif // NDEBUG
372 
373     // ReturnOps are bufferized as part of FuncOps.
374     return success();
375   }
376 };
377 
378 struct FuncOpInterface
379     : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
380   /// Rewrite function bbArgs and return values into buffer form. This function
381   /// bufferizes the function signature and the ReturnOp. When the entire
382   /// function body has been bufferized, function return types can be switched
383   /// to more concise memref types as part of `foldMemRefCasts`.
384   ///
385   /// All function bbArgs are writable unless they are explicitly marked as
386   /// read-only. Callers must insert copies when needed.
387   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
388                           const BufferizationOptions &options) const {
389     auto funcOp = cast<FuncOp>(op);
390     FunctionType funcType = funcOp.getFunctionType();
391 
392     // Construct the bufferized function type.
393     SmallVector<Type> argTypes;
394     for (const auto &it : llvm::enumerate(funcType.getInputs())) {
395       Type argType = it.value();
396       if (auto tensorType = argType.dyn_cast<TensorType>()) {
397         argTypes.push_back(
398             getBufferizedFunctionArgType(funcOp, it.index(), options));
399         continue;
400       }
401       argTypes.push_back(argType);
402     }
403 
404     // Bodiless functions are assumed opaque and we cannot know the
405     // bufferization contract they want to enforce. As a consequence, only
406     // support functions that don't return any tensors atm.
407     if (funcOp.getBody().empty()) {
408       SmallVector<Type> retTypes;
409       for (Type resultType : funcType.getResults()) {
410         if (resultType.isa<TensorType>())
411           return funcOp->emitError() << "cannot bufferize bodiless function "
412                                      << "that returns a tensor";
413         retTypes.push_back(resultType);
414       }
415       funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
416       return success();
417     }
418 
419     // TODO: Support functions with multiple returns.
420     func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
421     assert(returnOp && "expected func with single return op");
422     Location loc = returnOp.getLoc();
423 
424     // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg.
425     Block &frontBlock = funcOp.getBody().front();
426     for (BlockArgument &bbArg : frontBlock.getArguments()) {
427       auto tensorType = bbArg.getType().dyn_cast<TensorType>();
428       // Non-tensor types stay the same.
429       if (!tensorType)
430         continue;
431 
432       // Collect all uses of the bbArg.
433       SmallVector<OpOperand *> bbArgUses;
434       for (OpOperand &use : bbArg.getUses())
435         bbArgUses.push_back(&use);
436 
437       // Change the bbArg type to memref.
438       Type memrefType =
439           getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options);
440       bbArg.setType(memrefType);
441 
442       // Replace all uses of the original tensor bbArg.
443       rewriter.setInsertionPointToStart(&frontBlock);
444       if (!bbArgUses.empty()) {
445         // Insert to_tensor because the remaining function body has not been
446         // bufferized yet.
447         Value toTensorOp =
448             rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg);
449         for (OpOperand *use : bbArgUses)
450           use->set(toTensorOp);
451       }
452     }
453 
454     // 2. For each result, keep track of which inplace argument it reuses.
455     SmallVector<Value> returnValues;
456     for (OpOperand &returnOperand : returnOp->getOpOperands()) {
457       Value returnVal = returnOperand.get();
458       auto tensorType = returnVal.getType().dyn_cast<TensorType>();
459       rewriter.setInsertionPoint(returnOp);
460 
461       // If not a tensor type just forward it.
462       if (!tensorType) {
463         returnValues.push_back(returnVal);
464         continue;
465       }
466 
467       BaseMemRefType resultType;
468       if (options.functionBoundaryTypeConversion ==
469           BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
470         resultType = getMemRefTypeWithStaticIdentityLayout(tensorType);
471       } else {
472         // Note: If `InferLayoutMap`, cast are later folded away.
473         resultType = getMemRefTypeWithFullyDynamicLayout(tensorType);
474       }
475       Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
476           loc, resultType, returnVal);
477       returnValues.push_back(toMemrefOp);
478     }
479 
480     // 3. Rewrite the terminator without the in-place bufferizable values.
481     returnOp.operandsMutable().assign(returnValues);
482 
483     // 4. Rewrite the FuncOp type to buffer form.
484     funcOp.setType(FunctionType::get(op->getContext(), argTypes,
485                                      ValueRange(returnValues).getTypes()));
486 
487     return success();
488   }
489 
490   /// Return `true` if the given function argument is writable.
491   bool isWritable(Operation *op, Value value,
492                   const AnalysisState &state) const {
493     auto funcOp = cast<FuncOp>(op);
494     BlockArgument bbArg = value.dyn_cast<BlockArgument>();
495     assert(bbArg && "expected BlockArgument");
496 
497     // "bufferization.writable" overrides other writability decisions. This is
498     // currently used for testing only.
499     if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>(
500             bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName))
501       return writable.getValue();
502 
503     // All function arguments are writable by default.
504     return true;
505   }
506 };
507 
508 } // namespace func_ext
509 } // namespace bufferization
510 } // namespace mlir
511 
512 void mlir::bufferization::func_ext::
513     registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
514   registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
515     func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
516     func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
517     func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
518   });
519 }
520