1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
17
18 namespace mlir {
19 namespace bufferization {
20 namespace func_ext {
21
startFunctionAnalysis(FuncOp funcOp)22 void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
23 analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
24 auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
25 auto createdAliasingOperands =
26 aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping());
27 auto createdAliasingResults =
28 aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping());
29 auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet());
30 auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet());
31 (void)createdEquiv;
32 (void)createdAliasingOperands;
33 (void)createdAliasingResults;
34 (void)createdRead;
35 (void)createdWritten;
36 #ifndef NDEBUG
37 assert(createdEquiv.second && "equivalence info exists already");
38 assert(createdAliasingOperands.second && "aliasing info exists already");
39 assert(createdAliasingResults.second && "aliasing info exists already");
40 assert(createdRead.second && "bbarg access info exists already");
41 assert(createdWritten.second && "bbarg access info exists already");
42 #endif // NDEBUG
43 }
44
45 /// Return the unique ReturnOp that terminates `funcOp`.
46 /// Return nullptr if there is no such unique ReturnOp.
getAssumedUniqueReturnOp(FuncOp funcOp)47 static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
48 func::ReturnOp returnOp;
49 for (Block &b : funcOp.getBody()) {
50 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
51 if (returnOp)
52 return nullptr;
53 returnOp = candidateOp;
54 }
55 }
56 return returnOp;
57 }
58
59 /// Return the index-th bufferized function argument type. This assumes that the
60 /// specified argument is a tensor. If the tensor is ranked, a layout map may be
61 /// specified by the user. If no layout map is specified, the default layout map
62 /// (as per `options.functionBoundaryTypeConversion`) is used.
63 static BaseMemRefType
getBufferizedFunctionArgType(FuncOp funcOp,int64_t index,const BufferizationOptions & options)64 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
65 const BufferizationOptions &options) {
66 auto tensorType =
67 funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
68 assert(tensorType && "expected TensorType");
69
70 BaseMemRefType memrefType;
71 if (options.functionBoundaryTypeConversion ==
72 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
73 memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
74 } else {
75 // Note: Layout maps on function parameters cannot be inferred. The best we
76 // can do at the moment is "fully dynamic".
77 memrefType = getMemRefTypeWithFullyDynamicLayout(tensorType);
78 }
79
80 auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
81 index, BufferizationDialect::kBufferLayoutAttrName);
82 if (!layoutAttr)
83 return memrefType;
84
85 auto rankedMemrefType = memrefType.dyn_cast<MemRefType>();
86 assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
87 return MemRefType::get(
88 rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
89 layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt());
90 }
91
92 /// Return the FuncOp called by `callOp`.
getCalledFunction(CallOpInterface callOp)93 static FuncOp getCalledFunction(CallOpInterface callOp) {
94 SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
95 if (!sym)
96 return nullptr;
97 return dyn_cast_or_null<FuncOp>(
98 SymbolTable::lookupNearestSymbolFrom(callOp, sym));
99 }
100
101 /// Get FuncAnalysisState.
102 static const FuncAnalysisState &
getFuncAnalysisState(const AnalysisState & state)103 getFuncAnalysisState(const AnalysisState &state) {
104 Optional<const FuncAnalysisState *> maybeState =
105 state.getDialectState<FuncAnalysisState>(
106 func::FuncDialect::getDialectNamespace());
107 assert(maybeState && "FuncAnalysisState does not exist");
108 return **maybeState;
109 }
110
111 /// Return the state (phase) of analysis of the FuncOp.
getFuncOpAnalysisState(const AnalysisState & state,FuncOp funcOp)112 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
113 FuncOp funcOp) {
114 Optional<const FuncAnalysisState *> maybeState =
115 state.getDialectState<FuncAnalysisState>(
116 func::FuncDialect::getDialectNamespace());
117 if (!maybeState.has_value())
118 return FuncOpAnalysisState::NotAnalyzed;
119 const auto &analyzedFuncOps = maybeState.value()->analyzedFuncOps;
120 auto it = analyzedFuncOps.find(funcOp);
121 if (it == analyzedFuncOps.end())
122 return FuncOpAnalysisState::NotAnalyzed;
123 return it->second;
124 }
125
126 /// Return the index of the bbArg in the given FuncOp that is equivalent to the
127 /// specified return value (if any).
getEquivalentFuncArgIdx(FuncOp funcOp,const FuncAnalysisState & state,int64_t returnValIdx)128 static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp,
129 const FuncAnalysisState &state,
130 int64_t returnValIdx) {
131 auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
132 if (funcOpIt == state.equivalentFuncArgs.end())
133 // No equivalence info stores for funcOp.
134 return None;
135
136 auto retValIt = funcOpIt->getSecond().find(returnValIdx);
137 if (retValIt == funcOpIt->getSecond().end())
138 // Return value has no equivalent bbArg.
139 return None;
140
141 return retValIt->getSecond();
142 }
143
144 struct CallOpInterface
145 : public BufferizableOpInterface::ExternalModel<CallOpInterface,
146 func::CallOp> {
bufferizesToMemoryReadmlir::bufferization::func_ext::CallOpInterface147 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
148 const AnalysisState &state) const {
149 func::CallOp callOp = cast<func::CallOp>(op);
150 FuncOp funcOp = getCalledFunction(callOp);
151 assert(funcOp && "expected CallOp to a FuncOp");
152
153 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
154 // FuncOp not analyzed yet. Assume that OpOperand is read.
155 return true;
156
157 const FuncAnalysisState &funcState = getFuncAnalysisState(state);
158 return funcState.readBbArgs.lookup(funcOp).contains(
159 opOperand.getOperandNumber());
160 }
161
bufferizesToMemoryWritemlir::bufferization::func_ext::CallOpInterface162 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
163 const AnalysisState &state) const {
164 func::CallOp callOp = cast<func::CallOp>(op);
165 FuncOp funcOp = getCalledFunction(callOp);
166 assert(funcOp && "expected CallOp to a FuncOp");
167
168 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
169 // FuncOp not analyzed yet. Assume that OpOperand is written.
170 return true;
171
172 const FuncAnalysisState &funcState = getFuncAnalysisState(state);
173 return funcState.writtenBbArgs.lookup(funcOp).contains(
174 opOperand.getOperandNumber());
175 }
176
getAliasingOpResultmlir::bufferization::func_ext::CallOpInterface177 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
178 const AnalysisState &state) const {
179 func::CallOp callOp = cast<func::CallOp>(op);
180 FuncOp funcOp = getCalledFunction(callOp);
181 assert(funcOp && "expected CallOp to a FuncOp");
182 if (getFuncOpAnalysisState(state, funcOp) !=
183 FuncOpAnalysisState::Analyzed) {
184 // FuncOp not analyzed yet. Any OpResult may be aliasing.
185 SmallVector<OpResult> result;
186 for (OpResult opResult : op->getOpResults())
187 if (opResult.getType().isa<TensorType>())
188 result.push_back(opResult);
189 return result;
190 }
191
192 // Get aliasing results from state.
193 const FuncAnalysisState &funcState = getFuncAnalysisState(state);
194 auto aliasingReturnVals =
195 funcState.aliasingReturnVals.lookup(funcOp).lookup(
196 opOperand.getOperandNumber());
197 SmallVector<OpResult> result;
198 for (int64_t resultIdx : aliasingReturnVals)
199 result.push_back(callOp->getOpResult(resultIdx));
200 return result;
201 }
202
203 SmallVector<OpOperand *>
getAliasingOpOperandmlir::bufferization::func_ext::CallOpInterface204 getAliasingOpOperand(Operation *op, OpResult opResult,
205 const AnalysisState &state) const {
206 func::CallOp callOp = cast<func::CallOp>(op);
207 FuncOp funcOp = getCalledFunction(callOp);
208 assert(funcOp && "expected CallOp to a FuncOp");
209 if (getFuncOpAnalysisState(state, funcOp) !=
210 FuncOpAnalysisState::Analyzed) {
211 // FuncOp not analyzed yet. Any OpOperand may be aliasing.
212 SmallVector<OpOperand *> result;
213 for (OpOperand &opOperand : op->getOpOperands())
214 if (opOperand.get().getType().isa<TensorType>())
215 result.push_back(&opOperand);
216 return result;
217 }
218
219 // Get aliasing bbArgs from state.
220 const FuncAnalysisState &funcState = getFuncAnalysisState(state);
221 auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup(
222 opResult.getResultNumber());
223 SmallVector<OpOperand *> result;
224 for (int64_t bbArgIdx : aliasingFuncArgs)
225 result.push_back(&callOp->getOpOperand(bbArgIdx));
226 return result;
227 }
228
bufferRelationmlir::bufferization::func_ext::CallOpInterface229 BufferRelation bufferRelation(Operation *op, OpResult opResult,
230 const AnalysisState &state) const {
231 func::CallOp callOp = cast<func::CallOp>(op);
232 FuncOp funcOp = getCalledFunction(callOp);
233 assert(funcOp && "expected CallOp to a FuncOp");
234 if (getFuncOpAnalysisState(state, funcOp) !=
235 FuncOpAnalysisState::Analyzed) {
236 // Function not analyzed yet. The conservative answer is "None".
237 return BufferRelation::None;
238 }
239
240 const FuncAnalysisState &funcState = getFuncAnalysisState(state);
241 Optional<int64_t> maybeEquiv =
242 getEquivalentFuncArgIdx(funcOp, funcState, opResult.getResultNumber());
243 if (maybeEquiv) {
244 #ifndef NDEBUG
245 SmallVector<OpOperand *> aliasingOpOperands =
246 getAliasingOpOperand(op, opResult, state);
247 assert(aliasingOpOperands.size() == 1 &&
248 "expected exactly 1 aliasing OpOperand");
249 assert(aliasingOpOperands.front()->getOperandNumber() == *maybeEquiv &&
250 "inconsistent analysis state");
251 #endif
252 return BufferRelation::Equivalent;
253 }
254 return BufferRelation::None;
255 }
256
257 /// All function arguments are writable. It is the responsibility of the
258 /// CallOp to insert buffer copies where necessary.
bufferizemlir::bufferization::func_ext::CallOpInterface259 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
260 const BufferizationOptions &options) const {
261 func::CallOp callOp = cast<func::CallOp>(op);
262 unsigned numResults = callOp.getNumResults();
263 unsigned numOperands = callOp->getNumOperands();
264 FuncOp funcOp = getCalledFunction(callOp);
265 assert(funcOp && "expected CallOp to a FuncOp");
266 FunctionType funcType = funcOp.getFunctionType();
267
268 // Result types of the bufferized CallOp.
269 SmallVector<Type> resultTypes;
270 // Replacement values for the existing CallOp. These are usually the results
271 // of the bufferized CallOp, unless a tensor result folds onto an operand.
272 SmallVector<Value> replacementValues(numResults, Value());
273 // For non-tensor results: A mapping from return val indices of the old
274 // CallOp to return val indices of the bufferized CallOp.
275 SmallVector<Optional<unsigned>> retValMapping(numResults, None);
276 // Operands of the bufferized CallOp.
277 SmallVector<Value> newOperands(numOperands, Value());
278
279 // 1. Compute the result types of the new CallOp.
280 for (const auto &it : llvm::enumerate(callOp.getResultTypes())) {
281 unsigned returnValIdx = it.index();
282 Type returnType = it.value();
283 if (!returnType.isa<TensorType>()) {
284 // Non-tensor values are returned.
285 retValMapping[returnValIdx] = resultTypes.size();
286 resultTypes.push_back(returnType);
287 continue;
288 }
289
290 // Returning a memref.
291 retValMapping[returnValIdx] = resultTypes.size();
292 resultTypes.push_back(funcType.getResult(resultTypes.size()));
293 }
294
295 // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
296 for (OpOperand &opOperand : callOp->getOpOperands()) {
297 unsigned idx = opOperand.getOperandNumber();
298 Value tensorOperand = opOperand.get();
299
300 // Non-tensor operands are just copied.
301 if (!tensorOperand.getType().isa<TensorType>()) {
302 newOperands[idx] = tensorOperand;
303 continue;
304 }
305
306 // Retrieve buffers for tensor operands.
307 Value buffer = newOperands[idx];
308 if (!buffer) {
309 FailureOr<Value> maybeBuffer =
310 getBuffer(rewriter, opOperand.get(), options);
311 if (failed(maybeBuffer))
312 return failure();
313 buffer = *maybeBuffer;
314 }
315
316 // Caller / callee type mismatch is handled with a CastOp.
317 auto memRefType = funcType.getInput(idx);
318 // Since we don't yet have a clear layout story, to_memref may
319 // conservatively turn tensors into more dynamic memref than necessary.
320 // If the memref type of the callee fails, introduce an extra memref.cast
321 // that will either canonicalize away or fail compilation until we can do
322 // something better.
323 if (buffer.getType() != memRefType) {
324 assert(
325 memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
326 "CallOp::bufferize: cast incompatible");
327 Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
328 memRefType, buffer);
329 buffer = castBuffer;
330 }
331 newOperands[idx] = buffer;
332 }
333
334 // 3. Create the new CallOp.
335 Operation *newCallOp = rewriter.create<func::CallOp>(
336 callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
337 newCallOp->setAttrs(callOp->getAttrs());
338 // Get replacement values.
339 for (unsigned i = 0; i < replacementValues.size(); ++i) {
340 if (replacementValues[i])
341 continue;
342 replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
343 }
344
345 // 4. Replace the old op with the new op.
346 replaceOpWithBufferizedValues(rewriter, callOp, replacementValues);
347
348 return success();
349 }
350 };
351
352 struct ReturnOpInterface
353 : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
354 func::ReturnOp> {
bufferizesToMemoryReadmlir::bufferization::func_ext::ReturnOpInterface355 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
356 const AnalysisState &state) const {
357 return true;
358 }
359
bufferizesToMemoryWritemlir::bufferization::func_ext::ReturnOpInterface360 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
361 const AnalysisState &state) const {
362 return false;
363 }
364
getAliasingOpResultmlir::bufferization::func_ext::ReturnOpInterface365 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
366 const AnalysisState &state) const {
367 return {};
368 }
369
bufferizemlir::bufferization::func_ext::ReturnOpInterface370 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
371 const BufferizationOptions &options) const {
372 #ifndef NDEBUG
373 auto returnOp = cast<func::ReturnOp>(op);
374 assert(isa<FuncOp>(returnOp->getParentOp()) &&
375 "only support FuncOp parent for ReturnOp");
376 #endif // NDEBUG
377
378 // ReturnOps are bufferized as part of FuncOps.
379 return success();
380 }
381 };
382
383 struct FuncOpInterface
384 : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
385 /// Rewrite function bbArgs and return values into buffer form. This function
386 /// bufferizes the function signature and the ReturnOp. When the entire
387 /// function body has been bufferized, function return types can be switched
388 /// to more concise memref types as part of `foldMemRefCasts`.
389 ///
390 /// All function bbArgs are writable unless they are explicitly marked as
391 /// read-only. Callers must insert copies when needed.
bufferizemlir::bufferization::func_ext::FuncOpInterface392 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
393 const BufferizationOptions &options) const {
394 auto funcOp = cast<FuncOp>(op);
395 FunctionType funcType = funcOp.getFunctionType();
396
397 // Construct the bufferized function type.
398 SmallVector<Type> argTypes;
399 for (const auto &it : llvm::enumerate(funcType.getInputs())) {
400 Type argType = it.value();
401 if (auto tensorType = argType.dyn_cast<TensorType>()) {
402 argTypes.push_back(
403 getBufferizedFunctionArgType(funcOp, it.index(), options));
404 continue;
405 }
406 argTypes.push_back(argType);
407 }
408
409 // Bodiless functions are assumed opaque and we cannot know the
410 // bufferization contract they want to enforce. As a consequence, only
411 // support functions that don't return any tensors atm.
412 if (funcOp.getBody().empty()) {
413 SmallVector<Type> retTypes;
414 for (Type resultType : funcType.getResults()) {
415 if (resultType.isa<TensorType>())
416 return funcOp->emitError() << "cannot bufferize bodiless function "
417 << "that returns a tensor";
418 retTypes.push_back(resultType);
419 }
420 funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
421 return success();
422 }
423
424 // TODO: Support functions with multiple returns.
425 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
426 assert(returnOp && "expected func with single return op");
427 Location loc = returnOp.getLoc();
428
429 // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg.
430 Block &frontBlock = funcOp.getBody().front();
431 for (BlockArgument &bbArg : frontBlock.getArguments()) {
432 auto tensorType = bbArg.getType().dyn_cast<TensorType>();
433 // Non-tensor types stay the same.
434 if (!tensorType)
435 continue;
436
437 // Collect all uses of the bbArg.
438 SmallVector<OpOperand *> bbArgUses;
439 for (OpOperand &use : bbArg.getUses())
440 bbArgUses.push_back(&use);
441
442 // Change the bbArg type to memref.
443 Type memrefType =
444 getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options);
445 bbArg.setType(memrefType);
446
447 // Replace all uses of the original tensor bbArg.
448 rewriter.setInsertionPointToStart(&frontBlock);
449 if (!bbArgUses.empty()) {
450 // Insert to_tensor because the remaining function body has not been
451 // bufferized yet.
452 Value toTensorOp =
453 rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg);
454 for (OpOperand *use : bbArgUses)
455 use->set(toTensorOp);
456 }
457 }
458
459 // 2. For each result, keep track of which inplace argument it reuses.
460 SmallVector<Value> returnValues;
461 for (OpOperand &returnOperand : returnOp->getOpOperands()) {
462 Value returnVal = returnOperand.get();
463 auto tensorType = returnVal.getType().dyn_cast<TensorType>();
464 rewriter.setInsertionPoint(returnOp);
465
466 // If not a tensor type just forward it.
467 if (!tensorType) {
468 returnValues.push_back(returnVal);
469 continue;
470 }
471
472 BaseMemRefType resultType;
473 if (options.functionBoundaryTypeConversion ==
474 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
475 resultType = getMemRefTypeWithStaticIdentityLayout(tensorType);
476 } else {
477 // Note: If `InferLayoutMap`, cast are later folded away.
478 resultType = getMemRefTypeWithFullyDynamicLayout(tensorType);
479 }
480 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
481 loc, resultType, returnVal);
482 returnValues.push_back(toMemrefOp);
483 }
484
485 // 3. Rewrite the terminator without the in-place bufferizable values.
486 returnOp.operandsMutable().assign(returnValues);
487
488 // 4. Rewrite the FuncOp type to buffer form.
489 funcOp.setType(FunctionType::get(op->getContext(), argTypes,
490 ValueRange(returnValues).getTypes()));
491
492 return success();
493 }
494
495 /// Return `true` if the given function argument is writable.
isWritablemlir::bufferization::func_ext::FuncOpInterface496 bool isWritable(Operation *op, Value value,
497 const AnalysisState &state) const {
498 auto funcOp = cast<FuncOp>(op);
499 BlockArgument bbArg = value.dyn_cast<BlockArgument>();
500 assert(bbArg && "expected BlockArgument");
501
502 // "bufferization.writable" overrides other writability decisions. This is
503 // currently used for testing only.
504 if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>(
505 bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName))
506 return writable.getValue();
507
508 // All function arguments are writable by default.
509 return true;
510 }
511 };
512
513 } // namespace func_ext
514 } // namespace bufferization
515 } // namespace mlir
516
517 void mlir::bufferization::func_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)518 registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
519 registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
520 func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
521 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
522 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
523 });
524 }
525