15523c145SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
25523c145SMatthias Springer //
35523c145SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45523c145SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
55523c145SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65523c145SMatthias Springer //
75523c145SMatthias Springer //===----------------------------------------------------------------------===//
85523c145SMatthias Springer 
95523c145SMatthias Springer #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
105523c145SMatthias Springer 
115523c145SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
125523c145SMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
135523c145SMatthias Springer #include "mlir/IR/Dialect.h"
145523c145SMatthias Springer #include "mlir/IR/Operation.h"
155523c145SMatthias Springer 
165523c145SMatthias Springer using namespace mlir;
175523c145SMatthias Springer using namespace mlir::bufferization;
185523c145SMatthias Springer using namespace mlir::vector;
195523c145SMatthias Springer 
205523c145SMatthias Springer namespace mlir {
215523c145SMatthias Springer namespace vector {
225523c145SMatthias Springer namespace {
235523c145SMatthias Springer 
245523c145SMatthias Springer /// Bufferization of vector.transfer_read. Replaced with a new
255523c145SMatthias Springer /// vector.transfer_read that operates on a memref.
265523c145SMatthias Springer struct TransferReadOpInterface
275523c145SMatthias Springer     : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
285523c145SMatthias Springer                                                     vector::TransferReadOp> {
295523c145SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
305523c145SMatthias Springer                               const BufferizationState &state) const {
315523c145SMatthias Springer     assert(opOperand.get().getType().isa<RankedTensorType>() &&
325523c145SMatthias Springer            "only tensor types expected");
335523c145SMatthias Springer     return true;
345523c145SMatthias Springer   }
355523c145SMatthias Springer 
365523c145SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
375523c145SMatthias Springer                                const BufferizationState &state) const {
385523c145SMatthias Springer     assert(opOperand.get().getType().isa<RankedTensorType>() &&
395523c145SMatthias Springer            "only tensor types expected");
405523c145SMatthias Springer     return false;
415523c145SMatthias Springer   }
425523c145SMatthias Springer 
43*585a8a32SMatthias Springer   SmallVector<OpResult>
44*585a8a32SMatthias Springer   getAliasingOpResult(Operation *op, OpOperand &opOperand,
455523c145SMatthias Springer                       const BufferizationState &state) const {
46*585a8a32SMatthias Springer     return {};
475523c145SMatthias Springer   }
485523c145SMatthias Springer 
495523c145SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
505523c145SMatthias Springer                           const BufferizationState &state) const {
515523c145SMatthias Springer     auto readOp = cast<vector::TransferReadOp>(op);
525523c145SMatthias Springer     assert(readOp.getShapedType().isa<TensorType>() &&
535523c145SMatthias Springer            "only tensor types expected");
545523c145SMatthias Springer 
555523c145SMatthias Springer     // TransferReadOp always reads from the bufferized op.source().
565523c145SMatthias Springer     Value buffer =
575523c145SMatthias Springer         *state.getBuffer(rewriter, readOp->getOpOperand(0) /*source*/);
585523c145SMatthias Springer     replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
595523c145SMatthias Springer         rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(),
605523c145SMatthias Springer         readOp.permutation_map(), readOp.padding(), readOp.mask(),
615523c145SMatthias Springer         readOp.in_boundsAttr());
625523c145SMatthias Springer     return success();
635523c145SMatthias Springer   }
645523c145SMatthias Springer };
655523c145SMatthias Springer 
665523c145SMatthias Springer /// Bufferization of vector.transfer_write. Replace with a new
675523c145SMatthias Springer /// vector.transfer_write that operates on a memref.
685523c145SMatthias Springer struct TransferWriteOpInterface
695523c145SMatthias Springer     : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
705523c145SMatthias Springer                                                     vector::TransferWriteOp> {
715523c145SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
725523c145SMatthias Springer                               const BufferizationState &state) const {
735523c145SMatthias Springer     assert(opOperand.get().getType().isa<TensorType>() &&
745523c145SMatthias Springer            "only tensor types expected");
755523c145SMatthias Springer     return true;
765523c145SMatthias Springer   }
775523c145SMatthias Springer 
785523c145SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
795523c145SMatthias Springer                                const BufferizationState &state) const {
805523c145SMatthias Springer     assert(opOperand.get().getType().isa<TensorType>() &&
815523c145SMatthias Springer            "only tensor types expected");
825523c145SMatthias Springer     return true;
835523c145SMatthias Springer   }
845523c145SMatthias Springer 
85*585a8a32SMatthias Springer   SmallVector<OpResult>
86*585a8a32SMatthias Springer   getAliasingOpResult(Operation *op, OpOperand &opOperand,
875523c145SMatthias Springer                       const BufferizationState &state) const {
885523c145SMatthias Springer     assert(opOperand.get().getType().isa<TensorType>() &&
895523c145SMatthias Springer            "only tensor types expected");
90*585a8a32SMatthias Springer     return {op->getOpResult(0)};
915523c145SMatthias Springer   }
925523c145SMatthias Springer 
935523c145SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
945523c145SMatthias Springer                                 const BufferizationState &state) const {
955523c145SMatthias Springer     return BufferRelation::Equivalent;
965523c145SMatthias Springer   }
975523c145SMatthias Springer 
985523c145SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
995523c145SMatthias Springer                           const BufferizationState &state) const {
1005523c145SMatthias Springer     auto writeOp = cast<vector::TransferWriteOp>(op);
1015523c145SMatthias Springer     assert(writeOp.getShapedType().isa<TensorType>() &&
1025523c145SMatthias Springer            "only tensor types expected");
1035523c145SMatthias Springer 
1045523c145SMatthias Springer     // Create a new transfer_write on buffer that doesn't have a return value.
1055523c145SMatthias Springer     // Leave the previous transfer_write to dead code as it still has uses at
1065523c145SMatthias Springer     // this point.
1075523c145SMatthias Springer     FailureOr<Value> resultBuffer =
1085523c145SMatthias Springer         state.getBuffer(rewriter, op->getOpOperand(1) /*source*/);
1095523c145SMatthias Springer     if (failed(resultBuffer))
1105523c145SMatthias Springer       return failure();
1115523c145SMatthias Springer     rewriter.create<vector::TransferWriteOp>(
1125523c145SMatthias Springer         writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(),
1135523c145SMatthias Springer         writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
1145523c145SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
1155523c145SMatthias Springer 
1165523c145SMatthias Springer     return success();
1175523c145SMatthias Springer   }
1185523c145SMatthias Springer };
1195523c145SMatthias Springer 
1205523c145SMatthias Springer } // namespace
1215523c145SMatthias Springer } // namespace vector
1225523c145SMatthias Springer } // namespace mlir
1235523c145SMatthias Springer 
1245523c145SMatthias Springer void mlir::vector::registerBufferizableOpInterfaceExternalModels(
1255523c145SMatthias Springer     DialectRegistry &registry) {
1265523c145SMatthias Springer   registry.addOpInterface<TransferReadOp, TransferReadOpInterface>();
1275523c145SMatthias Springer   registry.addOpInterface<TransferWriteOp, TransferWriteOpInterface>();
1285523c145SMatthias Springer }
129