15523c145SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
25523c145SMatthias Springer //
35523c145SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45523c145SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
55523c145SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65523c145SMatthias Springer //
75523c145SMatthias Springer //===----------------------------------------------------------------------===//
85523c145SMatthias Springer 
95523c145SMatthias Springer #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
105523c145SMatthias Springer 
115523c145SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
125523c145SMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
135523c145SMatthias Springer #include "mlir/IR/Dialect.h"
145523c145SMatthias Springer #include "mlir/IR/Operation.h"
155523c145SMatthias Springer 
165523c145SMatthias Springer using namespace mlir;
175523c145SMatthias Springer using namespace mlir::bufferization;
185523c145SMatthias Springer using namespace mlir::vector;
195523c145SMatthias Springer 
205523c145SMatthias Springer namespace mlir {
215523c145SMatthias Springer namespace vector {
225523c145SMatthias Springer namespace {
235523c145SMatthias Springer 
245523c145SMatthias Springer /// Bufferization of vector.transfer_read. Replaced with a new
255523c145SMatthias Springer /// vector.transfer_read that operates on a memref.
265523c145SMatthias Springer struct TransferReadOpInterface
275523c145SMatthias Springer     : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
285523c145SMatthias Springer                                                     vector::TransferReadOp> {
bufferizesToMemoryReadmlir::vector::__anon48b3dc910111::TransferReadOpInterface295523c145SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
309597b16aSMatthias Springer                               const AnalysisState &state) const {
315523c145SMatthias Springer     assert(opOperand.get().getType().isa<RankedTensorType>() &&
325523c145SMatthias Springer            "only tensor types expected");
335523c145SMatthias Springer     return true;
345523c145SMatthias Springer   }
355523c145SMatthias Springer 
bufferizesToMemoryWritemlir::vector::__anon48b3dc910111::TransferReadOpInterface365523c145SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
379597b16aSMatthias Springer                                const AnalysisState &state) const {
385523c145SMatthias Springer     assert(opOperand.get().getType().isa<RankedTensorType>() &&
395523c145SMatthias Springer            "only tensor types expected");
405523c145SMatthias Springer     return false;
415523c145SMatthias Springer   }
425523c145SMatthias Springer 
getAliasingOpResultmlir::vector::__anon48b3dc910111::TransferReadOpInterface439597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
449597b16aSMatthias Springer                                             const AnalysisState &state) const {
45585a8a32SMatthias Springer     return {};
465523c145SMatthias Springer   }
475523c145SMatthias Springer 
bufferizemlir::vector::__anon48b3dc910111::TransferReadOpInterface485523c145SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
49b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
505523c145SMatthias Springer     auto readOp = cast<vector::TransferReadOp>(op);
515523c145SMatthias Springer     assert(readOp.getShapedType().isa<TensorType>() &&
525523c145SMatthias Springer            "only tensor types expected");
535d50f51cSMatthias Springer     FailureOr<Value> buffer = getBuffer(rewriter, readOp.getSource(), options);
545d50f51cSMatthias Springer     if (failed(buffer))
555d50f51cSMatthias Springer       return failure();
565523c145SMatthias Springer     replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
575d50f51cSMatthias Springer         rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
587c38fd60SJacques Pienaar         readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
597c38fd60SJacques Pienaar         readOp.getInBoundsAttr());
605523c145SMatthias Springer     return success();
615523c145SMatthias Springer   }
625523c145SMatthias Springer };
635523c145SMatthias Springer 
645523c145SMatthias Springer /// Bufferization of vector.transfer_write. Replace with a new
655523c145SMatthias Springer /// vector.transfer_write that operates on a memref.
665523c145SMatthias Springer struct TransferWriteOpInterface
675523c145SMatthias Springer     : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
685523c145SMatthias Springer                                                     vector::TransferWriteOp> {
bufferizesToMemoryReadmlir::vector::__anon48b3dc910111::TransferWriteOpInterface695523c145SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
709597b16aSMatthias Springer                               const AnalysisState &state) const {
715523c145SMatthias Springer     assert(opOperand.get().getType().isa<TensorType>() &&
725523c145SMatthias Springer            "only tensor types expected");
735523c145SMatthias Springer     return true;
745523c145SMatthias Springer   }
755523c145SMatthias Springer 
bufferizesToMemoryWritemlir::vector::__anon48b3dc910111::TransferWriteOpInterface765523c145SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
779597b16aSMatthias Springer                                const AnalysisState &state) const {
785523c145SMatthias Springer     assert(opOperand.get().getType().isa<TensorType>() &&
795523c145SMatthias Springer            "only tensor types expected");
805523c145SMatthias Springer     return true;
815523c145SMatthias Springer   }
825523c145SMatthias Springer 
getAliasingOpResultmlir::vector::__anon48b3dc910111::TransferWriteOpInterface839597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
849597b16aSMatthias Springer                                             const AnalysisState &state) const {
855523c145SMatthias Springer     assert(opOperand.get().getType().isa<TensorType>() &&
865523c145SMatthias Springer            "only tensor types expected");
87585a8a32SMatthias Springer     return {op->getOpResult(0)};
885523c145SMatthias Springer   }
895523c145SMatthias Springer 
bufferRelationmlir::vector::__anon48b3dc910111::TransferWriteOpInterface905523c145SMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
919597b16aSMatthias Springer                                 const AnalysisState &state) const {
925523c145SMatthias Springer     return BufferRelation::Equivalent;
935523c145SMatthias Springer   }
945523c145SMatthias Springer 
bufferizemlir::vector::__anon48b3dc910111::TransferWriteOpInterface955523c145SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
96b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
975523c145SMatthias Springer     auto writeOp = cast<vector::TransferWriteOp>(op);
985523c145SMatthias Springer     assert(writeOp.getShapedType().isa<TensorType>() &&
995523c145SMatthias Springer            "only tensor types expected");
1005523c145SMatthias Springer 
1015523c145SMatthias Springer     // Create a new transfer_write on buffer that doesn't have a return value.
1025d50f51cSMatthias Springer     FailureOr<Value> resultBuffer =
1035d50f51cSMatthias Springer         getBuffer(rewriter, writeOp.getSource(), options);
1045d50f51cSMatthias Springer     if (failed(resultBuffer))
1055d50f51cSMatthias Springer       return failure();
1065523c145SMatthias Springer     rewriter.create<vector::TransferWriteOp>(
1075d50f51cSMatthias Springer         writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
1087c38fd60SJacques Pienaar         writeOp.getIndices(), writeOp.getPermutationMapAttr(),
109*a28ce1a4SMatthias Springer         writeOp.getMask(), writeOp.getInBoundsAttr());
1105d50f51cSMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
1115523c145SMatthias Springer 
1125523c145SMatthias Springer     return success();
1135523c145SMatthias Springer   }
1145523c145SMatthias Springer };
1155523c145SMatthias Springer 
1165523c145SMatthias Springer } // namespace
1175523c145SMatthias Springer } // namespace vector
1185523c145SMatthias Springer } // namespace mlir
1195523c145SMatthias Springer 
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)1205523c145SMatthias Springer void mlir::vector::registerBufferizableOpInterfaceExternalModels(
1215523c145SMatthias Springer     DialectRegistry &registry) {
12277eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
12377eee579SRiver Riddle     TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
12477eee579SRiver Riddle     TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
12577eee579SRiver Riddle   });
1265523c145SMatthias Springer }
127