15523c145SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
25523c145SMatthias Springer //
35523c145SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45523c145SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
55523c145SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65523c145SMatthias Springer //
75523c145SMatthias Springer //===----------------------------------------------------------------------===//
85523c145SMatthias Springer 
95523c145SMatthias Springer #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
105523c145SMatthias Springer 
115523c145SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
125523c145SMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
135523c145SMatthias Springer #include "mlir/IR/Dialect.h"
145523c145SMatthias Springer #include "mlir/IR/Operation.h"
155523c145SMatthias Springer 
165523c145SMatthias Springer using namespace mlir;
175523c145SMatthias Springer using namespace mlir::bufferization;
185523c145SMatthias Springer using namespace mlir::vector;
195523c145SMatthias Springer 
205523c145SMatthias Springer namespace mlir {
215523c145SMatthias Springer namespace vector {
225523c145SMatthias Springer namespace {
235523c145SMatthias Springer 
245523c145SMatthias Springer /// Bufferization of vector.transfer_read. Replaced with a new
255523c145SMatthias Springer /// vector.transfer_read that operates on a memref.
265523c145SMatthias Springer struct TransferReadOpInterface
275523c145SMatthias Springer     : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
285523c145SMatthias Springer                                                     vector::TransferReadOp> {
295523c145SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
309597b16aSMatthias Springer                               const AnalysisState &state) const {
315523c145SMatthias Springer     assert(opOperand.get().getType().isa<RankedTensorType>() &&
325523c145SMatthias Springer            "only tensor types expected");
335523c145SMatthias Springer     return true;
345523c145SMatthias Springer   }
355523c145SMatthias Springer 
365523c145SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
379597b16aSMatthias Springer                                const AnalysisState &state) const {
385523c145SMatthias Springer     assert(opOperand.get().getType().isa<RankedTensorType>() &&
395523c145SMatthias Springer            "only tensor types expected");
405523c145SMatthias Springer     return false;
415523c145SMatthias Springer   }
425523c145SMatthias Springer 
439597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
449597b16aSMatthias Springer                                             const AnalysisState &state) const {
45585a8a32SMatthias Springer     return {};
465523c145SMatthias Springer   }
475523c145SMatthias Springer 
485523c145SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
499597b16aSMatthias Springer                           BufferizationState &state) const {
505523c145SMatthias Springer     auto readOp = cast<vector::TransferReadOp>(op);
515523c145SMatthias Springer     assert(readOp.getShapedType().isa<TensorType>() &&
525523c145SMatthias Springer            "only tensor types expected");
535523c145SMatthias Springer 
545523c145SMatthias Springer     // TransferReadOp always reads from the bufferized op.source().
555523c145SMatthias Springer     Value buffer =
565523c145SMatthias Springer         *state.getBuffer(rewriter, readOp->getOpOperand(0) /*source*/);
575523c145SMatthias Springer     replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
585523c145SMatthias Springer         rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(),
595523c145SMatthias Springer         readOp.permutation_map(), readOp.padding(), readOp.mask(),
605523c145SMatthias Springer         readOp.in_boundsAttr());
615523c145SMatthias Springer     return success();
625523c145SMatthias Springer   }
635523c145SMatthias Springer };
645523c145SMatthias Springer 
655523c145SMatthias Springer /// Bufferization of vector.transfer_write. Replace with a new
665523c145SMatthias Springer /// vector.transfer_write that operates on a memref.
675523c145SMatthias Springer struct TransferWriteOpInterface
685523c145SMatthias Springer     : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
695523c145SMatthias Springer                                                     vector::TransferWriteOp> {
705523c145SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
719597b16aSMatthias Springer                               const AnalysisState &state) const {
725523c145SMatthias Springer     assert(opOperand.get().getType().isa<TensorType>() &&
735523c145SMatthias Springer            "only tensor types expected");
745523c145SMatthias Springer     return true;
755523c145SMatthias Springer   }
765523c145SMatthias Springer 
775523c145SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
789597b16aSMatthias Springer                                const AnalysisState &state) const {
795523c145SMatthias Springer     assert(opOperand.get().getType().isa<TensorType>() &&
805523c145SMatthias Springer            "only tensor types expected");
815523c145SMatthias Springer     return true;
825523c145SMatthias Springer   }
835523c145SMatthias Springer 
849597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
859597b16aSMatthias Springer                                             const AnalysisState &state) const {
865523c145SMatthias Springer     assert(opOperand.get().getType().isa<TensorType>() &&
875523c145SMatthias Springer            "only tensor types expected");
88585a8a32SMatthias Springer     return {op->getOpResult(0)};
895523c145SMatthias Springer   }
905523c145SMatthias Springer 
915523c145SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
929597b16aSMatthias Springer                                 const AnalysisState &state) const {
935523c145SMatthias Springer     return BufferRelation::Equivalent;
945523c145SMatthias Springer   }
955523c145SMatthias Springer 
965523c145SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
979597b16aSMatthias Springer                           BufferizationState &state) const {
985523c145SMatthias Springer     auto writeOp = cast<vector::TransferWriteOp>(op);
995523c145SMatthias Springer     assert(writeOp.getShapedType().isa<TensorType>() &&
1005523c145SMatthias Springer            "only tensor types expected");
1015523c145SMatthias Springer 
1025523c145SMatthias Springer     // Create a new transfer_write on buffer that doesn't have a return value.
1035523c145SMatthias Springer     // Leave the previous transfer_write to dead code as it still has uses at
1045523c145SMatthias Springer     // this point.
1055523c145SMatthias Springer     FailureOr<Value> resultBuffer =
1065523c145SMatthias Springer         state.getBuffer(rewriter, op->getOpOperand(1) /*source*/);
1075523c145SMatthias Springer     if (failed(resultBuffer))
1085523c145SMatthias Springer       return failure();
1095523c145SMatthias Springer     rewriter.create<vector::TransferWriteOp>(
1105523c145SMatthias Springer         writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(),
1115523c145SMatthias Springer         writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
1125523c145SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
1135523c145SMatthias Springer 
1145523c145SMatthias Springer     return success();
1155523c145SMatthias Springer   }
1165523c145SMatthias Springer };
1175523c145SMatthias Springer 
1185523c145SMatthias Springer } // namespace
1195523c145SMatthias Springer } // namespace vector
1205523c145SMatthias Springer } // namespace mlir
1215523c145SMatthias Springer 
1225523c145SMatthias Springer void mlir::vector::registerBufferizableOpInterfaceExternalModels(
1235523c145SMatthias Springer     DialectRegistry &registry) {
124*77eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
125*77eee579SRiver Riddle     TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
126*77eee579SRiver Riddle     TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
127*77eee579SRiver Riddle   });
1285523c145SMatthias Springer }
129