15523c145SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 25523c145SMatthias Springer // 35523c145SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 45523c145SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 55523c145SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65523c145SMatthias Springer // 75523c145SMatthias Springer //===----------------------------------------------------------------------===// 85523c145SMatthias Springer 95523c145SMatthias Springer #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" 105523c145SMatthias Springer 115523c145SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 125523c145SMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h" 135523c145SMatthias Springer #include "mlir/IR/Dialect.h" 145523c145SMatthias Springer #include "mlir/IR/Operation.h" 155523c145SMatthias Springer 165523c145SMatthias Springer using namespace mlir; 175523c145SMatthias Springer using namespace mlir::bufferization; 185523c145SMatthias Springer using namespace mlir::vector; 195523c145SMatthias Springer 205523c145SMatthias Springer namespace mlir { 215523c145SMatthias Springer namespace vector { 225523c145SMatthias Springer namespace { 235523c145SMatthias Springer 245523c145SMatthias Springer /// Bufferization of vector.transfer_read. Replaced with a new 255523c145SMatthias Springer /// vector.transfer_read that operates on a memref. 265523c145SMatthias Springer struct TransferReadOpInterface 275523c145SMatthias Springer : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface, 285523c145SMatthias Springer vector::TransferReadOp> { 295523c145SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 309597b16aSMatthias Springer const AnalysisState &state) const { 315523c145SMatthias Springer assert(opOperand.get().getType().isa<RankedTensorType>() && 325523c145SMatthias Springer "only tensor types expected"); 335523c145SMatthias Springer return true; 345523c145SMatthias Springer } 355523c145SMatthias Springer 365523c145SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 379597b16aSMatthias Springer const AnalysisState &state) const { 385523c145SMatthias Springer assert(opOperand.get().getType().isa<RankedTensorType>() && 395523c145SMatthias Springer "only tensor types expected"); 405523c145SMatthias Springer return false; 415523c145SMatthias Springer } 425523c145SMatthias Springer 439597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 449597b16aSMatthias Springer const AnalysisState &state) const { 45585a8a32SMatthias Springer return {}; 465523c145SMatthias Springer } 475523c145SMatthias Springer 485523c145SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 499597b16aSMatthias Springer BufferizationState &state) const { 505523c145SMatthias Springer auto readOp = cast<vector::TransferReadOp>(op); 515523c145SMatthias Springer assert(readOp.getShapedType().isa<TensorType>() && 525523c145SMatthias Springer "only tensor types expected"); 535523c145SMatthias Springer 545523c145SMatthias Springer // TransferReadOp always reads from the bufferized op.source(). 555523c145SMatthias Springer Value buffer = 565523c145SMatthias Springer *state.getBuffer(rewriter, readOp->getOpOperand(0) /*source*/); 575523c145SMatthias Springer replaceOpWithNewBufferizedOp<vector::TransferReadOp>( 58*7c38fd60SJacques Pienaar rewriter, readOp, readOp.getVectorType(), buffer, readOp.getIndices(), 59*7c38fd60SJacques Pienaar readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(), 60*7c38fd60SJacques Pienaar readOp.getInBoundsAttr()); 615523c145SMatthias Springer return success(); 625523c145SMatthias Springer } 635523c145SMatthias Springer }; 645523c145SMatthias Springer 655523c145SMatthias Springer /// Bufferization of vector.transfer_write. Replace with a new 665523c145SMatthias Springer /// vector.transfer_write that operates on a memref. 675523c145SMatthias Springer struct TransferWriteOpInterface 685523c145SMatthias Springer : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface, 695523c145SMatthias Springer vector::TransferWriteOp> { 705523c145SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 719597b16aSMatthias Springer const AnalysisState &state) const { 725523c145SMatthias Springer assert(opOperand.get().getType().isa<TensorType>() && 735523c145SMatthias Springer "only tensor types expected"); 745523c145SMatthias Springer return true; 755523c145SMatthias Springer } 765523c145SMatthias Springer 775523c145SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 789597b16aSMatthias Springer const AnalysisState &state) const { 795523c145SMatthias Springer assert(opOperand.get().getType().isa<TensorType>() && 805523c145SMatthias Springer "only tensor types expected"); 815523c145SMatthias Springer return true; 825523c145SMatthias Springer } 835523c145SMatthias Springer 849597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 859597b16aSMatthias Springer const AnalysisState &state) const { 865523c145SMatthias Springer assert(opOperand.get().getType().isa<TensorType>() && 875523c145SMatthias Springer "only tensor types expected"); 88585a8a32SMatthias Springer return {op->getOpResult(0)}; 895523c145SMatthias Springer } 905523c145SMatthias Springer 915523c145SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 929597b16aSMatthias Springer const AnalysisState &state) const { 935523c145SMatthias Springer return BufferRelation::Equivalent; 945523c145SMatthias Springer } 955523c145SMatthias Springer 965523c145SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 979597b16aSMatthias Springer BufferizationState &state) const { 985523c145SMatthias Springer auto writeOp = cast<vector::TransferWriteOp>(op); 995523c145SMatthias Springer assert(writeOp.getShapedType().isa<TensorType>() && 1005523c145SMatthias Springer "only tensor types expected"); 1015523c145SMatthias Springer 1025523c145SMatthias Springer // Create a new transfer_write on buffer that doesn't have a return value. 1035523c145SMatthias Springer // Leave the previous transfer_write to dead code as it still has uses at 1045523c145SMatthias Springer // this point. 1055523c145SMatthias Springer FailureOr<Value> resultBuffer = 1065523c145SMatthias Springer state.getBuffer(rewriter, op->getOpOperand(1) /*source*/); 1075523c145SMatthias Springer if (failed(resultBuffer)) 1085523c145SMatthias Springer return failure(); 1095523c145SMatthias Springer rewriter.create<vector::TransferWriteOp>( 110*7c38fd60SJacques Pienaar writeOp.getLoc(), writeOp.getVector(), *resultBuffer, 111*7c38fd60SJacques Pienaar writeOp.getIndices(), writeOp.getPermutationMapAttr(), 112*7c38fd60SJacques Pienaar writeOp.getInBoundsAttr()); 1135523c145SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *resultBuffer); 1145523c145SMatthias Springer 1155523c145SMatthias Springer return success(); 1165523c145SMatthias Springer } 1175523c145SMatthias Springer }; 1185523c145SMatthias Springer 1195523c145SMatthias Springer } // namespace 1205523c145SMatthias Springer } // namespace vector 1215523c145SMatthias Springer } // namespace mlir 1225523c145SMatthias Springer 1235523c145SMatthias Springer void mlir::vector::registerBufferizableOpInterfaceExternalModels( 1245523c145SMatthias Springer DialectRegistry ®istry) { 12577eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) { 12677eee579SRiver Riddle TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx); 12777eee579SRiver Riddle TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx); 12877eee579SRiver Riddle }); 1295523c145SMatthias Springer } 130