1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
17 
18 namespace mlir {
19 namespace bufferization {
20 namespace func_ext {
21 
22 void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
23   analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
24   auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
25   auto createdAliasingOperands =
26       aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping());
27   auto createdAliasingResults =
28       aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping());
29   auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet());
30   auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet());
31   (void)createdEquiv;
32   (void)createdAliasingOperands;
33   (void)createdAliasingResults;
34   (void)createdRead;
35   (void)createdWritten;
36 #ifndef NDEBUG
37   assert(createdEquiv.second && "equivalence info exists already");
38   assert(createdAliasingOperands.second && "aliasing info exists already");
39   assert(createdAliasingResults.second && "aliasing info exists already");
40   assert(createdRead.second && "bbarg access info exists already");
41   assert(createdWritten.second && "bbarg access info exists already");
42 #endif // NDEBUG
43 }
44 
45 /// Return the unique ReturnOp that terminates `funcOp`.
46 /// Return nullptr if there is no such unique ReturnOp.
47 static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
48   func::ReturnOp returnOp;
49   for (Block &b : funcOp.getBody()) {
50     if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
51       if (returnOp)
52         return nullptr;
53       returnOp = candidateOp;
54     }
55   }
56   return returnOp;
57 }
58 
59 /// Return the index-th bufferized function argument type. This assumes that the
60 /// specified argument is a tensor. If the tensor is ranked, a layout map may be
61 /// specified by the user. If no layout map is specified, the default layout map
62 /// (as per `options.functionBoundaryTypeConversion`) is used.
63 static BaseMemRefType
64 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
65                              const BufferizationOptions &options) {
66   auto tensorType =
67       funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
68   assert(tensorType && "expected TensorType");
69 
70   BaseMemRefType memrefType;
71   if (options.functionBoundaryTypeConversion ==
72       BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
73     memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
74   } else {
75     // Note: Layout maps on function parameters cannot be inferred. The best we
76     // can do at the moment is "fully dynamic".
77     memrefType = getMemRefTypeWithFullyDynamicLayout(tensorType);
78   }
79 
80   auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
81       index, BufferizationDialect::kBufferLayoutAttrName);
82   if (!layoutAttr)
83     return memrefType;
84 
85   auto rankedMemrefType = memrefType.dyn_cast<MemRefType>();
86   assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
87   return MemRefType::get(
88       rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
89       layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt());
90 }
91 
92 /// Return the FuncOp called by `callOp`.
93 static FuncOp getCalledFunction(CallOpInterface callOp) {
94   SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
95   if (!sym)
96     return nullptr;
97   return dyn_cast_or_null<FuncOp>(
98       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
99 }
100 
101 /// Get FuncAnalysisState.
102 static const FuncAnalysisState &
103 getFuncAnalysisState(const AnalysisState &state) {
104   Optional<const FuncAnalysisState *> maybeState =
105       state.getDialectState<FuncAnalysisState>(
106           func::FuncDialect::getDialectNamespace());
107   assert(maybeState.hasValue() && "FuncAnalysisState does not exist");
108   return **maybeState;
109 }
110 
111 /// Return the state (phase) of analysis of the FuncOp.
112 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
113                                                   FuncOp funcOp) {
114   const FuncAnalysisState &funcState = getFuncAnalysisState(state);
115   auto it = funcState.analyzedFuncOps.find(funcOp);
116   if (it == funcState.analyzedFuncOps.end())
117     return FuncOpAnalysisState::NotAnalyzed;
118   return it->second;
119 }
120 
121 /// Return the index of the bbArg in the given FuncOp that is equivalent to the
122 /// specified return value (if any).
123 static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp,
124                                                  const FuncAnalysisState &state,
125                                                  int64_t returnValIdx) {
126   auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
127   if (funcOpIt == state.equivalentFuncArgs.end())
128     // No equivalence info stores for funcOp.
129     return None;
130 
131   auto retValIt = funcOpIt->getSecond().find(returnValIdx);
132   if (retValIt == funcOpIt->getSecond().end())
133     // Return value has no equivalent bbArg.
134     return None;
135 
136   return retValIt->getSecond();
137 }
138 
139 struct CallOpInterface
140     : public BufferizableOpInterface::ExternalModel<CallOpInterface,
141                                                     func::CallOp> {
142   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
143                               const AnalysisState &state) const {
144     func::CallOp callOp = cast<func::CallOp>(op);
145     FuncOp funcOp = getCalledFunction(callOp);
146     assert(funcOp && "expected CallOp to a FuncOp");
147 
148     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
149     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
150       // FuncOp not analyzed yet. Assume that OpOperand is read.
151       return true;
152 
153     return funcState.readBbArgs.lookup(funcOp).contains(
154         opOperand.getOperandNumber());
155   }
156 
157   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
158                                const AnalysisState &state) const {
159     func::CallOp callOp = cast<func::CallOp>(op);
160     FuncOp funcOp = getCalledFunction(callOp);
161     assert(funcOp && "expected CallOp to a FuncOp");
162 
163     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
164     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
165       // FuncOp not analyzed yet. Assume that OpOperand is written.
166       return true;
167 
168     return funcState.writtenBbArgs.lookup(funcOp).contains(
169         opOperand.getOperandNumber());
170   }
171 
172   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
173                                             const AnalysisState &state) const {
174     func::CallOp callOp = cast<func::CallOp>(op);
175     FuncOp funcOp = getCalledFunction(callOp);
176     assert(funcOp && "expected CallOp to a FuncOp");
177     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
178     if (getFuncOpAnalysisState(state, funcOp) !=
179         FuncOpAnalysisState::Analyzed) {
180       // FuncOp not analyzed yet. Any OpResult may be aliasing.
181       SmallVector<OpResult> result;
182       for (OpResult opResult : op->getOpResults())
183         if (opResult.getType().isa<TensorType>())
184           result.push_back(opResult);
185       return result;
186     }
187 
188     // Get aliasing results from state.
189     auto aliasingReturnVals =
190         funcState.aliasingReturnVals.lookup(funcOp).lookup(
191             opOperand.getOperandNumber());
192     SmallVector<OpResult> result;
193     for (int64_t resultIdx : aliasingReturnVals)
194       result.push_back(callOp->getOpResult(resultIdx));
195     return result;
196   }
197 
198   SmallVector<OpOperand *>
199   getAliasingOpOperand(Operation *op, OpResult opResult,
200                        const AnalysisState &state) const {
201     func::CallOp callOp = cast<func::CallOp>(op);
202     FuncOp funcOp = getCalledFunction(callOp);
203     assert(funcOp && "expected CallOp to a FuncOp");
204     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
205     if (getFuncOpAnalysisState(state, funcOp) !=
206         FuncOpAnalysisState::Analyzed) {
207       // FuncOp not analyzed yet. Any OpOperand may be aliasing.
208       SmallVector<OpOperand *> result;
209       for (OpOperand &opOperand : op->getOpOperands())
210         if (opOperand.get().getType().isa<TensorType>())
211           result.push_back(&opOperand);
212       return result;
213     }
214 
215     // Get aliasing bbArgs from state.
216     auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup(
217         opResult.getResultNumber());
218     SmallVector<OpOperand *> result;
219     for (int64_t bbArgIdx : aliasingFuncArgs)
220       result.push_back(&callOp->getOpOperand(bbArgIdx));
221     return result;
222   }
223 
224   BufferRelation bufferRelation(Operation *op, OpResult opResult,
225                                 const AnalysisState &state) const {
226     return BufferRelation::Equivalent;
227   }
228 
229   /// All function arguments are writable. It is the responsibility of the
230   /// CallOp to insert buffer copies where necessary.
231   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
232                           BufferizationState &state) const {
233     func::CallOp callOp = cast<func::CallOp>(op);
234     unsigned numResults = callOp.getNumResults();
235     unsigned numOperands = callOp->getNumOperands();
236     FuncOp funcOp = getCalledFunction(callOp);
237     assert(funcOp && "expected CallOp to a FuncOp");
238     FunctionType funcType = funcOp.getFunctionType();
239     const FuncAnalysisState &funcState =
240         getFuncAnalysisState(state.getAnalysisState());
241     const OneShotBufferizationOptions &options =
242         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
243 
244     // Result types of the bufferized CallOp.
245     SmallVector<Type> resultTypes;
246     // Replacement values for the existing CallOp. These are usually the results
247     // of the bufferized CallOp, unless a tensor result folds onto an operand.
248     SmallVector<Value> replacementValues(numResults, Value());
249     // For non-tensor results: A mapping from return val indices of the old
250     // CallOp to return val indices of the bufferized CallOp.
251     SmallVector<Optional<unsigned>> retValMapping(numResults, None);
252     // Operands of the bufferized CallOp.
253     SmallVector<Value> newOperands(numOperands, Value());
254 
255     // Based on previously gathered equivalence information, we know if a
256     // tensor result folds onto an operand. These are the only tensor value
257     // results that are supported at the moment.
258     //
259     // For tensors return values that do not fold onto an operand, additional
260     // work is needed (TODO) to either:
261     // * hoist a result into an inplaceable operand or
262     // * devise a better representation to truly return a buffer.
263     //
264     // Note: If a function has no body, no equivalence information is
265     // available. Consequently, a tensor return value cannot be proven to fold
266     // onto a FuncOp bbArg, so calls to such functions are not bufferizable at
267     // the moment.
268 
269     // 1. Compute the result types of the new CallOp. Tensor results that are
270     // equivalent to a FuncOp bbArg are no longer returned.
271     for (const auto &it : llvm::enumerate(callOp.getResultTypes())) {
272       unsigned returnValIdx = it.index();
273       Type returnType = it.value();
274       if (!returnType.isa<TensorType>()) {
275         // Non-tensor values are returned.
276         retValMapping[returnValIdx] = resultTypes.size();
277         resultTypes.push_back(returnType);
278         continue;
279       }
280 
281       if (options.dropEquivalentFuncResults) {
282         if (Optional<int64_t> bbArgIdx =
283                 getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) {
284           // Return operands that are equivalent to some bbArg, are not
285           // returned.
286           FailureOr<Value> bufferOrFailure =
287               state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx));
288           if (failed(bufferOrFailure))
289             return failure();
290           replacementValues[returnValIdx] = *bufferOrFailure;
291           newOperands[*bbArgIdx] = *bufferOrFailure;
292           continue;
293         }
294       }
295 
296       if (!options.allowReturnAllocs)
297         return callOp->emitError(
298             "call to FuncOp that returns non-equivalent tensors not supported");
299 
300       // Returning a memref. This memref is not equivalent to any bbArg. It is
301       // likely a newly allocated buffer. We may want to hoist such allocations
302       // to the call site in the future.
303       retValMapping[returnValIdx] = resultTypes.size();
304       resultTypes.push_back(funcType.getResult(resultTypes.size()));
305     }
306 
307     // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
308     for (OpOperand &opOperand : callOp->getOpOperands()) {
309       unsigned idx = opOperand.getOperandNumber();
310       Value tensorOperand = opOperand.get();
311 
312       // Non-tensor operands are just copied.
313       if (!tensorOperand.getType().isa<TensorType>()) {
314         newOperands[idx] = tensorOperand;
315         continue;
316       }
317 
318       // Retrieve buffers for tensor operands. Tensor operand buffers, who's
319       // corresponding FuncOp bbArgs are equivalent to a returned tensor, were
320       // already stored in `newOperands` during Step 1.
321       Value buffer = newOperands[idx];
322       if (!buffer) {
323         FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, opOperand);
324         if (failed(bufferOrFailure))
325           return failure();
326         buffer = *bufferOrFailure;
327       }
328 
329       // Caller / callee type mismatch is handled with a CastOp.
330       auto memRefType = funcType.getInput(idx);
331       // Since we don't yet have a clear layout story, to_memref may
332       // conservatively turn tensors into more dynamic memref than necessary.
333       // If the memref type of the callee fails, introduce an extra memref.cast
334       // that will either canonicalize away or fail compilation until we can do
335       // something better.
336       if (buffer.getType() != memRefType) {
337         assert(
338             memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
339             "CallOp::bufferize: cast incompatible");
340         Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
341                                                            memRefType, buffer);
342         buffer = castBuffer;
343       }
344       newOperands[idx] = buffer;
345     }
346 
347     // 3. Create the new CallOp.
348     Operation *newCallOp = rewriter.create<func::CallOp>(
349         callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
350     newCallOp->setAttrs(callOp->getAttrs());
351     // Get replacement values for non-tensor / non-equivalent results.
352     for (unsigned i = 0; i < replacementValues.size(); ++i) {
353       if (replacementValues[i])
354         continue;
355       replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
356     }
357 
358     // 4. Replace the old op with the new op.
359     replaceOpWithBufferizedValues(rewriter, callOp, replacementValues);
360 
361     return success();
362   }
363 };
364 
365 struct ReturnOpInterface
366     : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
367                                                     func::ReturnOp> {
368   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
369                               const AnalysisState &state) const {
370     return true;
371   }
372 
373   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
374                                const AnalysisState &state) const {
375     return false;
376   }
377 
378   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
379                                             const AnalysisState &state) const {
380     return {};
381   }
382 
383   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
384                           BufferizationState &state) const {
385 #ifndef NDEBUG
386     auto returnOp = cast<func::ReturnOp>(op);
387     assert(isa<FuncOp>(returnOp->getParentOp()) &&
388            "only support FuncOp parent for ReturnOp");
389 #endif // NDEBUG
390 
391     // ReturnOps are bufferized as part of FuncOps.
392     return failure();
393   }
394 };
395 
396 struct FuncOpInterface
397     : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
398   /// Rewrite function bbArgs and return values into buffer form. This function
399   /// bufferizes the function signature and the ReturnOp. When the entire
400   /// function body has been bufferized, function return types can be switched
401   /// to more concise memref types as part of `foldMemRefCasts`.
402   ///
403   /// When a tensor function argument is known to be equivalent to a tensor
404   /// result, it is dropped from the return values.
405   ///
406   /// All function bbArgs are writable unless they are explicitly marked as
407   /// read-only. Callers must insert copies when needed.
408   ///
409   /// Note: Returning a memref is possible, but corresponding CallOp
410   /// bufferizations fail unless `allowReturnAllocs`.
411   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
412                           BufferizationState &state) const {
413     auto funcOp = cast<FuncOp>(op);
414     FunctionType funcType = funcOp.getFunctionType();
415     const FuncAnalysisState &funcState =
416         getFuncAnalysisState(state.getAnalysisState());
417     const OneShotBufferizationOptions &options =
418         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
419 
420     // Construct the bufferized function type.
421     SmallVector<Type> argTypes;
422     for (const auto &it : llvm::enumerate(funcType.getInputs())) {
423       Type argType = it.value();
424       if (auto tensorType = argType.dyn_cast<TensorType>()) {
425         argTypes.push_back(
426             getBufferizedFunctionArgType(funcOp, it.index(), options));
427         continue;
428       }
429       argTypes.push_back(argType);
430     }
431 
432     // Bodiless functions are assumed opaque and we cannot know the
433     // bufferization contract they want to enforce. As a consequence, only
434     // support functions that don't return any tensors atm.
435     if (funcOp.getBody().empty()) {
436       SmallVector<Type> retTypes;
437       for (Type resultType : funcType.getResults()) {
438         if (resultType.isa<TensorType>())
439           return funcOp->emitError() << "cannot bufferize bodiless function "
440                                      << "that returns a tensor";
441         retTypes.push_back(resultType);
442       }
443       funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
444       return success();
445     }
446 
447     // TODO: Support functions with multiple returns.
448     func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
449     assert(returnOp && "expected func with single return op");
450     Location loc = returnOp.getLoc();
451 
452     // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg.
453     Block &frontBlock = funcOp.getBody().front();
454     for (BlockArgument &bbArg : frontBlock.getArguments()) {
455       auto tensorType = bbArg.getType().dyn_cast<TensorType>();
456       // Non-tensor types stay the same.
457       if (!tensorType)
458         continue;
459 
460       // Collect all uses of the bbArg.
461       SmallVector<OpOperand *> bbArgUses;
462       for (OpOperand &use : bbArg.getUses())
463         bbArgUses.push_back(&use);
464 
465       // Change the bbArg type to memref.
466       Type memrefType =
467           getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options);
468       bbArg.setType(memrefType);
469 
470       // Replace all uses of the original tensor bbArg.
471       rewriter.setInsertionPointToStart(&frontBlock);
472       if (!bbArgUses.empty()) {
473         // Insert to_tensor because the remaining function body has not been
474         // bufferized yet.
475         Value toTensorOp =
476             rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg);
477         for (OpOperand *use : bbArgUses)
478           use->set(toTensorOp);
479       }
480     }
481 
482     // 2. For each result, keep track of which inplace argument it reuses.
483     SmallVector<Value> returnValues;
484     for (OpOperand &returnOperand : returnOp->getOpOperands()) {
485       Value returnVal = returnOperand.get();
486       auto tensorType = returnVal.getType().dyn_cast<TensorType>();
487       rewriter.setInsertionPoint(returnOp);
488 
489       // If not a tensor type just forward it.
490       if (!tensorType) {
491         returnValues.push_back(returnVal);
492         continue;
493       }
494 
495       // If return operand is equivalent to some bbArg, no need to return it.
496       if (options.dropEquivalentFuncResults) {
497         if (Optional<int64_t> equivBbArgIdx = getEquivalentFuncArgIdx(
498                 funcOp, funcState, returnOperand.getOperandNumber())) {
499           // TODO: Use memref type with fully dynamic layout map and add folder
500           // for memref.cast + memref.copy.
501           Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
502               loc, getMemRefType(tensorType, options), returnVal);
503           BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx);
504           // Note: This copy will fold away. It must be inserted here to ensure
505           // that `returnVal` still has at least one use and does not fold away.
506           if (failed(
507                   options.createMemCpy(rewriter, loc, toMemrefOp, equivBbArg)))
508             return funcOp->emitError("could not generate copy for bbArg");
509           continue;
510         }
511       }
512 
513       BaseMemRefType resultType;
514       if (options.functionBoundaryTypeConversion ==
515           BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
516         resultType = getMemRefTypeWithStaticIdentityLayout(tensorType);
517       } else {
518         // Note: If `InferLayoutMap`, cast are later folded away.
519         resultType = getMemRefTypeWithFullyDynamicLayout(tensorType);
520       }
521       Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
522           loc, resultType, returnVal);
523       returnValues.push_back(toMemrefOp);
524     }
525 
526     // 3. Rewrite the terminator without the in-place bufferizable values.
527     returnOp.operandsMutable().assign(returnValues);
528 
529     // 4. Rewrite the FuncOp type to buffer form.
530     funcOp.setType(FunctionType::get(op->getContext(), argTypes,
531                                      ValueRange(returnValues).getTypes()));
532 
533     return success();
534   }
535 
536   /// Return `true` if the given function argument is writable.
537   bool isWritable(Operation *op, Value value,
538                   const AnalysisState &state) const {
539     auto funcOp = cast<FuncOp>(op);
540     BlockArgument bbArg = value.dyn_cast<BlockArgument>();
541     assert(bbArg && "expected BlockArgument");
542 
543     // "bufferization.writable" overrides other writability decisions. This is
544     // currently used for testing only.
545     if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>(
546             bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName))
547       return writable.getValue();
548 
549     // All function arguments are writable by default.
550     return true;
551   }
552 
553   bool isAllocationHoistingBarrier(Operation *op) const { return true; }
554 };
555 
556 } // namespace func_ext
557 } // namespace bufferization
558 } // namespace mlir
559 
560 void mlir::bufferization::func_ext::
561     registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
562   registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
563     func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
564     func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
565     func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
566   });
567 }
568