1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
17 
18 namespace mlir {
19 namespace bufferization {
20 namespace func_ext {
21 
22 void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
23   analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
24   auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
25   auto createdAliasingOperands =
26       aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping());
27   auto createdAliasingResults =
28       aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping());
29   auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet());
30   auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet());
31   (void)createdEquiv;
32   (void)createdAliasingOperands;
33   (void)createdAliasingResults;
34   (void)createdRead;
35   (void)createdWritten;
36 #ifndef NDEBUG
37   assert(createdEquiv.second && "equivalence info exists already");
38   assert(createdAliasingOperands.second && "aliasing info exists already");
39   assert(createdAliasingResults.second && "aliasing info exists already");
40   assert(createdRead.second && "bbarg access info exists already");
41   assert(createdWritten.second && "bbarg access info exists already");
42 #endif // NDEBUG
43 }
44 
45 /// Return the unique ReturnOp that terminates `funcOp`.
46 /// Return nullptr if there is no such unique ReturnOp.
47 static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
48   func::ReturnOp returnOp;
49   for (Block &b : funcOp.getBody()) {
50     if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
51       if (returnOp)
52         return nullptr;
53       returnOp = candidateOp;
54     }
55   }
56   return returnOp;
57 }
58 
59 /// Return the index-th bufferized function argument type. This assumes that the
60 /// specified argument is a tensor. If the tensor is ranked, a layout map may be
61 /// specified by the user. If no layout map is specified, a fully dynamic map is
62 /// used.
63 static BaseMemRefType
64 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
65                              const BufferizationOptions &options) {
66   auto tensorType =
67       funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
68   assert(tensorType && "expected TensorType");
69   BaseMemRefType memrefType = getMemRefType(tensorType, options);
70 
71   auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
72       index, BufferizationDialect::kBufferLayoutAttrName);
73   if (!layoutAttr)
74     return memrefType;
75 
76   auto rankedMemrefType = memrefType.dyn_cast<MemRefType>();
77   assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
78   return MemRefType::get(
79       rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
80       layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt());
81 }
82 
83 /// Return the FuncOp called by `callOp`.
84 static FuncOp getCalledFunction(CallOpInterface callOp) {
85   SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
86   if (!sym)
87     return nullptr;
88   return dyn_cast_or_null<FuncOp>(
89       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
90 }
91 
92 /// Get FuncAnalysisState.
93 static const FuncAnalysisState &
94 getFuncAnalysisState(const AnalysisState &state) {
95   Optional<const FuncAnalysisState *> maybeState =
96       state.getDialectState<FuncAnalysisState>(
97           func::FuncDialect::getDialectNamespace());
98   assert(maybeState.hasValue() && "FuncAnalysisState does not exist");
99   return **maybeState;
100 }
101 
102 /// Return the state (phase) of analysis of the FuncOp.
103 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
104                                                   FuncOp funcOp) {
105   const FuncAnalysisState &funcState = getFuncAnalysisState(state);
106   auto it = funcState.analyzedFuncOps.find(funcOp);
107   if (it == funcState.analyzedFuncOps.end())
108     return FuncOpAnalysisState::NotAnalyzed;
109   return it->second;
110 }
111 
112 /// Return the index of the bbArg in the given FuncOp that is equivalent to the
113 /// specified return value (if any).
114 static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp,
115                                                  const FuncAnalysisState &state,
116                                                  int64_t returnValIdx) {
117   auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
118   if (funcOpIt == state.equivalentFuncArgs.end())
119     // No equivalence info stores for funcOp.
120     return None;
121 
122   auto retValIt = funcOpIt->getSecond().find(returnValIdx);
123   if (retValIt == funcOpIt->getSecond().end())
124     // Return value has no equivalent bbArg.
125     return None;
126 
127   return retValIt->getSecond();
128 }
129 
130 struct CallOpInterface
131     : public BufferizableOpInterface::ExternalModel<CallOpInterface,
132                                                     func::CallOp> {
133   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
134                               const AnalysisState &state) const {
135     func::CallOp callOp = cast<func::CallOp>(op);
136     FuncOp funcOp = getCalledFunction(callOp);
137     assert(funcOp && "expected CallOp to a FuncOp");
138 
139     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
140     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
141       // FuncOp not analyzed yet. Assume that OpOperand is read.
142       return true;
143 
144     return funcState.readBbArgs.lookup(funcOp).contains(
145         opOperand.getOperandNumber());
146   }
147 
148   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
149                                const AnalysisState &state) const {
150     func::CallOp callOp = cast<func::CallOp>(op);
151     FuncOp funcOp = getCalledFunction(callOp);
152     assert(funcOp && "expected CallOp to a FuncOp");
153 
154     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
155     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
156       // FuncOp not analyzed yet. Assume that OpOperand is written.
157       return true;
158 
159     return funcState.writtenBbArgs.lookup(funcOp).contains(
160         opOperand.getOperandNumber());
161   }
162 
163   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
164                                             const AnalysisState &state) const {
165     func::CallOp callOp = cast<func::CallOp>(op);
166     FuncOp funcOp = getCalledFunction(callOp);
167     assert(funcOp && "expected CallOp to a FuncOp");
168     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
169     if (getFuncOpAnalysisState(state, funcOp) !=
170         FuncOpAnalysisState::Analyzed) {
171       // FuncOp not analyzed yet. Any OpResult may be aliasing.
172       SmallVector<OpResult> result;
173       for (OpResult opResult : op->getOpResults())
174         if (opResult.getType().isa<TensorType>())
175           result.push_back(opResult);
176       return result;
177     }
178 
179     // Get aliasing results from state.
180     auto aliasingReturnVals =
181         funcState.aliasingReturnVals.lookup(funcOp).lookup(
182             opOperand.getOperandNumber());
183     SmallVector<OpResult> result;
184     for (int64_t resultIdx : aliasingReturnVals)
185       result.push_back(callOp->getOpResult(resultIdx));
186     return result;
187   }
188 
189   SmallVector<OpOperand *>
190   getAliasingOpOperand(Operation *op, OpResult opResult,
191                        const AnalysisState &state) const {
192     func::CallOp callOp = cast<func::CallOp>(op);
193     FuncOp funcOp = getCalledFunction(callOp);
194     assert(funcOp && "expected CallOp to a FuncOp");
195     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
196     if (getFuncOpAnalysisState(state, funcOp) !=
197         FuncOpAnalysisState::Analyzed) {
198       // FuncOp not analyzed yet. Any OpOperand may be aliasing.
199       SmallVector<OpOperand *> result;
200       for (OpOperand &opOperand : op->getOpOperands())
201         if (opOperand.get().getType().isa<TensorType>())
202           result.push_back(&opOperand);
203       return result;
204     }
205 
206     // Get aliasing bbArgs from state.
207     auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup(
208         opResult.getResultNumber());
209     SmallVector<OpOperand *> result;
210     for (int64_t bbArgIdx : aliasingFuncArgs)
211       result.push_back(&callOp->getOpOperand(bbArgIdx));
212     return result;
213   }
214 
215   BufferRelation bufferRelation(Operation *op, OpResult opResult,
216                                 const AnalysisState &state) const {
217     return BufferRelation::Equivalent;
218   }
219 
220   /// All function arguments are writable. It is the responsibility of the
221   /// CallOp to insert buffer copies where necessary.
222   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
223                           BufferizationState &state) const {
224     func::CallOp callOp = cast<func::CallOp>(op);
225     unsigned numResults = callOp.getNumResults();
226     unsigned numOperands = callOp->getNumOperands();
227     FuncOp funcOp = getCalledFunction(callOp);
228     assert(funcOp && "expected CallOp to a FuncOp");
229     FunctionType funcType = funcOp.getFunctionType();
230     const FuncAnalysisState &funcState =
231         getFuncAnalysisState(state.getAnalysisState());
232     const OneShotBufferizationOptions &options =
233         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
234 
235     // Result types of the bufferized CallOp.
236     SmallVector<Type> resultTypes;
237     // Replacement values for the existing CallOp. These are usually the results
238     // of the bufferized CallOp, unless a tensor result folds onto an operand.
239     SmallVector<Value> replacementValues(numResults, Value());
240     // For non-tensor results: A mapping from return val indices of the old
241     // CallOp to return val indices of the bufferized CallOp.
242     SmallVector<Optional<unsigned>> retValMapping(numResults, None);
243     // Operands of the bufferized CallOp.
244     SmallVector<Value> newOperands(numOperands, Value());
245 
246     // Based on previously gathered equivalence information, we know if a
247     // tensor result folds onto an operand. These are the only tensor value
248     // results that are supported at the moment.
249     //
250     // For tensors return values that do not fold onto an operand, additional
251     // work is needed (TODO) to either:
252     // * hoist a result into an inplaceable operand or
253     // * devise a better representation to truly return a buffer.
254     //
255     // Note: If a function has no body, no equivalence information is
256     // available. Consequently, a tensor return value cannot be proven to fold
257     // onto a FuncOp bbArg, so calls to such functions are not bufferizable at
258     // the moment.
259 
260     // 1. Compute the result types of the new CallOp. Tensor results that are
261     // equivalent to a FuncOp bbArg are no longer returned.
262     for (const auto &it : llvm::enumerate(callOp.getResultTypes())) {
263       unsigned returnValIdx = it.index();
264       Type returnType = it.value();
265       if (!returnType.isa<TensorType>()) {
266         // Non-tensor values are returned.
267         retValMapping[returnValIdx] = resultTypes.size();
268         resultTypes.push_back(returnType);
269         continue;
270       }
271 
272       if (options.dropEquivalentFuncResults) {
273         if (Optional<int64_t> bbArgIdx =
274                 getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) {
275           // Return operands that are equivalent to some bbArg, are not
276           // returned.
277           FailureOr<Value> bufferOrFailure =
278               state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx));
279           if (failed(bufferOrFailure))
280             return failure();
281           replacementValues[returnValIdx] = *bufferOrFailure;
282           newOperands[*bbArgIdx] = *bufferOrFailure;
283           continue;
284         }
285       }
286 
287       if (!options.allowReturnAllocs)
288         return callOp->emitError(
289             "call to FuncOp that returns non-equivalent tensors not supported");
290 
291       // Returning a memref. This memref is not equivalent to any bbArg. It is
292       // likely a newly allocated buffer. We may want to hoist such allocations
293       // to the call site in the future.
294       retValMapping[returnValIdx] = resultTypes.size();
295       resultTypes.push_back(funcType.getResult(resultTypes.size()));
296     }
297 
298     // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
299     for (OpOperand &opOperand : callOp->getOpOperands()) {
300       unsigned idx = opOperand.getOperandNumber();
301       Value tensorOperand = opOperand.get();
302 
303       // Non-tensor operands are just copied.
304       if (!tensorOperand.getType().isa<TensorType>()) {
305         newOperands[idx] = tensorOperand;
306         continue;
307       }
308 
309       // Retrieve buffers for tensor operands. Tensor operand buffers, who's
310       // corresponding FuncOp bbArgs are equivalent to a returned tensor, were
311       // already stored in `newOperands` during Step 1.
312       Value buffer = newOperands[idx];
313       if (!buffer) {
314         FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, opOperand);
315         if (failed(bufferOrFailure))
316           return failure();
317         buffer = *bufferOrFailure;
318       }
319 
320       // Caller / callee type mismatch is handled with a CastOp.
321       auto memRefType = funcType.getInput(idx);
322       // Since we don't yet have a clear layout story, to_memref may
323       // conservatively turn tensors into more dynamic memref than necessary.
324       // If the memref type of the callee fails, introduce an extra memref.cast
325       // that will either canonicalize away or fail compilation until we can do
326       // something better.
327       if (buffer.getType() != memRefType) {
328         assert(
329             memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
330             "CallOp::bufferize: cast incompatible");
331         Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
332                                                            memRefType, buffer);
333         buffer = castBuffer;
334       }
335       newOperands[idx] = buffer;
336     }
337 
338     // 3. Create the new CallOp.
339     Operation *newCallOp = rewriter.create<func::CallOp>(
340         callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
341     newCallOp->setAttrs(callOp->getAttrs());
342     // Get replacement values for non-tensor / non-equivalent results.
343     for (unsigned i = 0; i < replacementValues.size(); ++i) {
344       if (replacementValues[i])
345         continue;
346       replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
347     }
348 
349     // 4. Replace the old op with the new op.
350     replaceOpWithBufferizedValues(rewriter, callOp, replacementValues);
351 
352     return success();
353   }
354 };
355 
356 struct ReturnOpInterface
357     : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
358                                                     func::ReturnOp> {
359   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
360                               const AnalysisState &state) const {
361     return true;
362   }
363 
364   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
365                                const AnalysisState &state) const {
366     return false;
367   }
368 
369   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
370                                             const AnalysisState &state) const {
371     return {};
372   }
373 
374   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
375                           BufferizationState &state) const {
376 #ifndef NDEBUG
377     auto returnOp = cast<func::ReturnOp>(op);
378     assert(isa<FuncOp>(returnOp->getParentOp()) &&
379            "only support FuncOp parent for ReturnOp");
380 #endif // NDEBUG
381 
382     // ReturnOps are bufferized as part of FuncOps.
383     return failure();
384   }
385 };
386 
387 struct FuncOpInterface
388     : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
389   /// Rewrite function bbArgs and return values into buffer form (using the
390   /// canonical memref layout for now). This function bufferizes the function
391   /// signature and the ReturnOp. When the entire function body has been
392   /// bufferized, function return types can be switched to more concise memref
393   /// types as part of `foldMemRefCasts`.
394   ///
395   /// When a tensor function argument is known to be equivalent to a tensor
396   /// result, it is dropped from the return values.
397   ///
398   /// All function bbArgs are writable unless they are explicitly marked as
399   /// read-only. Callers must insert copies when needed.
400   ///
401   /// Note: Returning a memref is possible, but corresponding CallOp
402   /// bufferizations fail unless `allowReturnAllocs`.
403   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
404                           BufferizationState &state) const {
405     auto funcOp = cast<FuncOp>(op);
406     FunctionType funcType = funcOp.getFunctionType();
407     const FuncAnalysisState &funcState =
408         getFuncAnalysisState(state.getAnalysisState());
409     const OneShotBufferizationOptions &options =
410         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
411 
412     // Construct the bufferized function type.
413     SmallVector<Type> argTypes;
414     for (const auto &it : llvm::enumerate(funcType.getInputs())) {
415       Type argType = it.value();
416       if (auto tensorType = argType.dyn_cast<TensorType>()) {
417         argTypes.push_back(
418             getBufferizedFunctionArgType(funcOp, it.index(), options));
419         continue;
420       }
421       argTypes.push_back(argType);
422     }
423 
424     // Bodiless functions are assumed opaque and we cannot know the
425     // bufferization contract they want to enforce. As a consequence, only
426     // support functions that don't return any tensors atm.
427     if (funcOp.getBody().empty()) {
428       SmallVector<Type> retTypes;
429       for (Type resultType : funcType.getResults()) {
430         if (resultType.isa<TensorType>())
431           return funcOp->emitError() << "cannot bufferize bodiless function "
432                                      << "that returns a tensor";
433         retTypes.push_back(resultType);
434       }
435       funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
436       return success();
437     }
438 
439     // TODO: Support functions with multiple returns.
440     func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
441     assert(returnOp && "expected func with single return op");
442 
443     // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg.
444     Block &frontBlock = funcOp.getBody().front();
445     for (BlockArgument &bbArg : frontBlock.getArguments()) {
446       auto tensorType = bbArg.getType().dyn_cast<TensorType>();
447       // Non-tensor types stay the same.
448       if (!tensorType)
449         continue;
450 
451       // Collect all uses of the bbArg.
452       SmallVector<OpOperand *> bbArgUses;
453       for (OpOperand &use : bbArg.getUses())
454         bbArgUses.push_back(&use);
455 
456       // Change the bbArg type to memref.
457       Type memrefType =
458           getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options);
459       bbArg.setType(memrefType);
460 
461       // Replace all uses of the original tensor bbArg.
462       rewriter.setInsertionPointToStart(&frontBlock);
463       if (!bbArgUses.empty()) {
464         // Insert to_tensor because the remaining function body has not been
465         // bufferized yet.
466         Value toTensorOp =
467             rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg);
468         for (OpOperand *use : bbArgUses)
469           use->set(toTensorOp);
470       }
471     }
472 
473     // 2. For each result, keep track of which inplace argument it reuses.
474     SmallVector<Value> returnValues;
475     for (OpOperand &returnOperand : returnOp->getOpOperands()) {
476       Value returnVal = returnOperand.get();
477 
478       // If not a tensor type just forward it.
479       if (!returnVal.getType().isa<RankedTensorType>()) {
480         returnValues.push_back(returnVal);
481         continue;
482       }
483 
484       // If return operand is equivalent to some bbArg, no need to return it.
485       if (options.dropEquivalentFuncResults) {
486         if (Optional<int64_t> equivBbArgIdx = getEquivalentFuncArgIdx(
487                 funcOp, funcState, returnOperand.getOperandNumber())) {
488           rewriter.setInsertionPoint(returnOp);
489           Location loc = returnOp.getLoc();
490           Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
491               loc,
492               getMemRefType(returnVal.getType().cast<TensorType>(), options),
493               returnVal);
494           BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx);
495           // Note: This copy will fold away. It must be inserted here to ensure
496           // that `returnVal` still has at least one use and does not fold away.
497           if (failed(
498                   createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options)))
499             return funcOp->emitError("could not generate copy for bbArg");
500           continue;
501         }
502       }
503 
504       returnValues.push_back(*state.getBuffer(rewriter, returnOperand));
505     }
506 
507     // 3. Rewrite the terminator without the in-place bufferizable values.
508     returnOp.operandsMutable().assign(returnValues);
509 
510     // 4. Rewrite the FuncOp type to buffer form.
511     funcOp.setType(FunctionType::get(op->getContext(), argTypes,
512                                      ValueRange(returnValues).getTypes()));
513 
514     return success();
515   }
516 
517   /// Return `true` if the given function argument is writable.
518   bool isWritable(Operation *op, Value value,
519                   const AnalysisState &state) const {
520     auto funcOp = cast<FuncOp>(op);
521     BlockArgument bbArg = value.dyn_cast<BlockArgument>();
522     assert(bbArg && "expected BlockArgument");
523 
524     // "bufferization.writable" overrides other writability decisions. This is
525     // currently used for testing only.
526     if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>(
527             bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName))
528       return writable.getValue();
529 
530     // All function arguments are writable by default.
531     return true;
532   }
533 
534   bool isAllocationHoistingBarrier(Operation *op) const { return true; }
535 };
536 
537 } // namespace func_ext
538 } // namespace bufferization
539 } // namespace mlir
540 
541 void mlir::bufferization::func_ext::
542     registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
543   registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
544     func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
545     func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
546     func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
547   });
548 }
549