141cb504bSMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
241cb504bSMatthias Springer //
341cb504bSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
441cb504bSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
541cb504bSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
641cb504bSMatthias Springer //
741cb504bSMatthias Springer //===----------------------------------------------------------------------===//
841cb504bSMatthias Springer 
941cb504bSMatthias Springer #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
1041cb504bSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1141cb504bSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1241cb504bSMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h"
1341cb504bSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
1441cb504bSMatthias Springer #include "mlir/IR/Dialect.h"
1541cb504bSMatthias Springer #include "mlir/IR/Operation.h"
1641cb504bSMatthias Springer 
1741cb504bSMatthias Springer using namespace mlir;
1841cb504bSMatthias Springer using namespace linalg;
1941cb504bSMatthias Springer using namespace mlir::bufferization;
2041cb504bSMatthias Springer 
2141cb504bSMatthias Springer namespace {
2241cb504bSMatthias Springer 
2341cb504bSMatthias Springer /// Generic conversion for any LinalgOp on tensors.
bufferizeLinalgOp(RewriterBase & rewriter,LinalgOp op,const BufferizationOptions & options)2441cb504bSMatthias Springer static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
25b55d55ecSMatthias Springer                                        const BufferizationOptions &options) {
2641cb504bSMatthias Springer   // Take a guard before anything else.
2741cb504bSMatthias Springer   OpBuilder::InsertionGuard g(rewriter);
2841cb504bSMatthias Springer   rewriter.setInsertionPoint(op);
2941cb504bSMatthias Springer 
3041cb504bSMatthias Springer   // Nothing to do. This op is already bufferized.
3141cb504bSMatthias Springer   if (op.hasBufferSemantics())
3241cb504bSMatthias Springer     return success();
3341cb504bSMatthias Springer 
3441cb504bSMatthias Springer   // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need
3541cb504bSMatthias Springer   // basis.
3641cb504bSMatthias Springer   if (!op.hasTensorSemantics())
3741cb504bSMatthias Springer     return op->emitError() << "op does not have tensor semantics";
3841cb504bSMatthias Springer 
3941cb504bSMatthias Springer   // New input operands for the cloned op.
4041cb504bSMatthias Springer   SmallVector<Value> newInputBuffers;
4141cb504bSMatthias Springer   newInputBuffers.reserve(op.getNumInputs());
4241cb504bSMatthias Springer   for (OpOperand *opOperand : op.getInputOperands()) {
4341cb504bSMatthias Springer     if (op.isScalar(opOperand)) {
4441cb504bSMatthias Springer       newInputBuffers.push_back(opOperand->get());
4541cb504bSMatthias Springer       continue;
4641cb504bSMatthias Springer     }
47*5d50f51cSMatthias Springer     FailureOr<Value> buffer = getBuffer(rewriter, opOperand->get(), options);
48*5d50f51cSMatthias Springer     if (failed(buffer))
49*5d50f51cSMatthias Springer       return failure();
50*5d50f51cSMatthias Springer     newInputBuffers.push_back(*buffer);
5141cb504bSMatthias Springer   }
5241cb504bSMatthias Springer 
5341cb504bSMatthias Springer   // New output operands for the cloned op.
5441cb504bSMatthias Springer   SmallVector<Value> newOutputBuffers;
5541cb504bSMatthias Springer   for (OpResult opResult : op->getOpResults()) {
56b3ebe3beSMatthias Springer     OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber());
57*5d50f51cSMatthias Springer     FailureOr<Value> resultBuffer =
58*5d50f51cSMatthias Springer         getBuffer(rewriter, opOperand->get(), options);
59*5d50f51cSMatthias Springer     if (failed(resultBuffer))
60*5d50f51cSMatthias Springer       return failure();
61*5d50f51cSMatthias Springer     newOutputBuffers.push_back(*resultBuffer);
6241cb504bSMatthias Springer   }
6341cb504bSMatthias Springer 
6441cb504bSMatthias Springer   // Merge input/output operands.
6541cb504bSMatthias Springer   SmallVector<Value> newOperands = newInputBuffers;
6641cb504bSMatthias Springer   newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
6741cb504bSMatthias Springer 
6841cb504bSMatthias Springer   // Set insertion point now that potential alloc/dealloc are introduced.
6941cb504bSMatthias Springer   rewriter.setInsertionPoint(op);
7041cb504bSMatthias Springer   // Clone the op, but use the new operands. Move the existing block into the
7141cb504bSMatthias Springer   // new op. Since the new op does not have any tensor results, it does not
7241cb504bSMatthias Springer   // return anything.
7341cb504bSMatthias Springer   assert(op->getNumRegions() == 1 && "expected that op has 1 region");
7441cb504bSMatthias Springer   auto newOp = cast<LinalgOp>(op.cloneWithoutRegions(
7541cb504bSMatthias Springer       rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands));
7641cb504bSMatthias Springer   rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0),
7741cb504bSMatthias Springer                               newOp->getRegion(0).begin());
7841cb504bSMatthias Springer 
7941cb504bSMatthias Springer   // Replace the results of the old op with the new output buffers.
8041cb504bSMatthias Springer   replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);
8141cb504bSMatthias Springer 
8241cb504bSMatthias Springer   return success();
8341cb504bSMatthias Springer }
8441cb504bSMatthias Springer 
8541cb504bSMatthias Springer /// Bufferization of linalg.generic. Replace with a new linalg.generic that
8641cb504bSMatthias Springer /// operates entirely on memrefs.
8741cb504bSMatthias Springer template <typename OpTy>
8841cb504bSMatthias Springer struct LinalgOpInterface
8941cb504bSMatthias Springer     : public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>,
9041cb504bSMatthias Springer                                                     OpTy> {
bufferizesToMemoryRead__anon0de880b50111::LinalgOpInterface9141cb504bSMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
929597b16aSMatthias Springer                               const AnalysisState &state) const {
9341cb504bSMatthias Springer     // Operand is read if it is used in the computation.
9441cb504bSMatthias Springer     auto genericOp = cast<linalg::LinalgOp>(op);
9541cb504bSMatthias Springer     return genericOp.payloadUsesValueFromOperand(&opOperand);
9641cb504bSMatthias Springer   }
9741cb504bSMatthias Springer 
bufferizesToMemoryWrite__anon0de880b50111::LinalgOpInterface9841cb504bSMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
999597b16aSMatthias Springer                                const AnalysisState &state) const {
10025bc6846SMatthias Springer     // Operand is written to if it has an aliasing OpResult.
10141cb504bSMatthias Springer     auto bufferizableOp = cast<BufferizableOpInterface>(op);
10241cb504bSMatthias Springer     return !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
10341cb504bSMatthias Springer   }
10441cb504bSMatthias Springer 
10541cb504bSMatthias Springer   SmallVector<OpOperand *>
getAliasingOpOperand__anon0de880b50111::LinalgOpInterface10641cb504bSMatthias Springer   getAliasingOpOperand(Operation *op, OpResult opResult,
1079597b16aSMatthias Springer                        const AnalysisState &state) const {
10841cb504bSMatthias Springer     auto genericOp = cast<linalg::LinalgOp>(op);
10941cb504bSMatthias Springer 
110ad2e635fSMatthias Springer     // The i-th OpResult may alias with the i-th "out" tensor.
11125bc6846SMatthias Springer     return {genericOp.getOutputOperand(opResult.getResultNumber())};
11241cb504bSMatthias Springer   }
11341cb504bSMatthias Springer 
getAliasingOpResult__anon0de880b50111::LinalgOpInterface1149597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
1159597b16aSMatthias Springer                                             const AnalysisState &state) const {
11641cb504bSMatthias Springer     auto genericOp = cast<linalg::LinalgOp>(op);
11741cb504bSMatthias Springer 
118ad2e635fSMatthias Springer     // The i-th "out" tensor may alias with the i-th OpResult.
11925bc6846SMatthias Springer     if (genericOp.isOutputTensor(&opOperand))
12025bc6846SMatthias Springer       return {genericOp.getTiedOpResult(&opOperand)};
12125bc6846SMatthias Springer     return {};
12225bc6846SMatthias Springer   }
12325bc6846SMatthias Springer 
bufferRelation__anon0de880b50111::LinalgOpInterface12441cb504bSMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
1259597b16aSMatthias Springer                                 const AnalysisState &state) const {
12641cb504bSMatthias Springer     return BufferRelation::Equivalent;
12741cb504bSMatthias Springer   }
12841cb504bSMatthias Springer 
bufferize__anon0de880b50111::LinalgOpInterface12941cb504bSMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
130b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
131b55d55ecSMatthias Springer     return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), options);
13241cb504bSMatthias Springer   }
13341cb504bSMatthias Springer };
13441cb504bSMatthias Springer 
13541cb504bSMatthias Springer /// Helper structure that iterates over all LinalgOps in `OpTys` and registers
13641cb504bSMatthias Springer /// the `BufferizableOpInterface` with each of them.
13777eee579SRiver Riddle template <typename... Ops>
13877eee579SRiver Riddle struct LinalgOpInterfaceHelper {
registerOpInterface__anon0de880b50111::LinalgOpInterfaceHelper13977eee579SRiver Riddle   static void registerOpInterface(MLIRContext *ctx) {
14077eee579SRiver Riddle     (void)std::initializer_list<int>{
14177eee579SRiver Riddle         0, (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), 0)...};
14241cb504bSMatthias Springer   }
14341cb504bSMatthias Springer };
14441cb504bSMatthias Springer } // namespace
14541cb504bSMatthias Springer 
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)14641cb504bSMatthias Springer void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
14741cb504bSMatthias Springer     DialectRegistry &registry) {
14877eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
14941cb504bSMatthias Springer     // Register all Linalg structured ops. `LinalgOp` is an interface and it is
15041cb504bSMatthias Springer     // not possible to attach an external interface to an existing interface.
15141cb504bSMatthias Springer     // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one.
15241cb504bSMatthias Springer     LinalgOpInterfaceHelper<
15341cb504bSMatthias Springer #define GET_OP_LIST
15441cb504bSMatthias Springer #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
15577eee579SRiver Riddle         >::registerOpInterface(ctx);
15677eee579SRiver Riddle   });
15741cb504bSMatthias Springer }
158