15523c145SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
25523c145SMatthias Springer //
35523c145SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45523c145SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
55523c145SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65523c145SMatthias Springer //
75523c145SMatthias Springer //===----------------------------------------------------------------------===//
85523c145SMatthias Springer
95523c145SMatthias Springer #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
105523c145SMatthias Springer
115523c145SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
125523c145SMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
135523c145SMatthias Springer #include "mlir/IR/Dialect.h"
145523c145SMatthias Springer #include "mlir/IR/Operation.h"
155523c145SMatthias Springer
165523c145SMatthias Springer using namespace mlir;
175523c145SMatthias Springer using namespace mlir::bufferization;
185523c145SMatthias Springer using namespace mlir::vector;
195523c145SMatthias Springer
205523c145SMatthias Springer namespace mlir {
215523c145SMatthias Springer namespace vector {
225523c145SMatthias Springer namespace {
235523c145SMatthias Springer
245523c145SMatthias Springer /// Bufferization of vector.transfer_read. Replaced with a new
255523c145SMatthias Springer /// vector.transfer_read that operates on a memref.
265523c145SMatthias Springer struct TransferReadOpInterface
275523c145SMatthias Springer : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
285523c145SMatthias Springer vector::TransferReadOp> {
bufferizesToMemoryReadmlir::vector::__anon48b3dc910111::TransferReadOpInterface295523c145SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
309597b16aSMatthias Springer const AnalysisState &state) const {
315523c145SMatthias Springer assert(opOperand.get().getType().isa<RankedTensorType>() &&
325523c145SMatthias Springer "only tensor types expected");
335523c145SMatthias Springer return true;
345523c145SMatthias Springer }
355523c145SMatthias Springer
bufferizesToMemoryWritemlir::vector::__anon48b3dc910111::TransferReadOpInterface365523c145SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
379597b16aSMatthias Springer const AnalysisState &state) const {
385523c145SMatthias Springer assert(opOperand.get().getType().isa<RankedTensorType>() &&
395523c145SMatthias Springer "only tensor types expected");
405523c145SMatthias Springer return false;
415523c145SMatthias Springer }
425523c145SMatthias Springer
getAliasingOpResultmlir::vector::__anon48b3dc910111::TransferReadOpInterface439597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
449597b16aSMatthias Springer const AnalysisState &state) const {
45585a8a32SMatthias Springer return {};
465523c145SMatthias Springer }
475523c145SMatthias Springer
bufferizemlir::vector::__anon48b3dc910111::TransferReadOpInterface485523c145SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
49b55d55ecSMatthias Springer const BufferizationOptions &options) const {
505523c145SMatthias Springer auto readOp = cast<vector::TransferReadOp>(op);
515523c145SMatthias Springer assert(readOp.getShapedType().isa<TensorType>() &&
525523c145SMatthias Springer "only tensor types expected");
535d50f51cSMatthias Springer FailureOr<Value> buffer = getBuffer(rewriter, readOp.getSource(), options);
545d50f51cSMatthias Springer if (failed(buffer))
555d50f51cSMatthias Springer return failure();
565523c145SMatthias Springer replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
575d50f51cSMatthias Springer rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
587c38fd60SJacques Pienaar readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
597c38fd60SJacques Pienaar readOp.getInBoundsAttr());
605523c145SMatthias Springer return success();
615523c145SMatthias Springer }
625523c145SMatthias Springer };
635523c145SMatthias Springer
645523c145SMatthias Springer /// Bufferization of vector.transfer_write. Replace with a new
655523c145SMatthias Springer /// vector.transfer_write that operates on a memref.
665523c145SMatthias Springer struct TransferWriteOpInterface
675523c145SMatthias Springer : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
685523c145SMatthias Springer vector::TransferWriteOp> {
bufferizesToMemoryReadmlir::vector::__anon48b3dc910111::TransferWriteOpInterface695523c145SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
709597b16aSMatthias Springer const AnalysisState &state) const {
715523c145SMatthias Springer assert(opOperand.get().getType().isa<TensorType>() &&
725523c145SMatthias Springer "only tensor types expected");
735523c145SMatthias Springer return true;
745523c145SMatthias Springer }
755523c145SMatthias Springer
bufferizesToMemoryWritemlir::vector::__anon48b3dc910111::TransferWriteOpInterface765523c145SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
779597b16aSMatthias Springer const AnalysisState &state) const {
785523c145SMatthias Springer assert(opOperand.get().getType().isa<TensorType>() &&
795523c145SMatthias Springer "only tensor types expected");
805523c145SMatthias Springer return true;
815523c145SMatthias Springer }
825523c145SMatthias Springer
getAliasingOpResultmlir::vector::__anon48b3dc910111::TransferWriteOpInterface839597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
849597b16aSMatthias Springer const AnalysisState &state) const {
855523c145SMatthias Springer assert(opOperand.get().getType().isa<TensorType>() &&
865523c145SMatthias Springer "only tensor types expected");
87585a8a32SMatthias Springer return {op->getOpResult(0)};
885523c145SMatthias Springer }
895523c145SMatthias Springer
bufferRelationmlir::vector::__anon48b3dc910111::TransferWriteOpInterface905523c145SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult,
919597b16aSMatthias Springer const AnalysisState &state) const {
925523c145SMatthias Springer return BufferRelation::Equivalent;
935523c145SMatthias Springer }
945523c145SMatthias Springer
bufferizemlir::vector::__anon48b3dc910111::TransferWriteOpInterface955523c145SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
96b55d55ecSMatthias Springer const BufferizationOptions &options) const {
975523c145SMatthias Springer auto writeOp = cast<vector::TransferWriteOp>(op);
985523c145SMatthias Springer assert(writeOp.getShapedType().isa<TensorType>() &&
995523c145SMatthias Springer "only tensor types expected");
1005523c145SMatthias Springer
1015523c145SMatthias Springer // Create a new transfer_write on buffer that doesn't have a return value.
1025d50f51cSMatthias Springer FailureOr<Value> resultBuffer =
1035d50f51cSMatthias Springer getBuffer(rewriter, writeOp.getSource(), options);
1045d50f51cSMatthias Springer if (failed(resultBuffer))
1055d50f51cSMatthias Springer return failure();
1065523c145SMatthias Springer rewriter.create<vector::TransferWriteOp>(
1075d50f51cSMatthias Springer writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
1087c38fd60SJacques Pienaar writeOp.getIndices(), writeOp.getPermutationMapAttr(),
109*a28ce1a4SMatthias Springer writeOp.getMask(), writeOp.getInBoundsAttr());
1105d50f51cSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
1115523c145SMatthias Springer
1125523c145SMatthias Springer return success();
1135523c145SMatthias Springer }
1145523c145SMatthias Springer };
1155523c145SMatthias Springer
1165523c145SMatthias Springer } // namespace
1175523c145SMatthias Springer } // namespace vector
1185523c145SMatthias Springer } // namespace mlir
1195523c145SMatthias Springer
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)1205523c145SMatthias Springer void mlir::vector::registerBufferizableOpInterfaceExternalModels(
1215523c145SMatthias Springer DialectRegistry ®istry) {
12277eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
12377eee579SRiver Riddle TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
12477eee579SRiver Riddle TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
12577eee579SRiver Riddle });
1265523c145SMatthias Springer }
127