1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
17 
18 namespace mlir {
19 namespace bufferization {
20 namespace func_ext {
21 
22 void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
23   analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
24   auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
25   auto createdAliasingOperands =
26       aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping());
27   auto createdAliasingResults =
28       aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping());
29   auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet());
30   auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet());
31   (void)createdEquiv;
32   (void)createdAliasingOperands;
33   (void)createdAliasingResults;
34   (void)createdRead;
35   (void)createdWritten;
36 #ifndef NDEBUG
37   assert(createdEquiv.second && "equivalence info exists already");
38   assert(createdAliasingOperands.second && "aliasing info exists already");
39   assert(createdAliasingResults.second && "aliasing info exists already");
40   assert(createdRead.second && "bbarg access info exists already");
41   assert(createdWritten.second && "bbarg access info exists already");
42 #endif // NDEBUG
43 }
44 
45 /// Return the unique ReturnOp that terminates `funcOp`.
46 /// Return nullptr if there is no such unique ReturnOp.
47 static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
48   func::ReturnOp returnOp;
49   for (Block &b : funcOp.getBody()) {
50     if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
51       if (returnOp)
52         return nullptr;
53       returnOp = candidateOp;
54     }
55   }
56   return returnOp;
57 }
58 
59 /// Return the index-th bufferized function argument type. This assumes that the
60 /// specified argument is a tensor. If the tensor is ranked, a layout map may be
61 /// specified by the user. If no layout map is specified, a fully dynamic map is
62 /// used.
63 static BaseMemRefType
64 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
65                              const BufferizationOptions &options) {
66   auto tensorType =
67       funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
68   assert(tensorType && "expected TensorType");
69   BaseMemRefType memrefType = getMemRefType(tensorType, options);
70 
71   auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
72       index, BufferizationDialect::kBufferLayoutAttrName);
73   if (!layoutAttr)
74     return memrefType;
75 
76   auto rankedMemrefType = memrefType.dyn_cast<MemRefType>();
77   assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
78   return MemRefType::get(
79       rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
80       layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt());
81 }
82 
83 /// Return the FuncOp called by `callOp`.
84 static FuncOp getCalledFunction(CallOpInterface callOp) {
85   SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
86   if (!sym)
87     return nullptr;
88   return dyn_cast_or_null<FuncOp>(
89       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
90 }
91 
92 /// Get FuncAnalysisState.
93 static const FuncAnalysisState &
94 getFuncAnalysisState(const AnalysisState &state) {
95   Optional<const FuncAnalysisState *> maybeState =
96       state.getDialectState<FuncAnalysisState>(
97           func::FuncDialect::getDialectNamespace());
98   assert(maybeState.hasValue() && "FuncAnalysisState does not exist");
99   return **maybeState;
100 }
101 
102 /// Return the state (phase) of analysis of the FuncOp.
103 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
104                                                   FuncOp funcOp) {
105   const FuncAnalysisState &funcState = getFuncAnalysisState(state);
106   auto it = funcState.analyzedFuncOps.find(funcOp);
107   if (it == funcState.analyzedFuncOps.end())
108     return FuncOpAnalysisState::NotAnalyzed;
109   return it->second;
110 }
111 
112 /// Return the index of the bbArg in the given FuncOp that is equivalent to the
113 /// specified return value (if any).
114 static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp,
115                                                  const FuncAnalysisState &state,
116                                                  int64_t returnValIdx) {
117   auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
118   if (funcOpIt == state.equivalentFuncArgs.end())
119     // No equivalence info stores for funcOp.
120     return None;
121 
122   auto retValIt = funcOpIt->getSecond().find(returnValIdx);
123   if (retValIt == funcOpIt->getSecond().end())
124     // Return value has no equivalent bbArg.
125     return None;
126 
127   return retValIt->getSecond();
128 }
129 
130 struct CallOpInterface
131     : public BufferizableOpInterface::ExternalModel<CallOpInterface,
132                                                     func::CallOp> {
133   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
134                               const AnalysisState &state) const {
135     func::CallOp callOp = cast<func::CallOp>(op);
136     FuncOp funcOp = getCalledFunction(callOp);
137     assert(funcOp && "expected CallOp to a FuncOp");
138 
139     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
140     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
141       // FuncOp not analyzed yet. Assume that OpOperand is read.
142       return true;
143 
144     return funcState.readBbArgs.lookup(funcOp).contains(
145         opOperand.getOperandNumber());
146   }
147 
148   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
149                                const AnalysisState &state) const {
150     func::CallOp callOp = cast<func::CallOp>(op);
151     FuncOp funcOp = getCalledFunction(callOp);
152     assert(funcOp && "expected CallOp to a FuncOp");
153 
154     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
155     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
156       // FuncOp not analyzed yet. Assume that OpOperand is written.
157       return true;
158 
159     return funcState.writtenBbArgs.lookup(funcOp).contains(
160         opOperand.getOperandNumber());
161   }
162 
163   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
164                                             const AnalysisState &state) const {
165     func::CallOp callOp = cast<func::CallOp>(op);
166     FuncOp funcOp = getCalledFunction(callOp);
167     assert(funcOp && "expected CallOp to a FuncOp");
168     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
169     if (getFuncOpAnalysisState(state, funcOp) !=
170         FuncOpAnalysisState::Analyzed) {
171       // FuncOp not analyzed yet. Any OpResult may be aliasing.
172       SmallVector<OpResult> result;
173       for (OpResult opResult : op->getOpResults())
174         if (opResult.getType().isa<TensorType>())
175           result.push_back(opResult);
176       return result;
177     }
178 
179     // Get aliasing results from state.
180     auto aliasingReturnVals =
181         funcState.aliasingReturnVals.lookup(funcOp).lookup(
182             opOperand.getOperandNumber());
183     SmallVector<OpResult> result;
184     for (int64_t resultIdx : aliasingReturnVals)
185       result.push_back(callOp->getOpResult(resultIdx));
186     return result;
187   }
188 
189   SmallVector<OpOperand *>
190   getAliasingOpOperand(Operation *op, OpResult opResult,
191                        const AnalysisState &state) const {
192     func::CallOp callOp = cast<func::CallOp>(op);
193     FuncOp funcOp = getCalledFunction(callOp);
194     assert(funcOp && "expected CallOp to a FuncOp");
195     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
196     if (getFuncOpAnalysisState(state, funcOp) !=
197         FuncOpAnalysisState::Analyzed) {
198       // FuncOp not analyzed yet. Any OpOperand may be aliasing.
199       SmallVector<OpOperand *> result;
200       for (OpOperand &opOperand : op->getOpOperands())
201         if (opOperand.get().getType().isa<TensorType>())
202           result.push_back(&opOperand);
203       return result;
204     }
205 
206     // Get aliasing bbArgs from state.
207     auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup(
208         opResult.getResultNumber());
209     SmallVector<OpOperand *> result;
210     for (int64_t bbArgIdx : aliasingFuncArgs)
211       result.push_back(&callOp->getOpOperand(bbArgIdx));
212     return result;
213   }
214 
215   BufferRelation bufferRelation(Operation *op, OpResult opResult,
216                                 const AnalysisState &state) const {
217     return BufferRelation::Equivalent;
218   }
219 
220   /// All function arguments are writable. It is the responsibility of the
221   /// CallOp to insert buffer copies where necessary.
222   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
223                           BufferizationState &state) const {
224     func::CallOp callOp = cast<func::CallOp>(op);
225     unsigned numResults = callOp.getNumResults();
226     unsigned numOperands = callOp->getNumOperands();
227     FuncOp funcOp = getCalledFunction(callOp);
228     assert(funcOp && "expected CallOp to a FuncOp");
229     FunctionType funcType = funcOp.getFunctionType();
230     const FuncAnalysisState &funcState =
231         getFuncAnalysisState(state.getAnalysisState());
232     const OneShotBufferizationOptions &options =
233         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
234 
235     // Result types of the bufferized CallOp.
236     SmallVector<Type> resultTypes;
237     // Replacement values for the existing CallOp. These are usually the results
238     // of the bufferized CallOp, unless a tensor result folds onto an operand.
239     SmallVector<Value> replacementValues(numResults, Value());
240     // For non-tensor results: A mapping from return val indices of the old
241     // CallOp to return val indices of the bufferized CallOp.
242     SmallVector<Optional<unsigned>> retValMapping(numResults, None);
243     // Operands of the bufferized CallOp.
244     SmallVector<Value> newOperands(numOperands, Value());
245 
246     // Based on previously gathered equivalence information, we know if a
247     // tensor result folds onto an operand. These are the only tensor value
248     // results that are supported at the moment.
249     //
250     // For tensors return values that do not fold onto an operand, additional
251     // work is needed (TODO) to either:
252     // * hoist a result into an inplaceable operand or
253     // * devise a better representation to truly return a buffer.
254     //
255     // Note: If a function has no body, no equivalence information is
256     // available. Consequently, a tensor return value cannot be proven to fold
257     // onto a FuncOp bbArg, so calls to such functions are not bufferizable at
258     // the moment.
259 
260     // 1. Compute the result types of the new CallOp. Tensor results that are
261     // equivalent to a FuncOp bbArg are no longer returned.
262     for (const auto &it : llvm::enumerate(callOp.getResultTypes())) {
263       unsigned returnValIdx = it.index();
264       Type returnType = it.value();
265       if (!returnType.isa<TensorType>()) {
266         // Non-tensor values are returned.
267         retValMapping[returnValIdx] = resultTypes.size();
268         resultTypes.push_back(returnType);
269         continue;
270       }
271 
272       if (Optional<int64_t> bbArgIdx =
273               getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) {
274         // Return operands that are equivalent to some bbArg, are not
275         // returned.
276         FailureOr<Value> bufferOrFailure =
277             state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx));
278         if (failed(bufferOrFailure))
279           return failure();
280         replacementValues[returnValIdx] = *bufferOrFailure;
281         newOperands[*bbArgIdx] = *bufferOrFailure;
282         continue;
283       }
284 
285       if (!options.allowReturnAllocs)
286         return callOp->emitError(
287             "call to FuncOp that returns non-equivalent tensors not supported");
288 
289       // Returning a memref. This memref is not equivalent to any bbArg. It is
290       // likely a newly allocated buffer. We may want to hoist such allocations
291       // to the call site in the future.
292       retValMapping[returnValIdx] = resultTypes.size();
293       resultTypes.push_back(funcType.getResult(resultTypes.size()));
294     }
295 
296     // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
297     for (OpOperand &opOperand : callOp->getOpOperands()) {
298       unsigned idx = opOperand.getOperandNumber();
299       Value tensorOperand = opOperand.get();
300 
301       // Non-tensor operands are just copied.
302       if (!tensorOperand.getType().isa<TensorType>()) {
303         newOperands[idx] = tensorOperand;
304         continue;
305       }
306 
307       // Retrieve buffers for tensor operands. Tensor operand buffers, who's
308       // corresponding FuncOp bbArgs are equivalent to a returned tensor, were
309       // already stored in `newOperands` during Step 1.
310       Value buffer = newOperands[idx];
311       if (!buffer) {
312         FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, opOperand);
313         if (failed(bufferOrFailure))
314           return failure();
315         buffer = *bufferOrFailure;
316       }
317 
318       // Caller / callee type mismatch is handled with a CastOp.
319       auto memRefType = funcType.getInput(idx);
320       // Since we don't yet have a clear layout story, to_memref may
321       // conservatively turn tensors into more dynamic memref than necessary.
322       // If the memref type of the callee fails, introduce an extra memref.cast
323       // that will either canonicalize away or fail compilation until we can do
324       // something better.
325       if (buffer.getType() != memRefType) {
326         assert(
327             memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
328             "CallOp::bufferize: cast incompatible");
329         Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
330                                                            memRefType, buffer);
331         buffer = castBuffer;
332       }
333       newOperands[idx] = buffer;
334     }
335 
336     // 3. Create the new CallOp.
337     Operation *newCallOp = rewriter.create<func::CallOp>(
338         callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
339     newCallOp->setAttrs(callOp->getAttrs());
340     // Get replacement values for non-tensor / non-equivalent results.
341     for (unsigned i = 0; i < replacementValues.size(); ++i) {
342       if (replacementValues[i])
343         continue;
344       replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
345     }
346 
347     // 4. Replace the old op with the new op.
348     replaceOpWithBufferizedValues(rewriter, callOp, replacementValues);
349 
350     return success();
351   }
352 };
353 
354 struct ReturnOpInterface
355     : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
356                                                     func::ReturnOp> {
357   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
358                               const AnalysisState &state) const {
359     return true;
360   }
361 
362   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
363                                const AnalysisState &state) const {
364     return false;
365   }
366 
367   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
368                                             const AnalysisState &state) const {
369     return {};
370   }
371 
372   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
373                           BufferizationState &state) const {
374 #ifndef NDEBUG
375     auto returnOp = cast<func::ReturnOp>(op);
376     assert(isa<FuncOp>(returnOp->getParentOp()) &&
377            "only support FuncOp parent for ReturnOp");
378 #endif // NDEBUG
379 
380     // ReturnOps are bufferized as part of FuncOps.
381     return failure();
382   }
383 };
384 
385 struct FuncOpInterface
386     : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
387   /// Rewrite function bbArgs and return values into buffer form (using the
388   /// canonical memref layout for now). This function bufferizes the function
389   /// signature and the ReturnOp. When the entire function body has been
390   /// bufferized, function return types can be switched to more concise memref
391   /// types as part of `foldMemRefCasts`.
392   ///
393   /// When a tensor function argument is known to be equivalent to a tensor
394   /// result, it is dropped from the return values.
395   ///
396   /// All function bbArgs are writable unless they are explicitly marked as
397   /// read-only. Callers must insert copies when needed.
398   ///
399   /// Note: Returning a memref is possible, but corresponding CallOp
400   /// bufferizations fail unless `allowReturnAllocs`.
401   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
402                           BufferizationState &state) const {
403     auto funcOp = cast<FuncOp>(op);
404     FunctionType funcType = funcOp.getFunctionType();
405     const FuncAnalysisState &funcState =
406         getFuncAnalysisState(state.getAnalysisState());
407     const BufferizationOptions &options = state.getOptions();
408 
409     // Construct the bufferized function type.
410     SmallVector<Type> argTypes;
411     for (const auto &it : llvm::enumerate(funcType.getInputs())) {
412       Type argType = it.value();
413       if (auto tensorType = argType.dyn_cast<TensorType>()) {
414         argTypes.push_back(
415             getBufferizedFunctionArgType(funcOp, it.index(), options));
416         continue;
417       }
418       argTypes.push_back(argType);
419     }
420 
421     // Bodiless functions are assumed opaque and we cannot know the
422     // bufferization contract they want to enforce. As a consequence, only
423     // support functions that don't return any tensors atm.
424     if (funcOp.getBody().empty()) {
425       SmallVector<Type> retTypes;
426       for (Type resultType : funcType.getResults()) {
427         if (resultType.isa<TensorType>())
428           return funcOp->emitError() << "cannot bufferize bodiless function "
429                                      << "that returns a tensor";
430         retTypes.push_back(resultType);
431       }
432       funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
433       return success();
434     }
435 
436     // TODO: Support functions with multiple returns.
437     func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
438     assert(returnOp && "expected func with single return op");
439 
440     // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg.
441     Block &frontBlock = funcOp.getBody().front();
442     for (BlockArgument &bbArg : frontBlock.getArguments()) {
443       auto tensorType = bbArg.getType().dyn_cast<TensorType>();
444       // Non-tensor types stay the same.
445       if (!tensorType)
446         continue;
447 
448       // Collect all uses of the bbArg.
449       SmallVector<OpOperand *> bbArgUses;
450       for (OpOperand &use : bbArg.getUses())
451         bbArgUses.push_back(&use);
452 
453       // Change the bbArg type to memref.
454       Type memrefType =
455           getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options);
456       bbArg.setType(memrefType);
457 
458       // Replace all uses of the original tensor bbArg.
459       rewriter.setInsertionPointToStart(&frontBlock);
460       if (!bbArgUses.empty()) {
461         // Insert to_tensor because the remaining function body has not been
462         // bufferized yet.
463         Value toTensorOp =
464             rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg);
465         for (OpOperand *use : bbArgUses)
466           use->set(toTensorOp);
467       }
468     }
469 
470     // 2. For each result, keep track of which inplace argument it reuses.
471     SmallVector<Value> returnValues;
472     for (OpOperand &returnOperand : returnOp->getOpOperands()) {
473       Value returnVal = returnOperand.get();
474 
475       // If not a tensor type just forward it.
476       if (!returnVal.getType().isa<RankedTensorType>()) {
477         returnValues.push_back(returnVal);
478         continue;
479       }
480 
481       // If return operand is equivalent to some bbArg, no need to return it.
482       if (Optional<int64_t> equivBbArgIdx = getEquivalentFuncArgIdx(
483               funcOp, funcState, returnOperand.getOperandNumber())) {
484         rewriter.setInsertionPoint(returnOp);
485         Location loc = returnOp.getLoc();
486         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
487             loc, getMemRefType(returnVal.getType().cast<TensorType>(), options),
488             returnVal);
489         BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx);
490         // Note: This copy will fold away. It must be inserted here to ensure
491         // that `returnVal` still has at least one use and does not fold away.
492         if (failed(
493                 createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options)))
494           return funcOp->emitError("could not generate copy for bbArg");
495         continue;
496       }
497 
498       returnValues.push_back(*state.getBuffer(rewriter, returnOperand));
499     }
500 
501     // 3. Rewrite the terminator without the in-place bufferizable values.
502     returnOp.operandsMutable().assign(returnValues);
503 
504     // 4. Rewrite the FuncOp type to buffer form.
505     funcOp.setType(FunctionType::get(op->getContext(), argTypes,
506                                      ValueRange(returnValues).getTypes()));
507 
508     return success();
509   }
510 
511   /// Return `true` if the given function argument is writable.
512   bool isWritable(Operation *op, Value value,
513                   const AnalysisState &state) const {
514     auto funcOp = cast<FuncOp>(op);
515     BlockArgument bbArg = value.dyn_cast<BlockArgument>();
516     assert(bbArg && "expected BlockArgument");
517 
518     // "bufferization.writable" overrides other writability decisions. This is
519     // currently used for testing only.
520     if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>(
521             bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName))
522       return writable.getValue();
523 
524     // All function arguments are writable by default.
525     return true;
526   }
527 
528   bool isAllocationHoistingBarrier(Operation *op) const { return true; }
529 };
530 
531 } // namespace func_ext
532 } // namespace bufferization
533 } // namespace mlir
534 
535 void mlir::bufferization::func_ext::
536     registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
537   registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
538     func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
539     func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
540     func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
541   });
542 }
543