1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
17 
18 namespace mlir {
19 namespace bufferization {
20 namespace func_ext {
21 
startFunctionAnalysis(FuncOp funcOp)22 void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
23   analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
24   auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
25   auto createdAliasingOperands =
26       aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping());
27   auto createdAliasingResults =
28       aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping());
29   auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet());
30   auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet());
31   (void)createdEquiv;
32   (void)createdAliasingOperands;
33   (void)createdAliasingResults;
34   (void)createdRead;
35   (void)createdWritten;
36 #ifndef NDEBUG
37   assert(createdEquiv.second && "equivalence info exists already");
38   assert(createdAliasingOperands.second && "aliasing info exists already");
39   assert(createdAliasingResults.second && "aliasing info exists already");
40   assert(createdRead.second && "bbarg access info exists already");
41   assert(createdWritten.second && "bbarg access info exists already");
42 #endif // NDEBUG
43 }
44 
45 /// Return the unique ReturnOp that terminates `funcOp`.
46 /// Return nullptr if there is no such unique ReturnOp.
getAssumedUniqueReturnOp(FuncOp funcOp)47 static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
48   func::ReturnOp returnOp;
49   for (Block &b : funcOp.getBody()) {
50     if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
51       if (returnOp)
52         return nullptr;
53       returnOp = candidateOp;
54     }
55   }
56   return returnOp;
57 }
58 
59 /// Return the index-th bufferized function argument type. This assumes that the
60 /// specified argument is a tensor. If the tensor is ranked, a layout map may be
61 /// specified by the user. If no layout map is specified, the default layout map
62 /// (as per `options.functionBoundaryTypeConversion`) is used.
63 static BaseMemRefType
getBufferizedFunctionArgType(FuncOp funcOp,int64_t index,const BufferizationOptions & options)64 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
65                              const BufferizationOptions &options) {
66   auto tensorType =
67       funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
68   assert(tensorType && "expected TensorType");
69 
70   BaseMemRefType memrefType;
71   if (options.functionBoundaryTypeConversion ==
72       BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
73     memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
74   } else {
75     // Note: Layout maps on function parameters cannot be inferred. The best we
76     // can do at the moment is "fully dynamic".
77     memrefType = getMemRefTypeWithFullyDynamicLayout(tensorType);
78   }
79 
80   auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
81       index, BufferizationDialect::kBufferLayoutAttrName);
82   if (!layoutAttr)
83     return memrefType;
84 
85   auto rankedMemrefType = memrefType.dyn_cast<MemRefType>();
86   assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
87   return MemRefType::get(
88       rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
89       layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt());
90 }
91 
92 /// Return the FuncOp called by `callOp`.
getCalledFunction(CallOpInterface callOp)93 static FuncOp getCalledFunction(CallOpInterface callOp) {
94   SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
95   if (!sym)
96     return nullptr;
97   return dyn_cast_or_null<FuncOp>(
98       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
99 }
100 
101 /// Get FuncAnalysisState.
102 static const FuncAnalysisState &
getFuncAnalysisState(const AnalysisState & state)103 getFuncAnalysisState(const AnalysisState &state) {
104   Optional<const FuncAnalysisState *> maybeState =
105       state.getDialectState<FuncAnalysisState>(
106           func::FuncDialect::getDialectNamespace());
107   assert(maybeState && "FuncAnalysisState does not exist");
108   return **maybeState;
109 }
110 
111 /// Return the state (phase) of analysis of the FuncOp.
getFuncOpAnalysisState(const AnalysisState & state,FuncOp funcOp)112 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
113                                                   FuncOp funcOp) {
114   Optional<const FuncAnalysisState *> maybeState =
115       state.getDialectState<FuncAnalysisState>(
116           func::FuncDialect::getDialectNamespace());
117   if (!maybeState.has_value())
118     return FuncOpAnalysisState::NotAnalyzed;
119   const auto &analyzedFuncOps = maybeState.value()->analyzedFuncOps;
120   auto it = analyzedFuncOps.find(funcOp);
121   if (it == analyzedFuncOps.end())
122     return FuncOpAnalysisState::NotAnalyzed;
123   return it->second;
124 }
125 
126 /// Return the index of the bbArg in the given FuncOp that is equivalent to the
127 /// specified return value (if any).
getEquivalentFuncArgIdx(FuncOp funcOp,const FuncAnalysisState & state,int64_t returnValIdx)128 static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp,
129                                                  const FuncAnalysisState &state,
130                                                  int64_t returnValIdx) {
131   auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
132   if (funcOpIt == state.equivalentFuncArgs.end())
133     // No equivalence info stores for funcOp.
134     return None;
135 
136   auto retValIt = funcOpIt->getSecond().find(returnValIdx);
137   if (retValIt == funcOpIt->getSecond().end())
138     // Return value has no equivalent bbArg.
139     return None;
140 
141   return retValIt->getSecond();
142 }
143 
144 struct CallOpInterface
145     : public BufferizableOpInterface::ExternalModel<CallOpInterface,
146                                                     func::CallOp> {
bufferizesToMemoryReadmlir::bufferization::func_ext::CallOpInterface147   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
148                               const AnalysisState &state) const {
149     func::CallOp callOp = cast<func::CallOp>(op);
150     FuncOp funcOp = getCalledFunction(callOp);
151     assert(funcOp && "expected CallOp to a FuncOp");
152 
153     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
154       // FuncOp not analyzed yet. Assume that OpOperand is read.
155       return true;
156 
157     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
158     return funcState.readBbArgs.lookup(funcOp).contains(
159         opOperand.getOperandNumber());
160   }
161 
bufferizesToMemoryWritemlir::bufferization::func_ext::CallOpInterface162   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
163                                const AnalysisState &state) const {
164     func::CallOp callOp = cast<func::CallOp>(op);
165     FuncOp funcOp = getCalledFunction(callOp);
166     assert(funcOp && "expected CallOp to a FuncOp");
167 
168     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
169       // FuncOp not analyzed yet. Assume that OpOperand is written.
170       return true;
171 
172     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
173     return funcState.writtenBbArgs.lookup(funcOp).contains(
174         opOperand.getOperandNumber());
175   }
176 
getAliasingOpResultmlir::bufferization::func_ext::CallOpInterface177   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
178                                             const AnalysisState &state) const {
179     func::CallOp callOp = cast<func::CallOp>(op);
180     FuncOp funcOp = getCalledFunction(callOp);
181     assert(funcOp && "expected CallOp to a FuncOp");
182     if (getFuncOpAnalysisState(state, funcOp) !=
183         FuncOpAnalysisState::Analyzed) {
184       // FuncOp not analyzed yet. Any OpResult may be aliasing.
185       SmallVector<OpResult> result;
186       for (OpResult opResult : op->getOpResults())
187         if (opResult.getType().isa<TensorType>())
188           result.push_back(opResult);
189       return result;
190     }
191 
192     // Get aliasing results from state.
193     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
194     auto aliasingReturnVals =
195         funcState.aliasingReturnVals.lookup(funcOp).lookup(
196             opOperand.getOperandNumber());
197     SmallVector<OpResult> result;
198     for (int64_t resultIdx : aliasingReturnVals)
199       result.push_back(callOp->getOpResult(resultIdx));
200     return result;
201   }
202 
203   SmallVector<OpOperand *>
getAliasingOpOperandmlir::bufferization::func_ext::CallOpInterface204   getAliasingOpOperand(Operation *op, OpResult opResult,
205                        const AnalysisState &state) const {
206     func::CallOp callOp = cast<func::CallOp>(op);
207     FuncOp funcOp = getCalledFunction(callOp);
208     assert(funcOp && "expected CallOp to a FuncOp");
209     if (getFuncOpAnalysisState(state, funcOp) !=
210         FuncOpAnalysisState::Analyzed) {
211       // FuncOp not analyzed yet. Any OpOperand may be aliasing.
212       SmallVector<OpOperand *> result;
213       for (OpOperand &opOperand : op->getOpOperands())
214         if (opOperand.get().getType().isa<TensorType>())
215           result.push_back(&opOperand);
216       return result;
217     }
218 
219     // Get aliasing bbArgs from state.
220     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
221     auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup(
222         opResult.getResultNumber());
223     SmallVector<OpOperand *> result;
224     for (int64_t bbArgIdx : aliasingFuncArgs)
225       result.push_back(&callOp->getOpOperand(bbArgIdx));
226     return result;
227   }
228 
bufferRelationmlir::bufferization::func_ext::CallOpInterface229   BufferRelation bufferRelation(Operation *op, OpResult opResult,
230                                 const AnalysisState &state) const {
231     func::CallOp callOp = cast<func::CallOp>(op);
232     FuncOp funcOp = getCalledFunction(callOp);
233     assert(funcOp && "expected CallOp to a FuncOp");
234     if (getFuncOpAnalysisState(state, funcOp) !=
235         FuncOpAnalysisState::Analyzed) {
236       // Function not analyzed yet. The conservative answer is "None".
237       return BufferRelation::None;
238     }
239 
240     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
241     Optional<int64_t> maybeEquiv =
242         getEquivalentFuncArgIdx(funcOp, funcState, opResult.getResultNumber());
243     if (maybeEquiv) {
244 #ifndef NDEBUG
245       SmallVector<OpOperand *> aliasingOpOperands =
246           getAliasingOpOperand(op, opResult, state);
247       assert(aliasingOpOperands.size() == 1 &&
248              "expected exactly 1 aliasing OpOperand");
249       assert(aliasingOpOperands.front()->getOperandNumber() == *maybeEquiv &&
250              "inconsistent analysis state");
251 #endif
252       return BufferRelation::Equivalent;
253     }
254     return BufferRelation::None;
255   }
256 
257   /// All function arguments are writable. It is the responsibility of the
258   /// CallOp to insert buffer copies where necessary.
bufferizemlir::bufferization::func_ext::CallOpInterface259   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
260                           const BufferizationOptions &options) const {
261     func::CallOp callOp = cast<func::CallOp>(op);
262     unsigned numResults = callOp.getNumResults();
263     unsigned numOperands = callOp->getNumOperands();
264     FuncOp funcOp = getCalledFunction(callOp);
265     assert(funcOp && "expected CallOp to a FuncOp");
266     FunctionType funcType = funcOp.getFunctionType();
267 
268     // Result types of the bufferized CallOp.
269     SmallVector<Type> resultTypes;
270     // Replacement values for the existing CallOp. These are usually the results
271     // of the bufferized CallOp, unless a tensor result folds onto an operand.
272     SmallVector<Value> replacementValues(numResults, Value());
273     // For non-tensor results: A mapping from return val indices of the old
274     // CallOp to return val indices of the bufferized CallOp.
275     SmallVector<Optional<unsigned>> retValMapping(numResults, None);
276     // Operands of the bufferized CallOp.
277     SmallVector<Value> newOperands(numOperands, Value());
278 
279     // 1. Compute the result types of the new CallOp.
280     for (const auto &it : llvm::enumerate(callOp.getResultTypes())) {
281       unsigned returnValIdx = it.index();
282       Type returnType = it.value();
283       if (!returnType.isa<TensorType>()) {
284         // Non-tensor values are returned.
285         retValMapping[returnValIdx] = resultTypes.size();
286         resultTypes.push_back(returnType);
287         continue;
288       }
289 
290       // Returning a memref.
291       retValMapping[returnValIdx] = resultTypes.size();
292       resultTypes.push_back(funcType.getResult(resultTypes.size()));
293     }
294 
295     // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
296     for (OpOperand &opOperand : callOp->getOpOperands()) {
297       unsigned idx = opOperand.getOperandNumber();
298       Value tensorOperand = opOperand.get();
299 
300       // Non-tensor operands are just copied.
301       if (!tensorOperand.getType().isa<TensorType>()) {
302         newOperands[idx] = tensorOperand;
303         continue;
304       }
305 
306       // Retrieve buffers for tensor operands.
307       Value buffer = newOperands[idx];
308       if (!buffer) {
309         FailureOr<Value> maybeBuffer =
310             getBuffer(rewriter, opOperand.get(), options);
311         if (failed(maybeBuffer))
312           return failure();
313         buffer = *maybeBuffer;
314       }
315 
316       // Caller / callee type mismatch is handled with a CastOp.
317       auto memRefType = funcType.getInput(idx);
318       // Since we don't yet have a clear layout story, to_memref may
319       // conservatively turn tensors into more dynamic memref than necessary.
320       // If the memref type of the callee fails, introduce an extra memref.cast
321       // that will either canonicalize away or fail compilation until we can do
322       // something better.
323       if (buffer.getType() != memRefType) {
324         assert(
325             memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
326             "CallOp::bufferize: cast incompatible");
327         Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
328                                                            memRefType, buffer);
329         buffer = castBuffer;
330       }
331       newOperands[idx] = buffer;
332     }
333 
334     // 3. Create the new CallOp.
335     Operation *newCallOp = rewriter.create<func::CallOp>(
336         callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
337     newCallOp->setAttrs(callOp->getAttrs());
338     // Get replacement values.
339     for (unsigned i = 0; i < replacementValues.size(); ++i) {
340       if (replacementValues[i])
341         continue;
342       replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
343     }
344 
345     // 4. Replace the old op with the new op.
346     replaceOpWithBufferizedValues(rewriter, callOp, replacementValues);
347 
348     return success();
349   }
350 };
351 
352 struct ReturnOpInterface
353     : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
354                                                     func::ReturnOp> {
bufferizesToMemoryReadmlir::bufferization::func_ext::ReturnOpInterface355   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
356                               const AnalysisState &state) const {
357     return true;
358   }
359 
bufferizesToMemoryWritemlir::bufferization::func_ext::ReturnOpInterface360   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
361                                const AnalysisState &state) const {
362     return false;
363   }
364 
getAliasingOpResultmlir::bufferization::func_ext::ReturnOpInterface365   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
366                                             const AnalysisState &state) const {
367     return {};
368   }
369 
bufferizemlir::bufferization::func_ext::ReturnOpInterface370   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
371                           const BufferizationOptions &options) const {
372 #ifndef NDEBUG
373     auto returnOp = cast<func::ReturnOp>(op);
374     assert(isa<FuncOp>(returnOp->getParentOp()) &&
375            "only support FuncOp parent for ReturnOp");
376 #endif // NDEBUG
377 
378     // ReturnOps are bufferized as part of FuncOps.
379     return success();
380   }
381 };
382 
383 struct FuncOpInterface
384     : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
385   /// Rewrite function bbArgs and return values into buffer form. This function
386   /// bufferizes the function signature and the ReturnOp. When the entire
387   /// function body has been bufferized, function return types can be switched
388   /// to more concise memref types as part of `foldMemRefCasts`.
389   ///
390   /// All function bbArgs are writable unless they are explicitly marked as
391   /// read-only. Callers must insert copies when needed.
bufferizemlir::bufferization::func_ext::FuncOpInterface392   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
393                           const BufferizationOptions &options) const {
394     auto funcOp = cast<FuncOp>(op);
395     FunctionType funcType = funcOp.getFunctionType();
396 
397     // Construct the bufferized function type.
398     SmallVector<Type> argTypes;
399     for (const auto &it : llvm::enumerate(funcType.getInputs())) {
400       Type argType = it.value();
401       if (auto tensorType = argType.dyn_cast<TensorType>()) {
402         argTypes.push_back(
403             getBufferizedFunctionArgType(funcOp, it.index(), options));
404         continue;
405       }
406       argTypes.push_back(argType);
407     }
408 
409     // Bodiless functions are assumed opaque and we cannot know the
410     // bufferization contract they want to enforce. As a consequence, only
411     // support functions that don't return any tensors atm.
412     if (funcOp.getBody().empty()) {
413       SmallVector<Type> retTypes;
414       for (Type resultType : funcType.getResults()) {
415         if (resultType.isa<TensorType>())
416           return funcOp->emitError() << "cannot bufferize bodiless function "
417                                      << "that returns a tensor";
418         retTypes.push_back(resultType);
419       }
420       funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
421       return success();
422     }
423 
424     // TODO: Support functions with multiple returns.
425     func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
426     assert(returnOp && "expected func with single return op");
427     Location loc = returnOp.getLoc();
428 
429     // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg.
430     Block &frontBlock = funcOp.getBody().front();
431     for (BlockArgument &bbArg : frontBlock.getArguments()) {
432       auto tensorType = bbArg.getType().dyn_cast<TensorType>();
433       // Non-tensor types stay the same.
434       if (!tensorType)
435         continue;
436 
437       // Collect all uses of the bbArg.
438       SmallVector<OpOperand *> bbArgUses;
439       for (OpOperand &use : bbArg.getUses())
440         bbArgUses.push_back(&use);
441 
442       // Change the bbArg type to memref.
443       Type memrefType =
444           getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options);
445       bbArg.setType(memrefType);
446 
447       // Replace all uses of the original tensor bbArg.
448       rewriter.setInsertionPointToStart(&frontBlock);
449       if (!bbArgUses.empty()) {
450         // Insert to_tensor because the remaining function body has not been
451         // bufferized yet.
452         Value toTensorOp =
453             rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg);
454         for (OpOperand *use : bbArgUses)
455           use->set(toTensorOp);
456       }
457     }
458 
459     // 2. For each result, keep track of which inplace argument it reuses.
460     SmallVector<Value> returnValues;
461     for (OpOperand &returnOperand : returnOp->getOpOperands()) {
462       Value returnVal = returnOperand.get();
463       auto tensorType = returnVal.getType().dyn_cast<TensorType>();
464       rewriter.setInsertionPoint(returnOp);
465 
466       // If not a tensor type just forward it.
467       if (!tensorType) {
468         returnValues.push_back(returnVal);
469         continue;
470       }
471 
472       BaseMemRefType resultType;
473       if (options.functionBoundaryTypeConversion ==
474           BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
475         resultType = getMemRefTypeWithStaticIdentityLayout(tensorType);
476       } else {
477         // Note: If `InferLayoutMap`, cast are later folded away.
478         resultType = getMemRefTypeWithFullyDynamicLayout(tensorType);
479       }
480       Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
481           loc, resultType, returnVal);
482       returnValues.push_back(toMemrefOp);
483     }
484 
485     // 3. Rewrite the terminator without the in-place bufferizable values.
486     returnOp.operandsMutable().assign(returnValues);
487 
488     // 4. Rewrite the FuncOp type to buffer form.
489     funcOp.setType(FunctionType::get(op->getContext(), argTypes,
490                                      ValueRange(returnValues).getTypes()));
491 
492     return success();
493   }
494 
495   /// Return `true` if the given function argument is writable.
isWritablemlir::bufferization::func_ext::FuncOpInterface496   bool isWritable(Operation *op, Value value,
497                   const AnalysisState &state) const {
498     auto funcOp = cast<FuncOp>(op);
499     BlockArgument bbArg = value.dyn_cast<BlockArgument>();
500     assert(bbArg && "expected BlockArgument");
501 
502     // "bufferization.writable" overrides other writability decisions. This is
503     // currently used for testing only.
504     if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>(
505             bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName))
506       return writable.getValue();
507 
508     // All function arguments are writable by default.
509     return true;
510   }
511 };
512 
513 } // namespace func_ext
514 } // namespace bufferization
515 } // namespace mlir
516 
517 void mlir::bufferization::func_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)518     registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
519   registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
520     func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
521     func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
522     func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
523   });
524 }
525