141cb504bSMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
241cb504bSMatthias Springer //
341cb504bSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
441cb504bSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
541cb504bSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
641cb504bSMatthias Springer //
741cb504bSMatthias Springer //===----------------------------------------------------------------------===//
841cb504bSMatthias Springer
941cb504bSMatthias Springer #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
1041cb504bSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1141cb504bSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1241cb504bSMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h"
1341cb504bSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
1441cb504bSMatthias Springer #include "mlir/IR/Dialect.h"
1541cb504bSMatthias Springer #include "mlir/IR/Operation.h"
1641cb504bSMatthias Springer
1741cb504bSMatthias Springer using namespace mlir;
1841cb504bSMatthias Springer using namespace linalg;
1941cb504bSMatthias Springer using namespace mlir::bufferization;
2041cb504bSMatthias Springer
2141cb504bSMatthias Springer namespace {
2241cb504bSMatthias Springer
2341cb504bSMatthias Springer /// Generic conversion for any LinalgOp on tensors.
bufferizeLinalgOp(RewriterBase & rewriter,LinalgOp op,const BufferizationOptions & options)2441cb504bSMatthias Springer static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
25b55d55ecSMatthias Springer const BufferizationOptions &options) {
2641cb504bSMatthias Springer // Take a guard before anything else.
2741cb504bSMatthias Springer OpBuilder::InsertionGuard g(rewriter);
2841cb504bSMatthias Springer rewriter.setInsertionPoint(op);
2941cb504bSMatthias Springer
3041cb504bSMatthias Springer // Nothing to do. This op is already bufferized.
3141cb504bSMatthias Springer if (op.hasBufferSemantics())
3241cb504bSMatthias Springer return success();
3341cb504bSMatthias Springer
3441cb504bSMatthias Springer // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need
3541cb504bSMatthias Springer // basis.
3641cb504bSMatthias Springer if (!op.hasTensorSemantics())
3741cb504bSMatthias Springer return op->emitError() << "op does not have tensor semantics";
3841cb504bSMatthias Springer
3941cb504bSMatthias Springer // New input operands for the cloned op.
4041cb504bSMatthias Springer SmallVector<Value> newInputBuffers;
4141cb504bSMatthias Springer newInputBuffers.reserve(op.getNumInputs());
4241cb504bSMatthias Springer for (OpOperand *opOperand : op.getInputOperands()) {
4341cb504bSMatthias Springer if (op.isScalar(opOperand)) {
4441cb504bSMatthias Springer newInputBuffers.push_back(opOperand->get());
4541cb504bSMatthias Springer continue;
4641cb504bSMatthias Springer }
47*5d50f51cSMatthias Springer FailureOr<Value> buffer = getBuffer(rewriter, opOperand->get(), options);
48*5d50f51cSMatthias Springer if (failed(buffer))
49*5d50f51cSMatthias Springer return failure();
50*5d50f51cSMatthias Springer newInputBuffers.push_back(*buffer);
5141cb504bSMatthias Springer }
5241cb504bSMatthias Springer
5341cb504bSMatthias Springer // New output operands for the cloned op.
5441cb504bSMatthias Springer SmallVector<Value> newOutputBuffers;
5541cb504bSMatthias Springer for (OpResult opResult : op->getOpResults()) {
56b3ebe3beSMatthias Springer OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber());
57*5d50f51cSMatthias Springer FailureOr<Value> resultBuffer =
58*5d50f51cSMatthias Springer getBuffer(rewriter, opOperand->get(), options);
59*5d50f51cSMatthias Springer if (failed(resultBuffer))
60*5d50f51cSMatthias Springer return failure();
61*5d50f51cSMatthias Springer newOutputBuffers.push_back(*resultBuffer);
6241cb504bSMatthias Springer }
6341cb504bSMatthias Springer
6441cb504bSMatthias Springer // Merge input/output operands.
6541cb504bSMatthias Springer SmallVector<Value> newOperands = newInputBuffers;
6641cb504bSMatthias Springer newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
6741cb504bSMatthias Springer
6841cb504bSMatthias Springer // Set insertion point now that potential alloc/dealloc are introduced.
6941cb504bSMatthias Springer rewriter.setInsertionPoint(op);
7041cb504bSMatthias Springer // Clone the op, but use the new operands. Move the existing block into the
7141cb504bSMatthias Springer // new op. Since the new op does not have any tensor results, it does not
7241cb504bSMatthias Springer // return anything.
7341cb504bSMatthias Springer assert(op->getNumRegions() == 1 && "expected that op has 1 region");
7441cb504bSMatthias Springer auto newOp = cast<LinalgOp>(op.cloneWithoutRegions(
7541cb504bSMatthias Springer rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands));
7641cb504bSMatthias Springer rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0),
7741cb504bSMatthias Springer newOp->getRegion(0).begin());
7841cb504bSMatthias Springer
7941cb504bSMatthias Springer // Replace the results of the old op with the new output buffers.
8041cb504bSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);
8141cb504bSMatthias Springer
8241cb504bSMatthias Springer return success();
8341cb504bSMatthias Springer }
8441cb504bSMatthias Springer
8541cb504bSMatthias Springer /// Bufferization of linalg.generic. Replace with a new linalg.generic that
8641cb504bSMatthias Springer /// operates entirely on memrefs.
8741cb504bSMatthias Springer template <typename OpTy>
8841cb504bSMatthias Springer struct LinalgOpInterface
8941cb504bSMatthias Springer : public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>,
9041cb504bSMatthias Springer OpTy> {
bufferizesToMemoryRead__anon0de880b50111::LinalgOpInterface9141cb504bSMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
929597b16aSMatthias Springer const AnalysisState &state) const {
9341cb504bSMatthias Springer // Operand is read if it is used in the computation.
9441cb504bSMatthias Springer auto genericOp = cast<linalg::LinalgOp>(op);
9541cb504bSMatthias Springer return genericOp.payloadUsesValueFromOperand(&opOperand);
9641cb504bSMatthias Springer }
9741cb504bSMatthias Springer
bufferizesToMemoryWrite__anon0de880b50111::LinalgOpInterface9841cb504bSMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
999597b16aSMatthias Springer const AnalysisState &state) const {
10025bc6846SMatthias Springer // Operand is written to if it has an aliasing OpResult.
10141cb504bSMatthias Springer auto bufferizableOp = cast<BufferizableOpInterface>(op);
10241cb504bSMatthias Springer return !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
10341cb504bSMatthias Springer }
10441cb504bSMatthias Springer
10541cb504bSMatthias Springer SmallVector<OpOperand *>
getAliasingOpOperand__anon0de880b50111::LinalgOpInterface10641cb504bSMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult,
1079597b16aSMatthias Springer const AnalysisState &state) const {
10841cb504bSMatthias Springer auto genericOp = cast<linalg::LinalgOp>(op);
10941cb504bSMatthias Springer
110ad2e635fSMatthias Springer // The i-th OpResult may alias with the i-th "out" tensor.
11125bc6846SMatthias Springer return {genericOp.getOutputOperand(opResult.getResultNumber())};
11241cb504bSMatthias Springer }
11341cb504bSMatthias Springer
getAliasingOpResult__anon0de880b50111::LinalgOpInterface1149597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
1159597b16aSMatthias Springer const AnalysisState &state) const {
11641cb504bSMatthias Springer auto genericOp = cast<linalg::LinalgOp>(op);
11741cb504bSMatthias Springer
118ad2e635fSMatthias Springer // The i-th "out" tensor may alias with the i-th OpResult.
11925bc6846SMatthias Springer if (genericOp.isOutputTensor(&opOperand))
12025bc6846SMatthias Springer return {genericOp.getTiedOpResult(&opOperand)};
12125bc6846SMatthias Springer return {};
12225bc6846SMatthias Springer }
12325bc6846SMatthias Springer
bufferRelation__anon0de880b50111::LinalgOpInterface12441cb504bSMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult,
1259597b16aSMatthias Springer const AnalysisState &state) const {
12641cb504bSMatthias Springer return BufferRelation::Equivalent;
12741cb504bSMatthias Springer }
12841cb504bSMatthias Springer
bufferize__anon0de880b50111::LinalgOpInterface12941cb504bSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
130b55d55ecSMatthias Springer const BufferizationOptions &options) const {
131b55d55ecSMatthias Springer return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), options);
13241cb504bSMatthias Springer }
13341cb504bSMatthias Springer };
13441cb504bSMatthias Springer
13541cb504bSMatthias Springer /// Helper structure that iterates over all LinalgOps in `OpTys` and registers
13641cb504bSMatthias Springer /// the `BufferizableOpInterface` with each of them.
13777eee579SRiver Riddle template <typename... Ops>
13877eee579SRiver Riddle struct LinalgOpInterfaceHelper {
registerOpInterface__anon0de880b50111::LinalgOpInterfaceHelper13977eee579SRiver Riddle static void registerOpInterface(MLIRContext *ctx) {
14077eee579SRiver Riddle (void)std::initializer_list<int>{
14177eee579SRiver Riddle 0, (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), 0)...};
14241cb504bSMatthias Springer }
14341cb504bSMatthias Springer };
14441cb504bSMatthias Springer } // namespace
14541cb504bSMatthias Springer
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)14641cb504bSMatthias Springer void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
14741cb504bSMatthias Springer DialectRegistry ®istry) {
14877eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
14941cb504bSMatthias Springer // Register all Linalg structured ops. `LinalgOp` is an interface and it is
15041cb504bSMatthias Springer // not possible to attach an external interface to an existing interface.
15141cb504bSMatthias Springer // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one.
15241cb504bSMatthias Springer LinalgOpInterfaceHelper<
15341cb504bSMatthias Springer #define GET_OP_LIST
15441cb504bSMatthias Springer #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
15577eee579SRiver Riddle >::registerOpInterface(ctx);
15677eee579SRiver Riddle });
15741cb504bSMatthias Springer }
158