15523c145SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 25523c145SMatthias Springer // 35523c145SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 45523c145SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 55523c145SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65523c145SMatthias Springer // 75523c145SMatthias Springer //===----------------------------------------------------------------------===// 85523c145SMatthias Springer 95523c145SMatthias Springer #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" 105523c145SMatthias Springer 115523c145SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 125523c145SMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h" 135523c145SMatthias Springer #include "mlir/IR/Dialect.h" 145523c145SMatthias Springer #include "mlir/IR/Operation.h" 155523c145SMatthias Springer 165523c145SMatthias Springer using namespace mlir; 175523c145SMatthias Springer using namespace mlir::bufferization; 185523c145SMatthias Springer using namespace mlir::vector; 195523c145SMatthias Springer 205523c145SMatthias Springer namespace mlir { 215523c145SMatthias Springer namespace vector { 225523c145SMatthias Springer namespace { 235523c145SMatthias Springer 245523c145SMatthias Springer /// Bufferization of vector.transfer_read. Replaced with a new 255523c145SMatthias Springer /// vector.transfer_read that operates on a memref. 265523c145SMatthias Springer struct TransferReadOpInterface 275523c145SMatthias Springer : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface, 285523c145SMatthias Springer vector::TransferReadOp> { 295523c145SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 309597b16aSMatthias Springer const AnalysisState &state) const { 315523c145SMatthias Springer assert(opOperand.get().getType().isa<RankedTensorType>() && 325523c145SMatthias Springer "only tensor types expected"); 335523c145SMatthias Springer return true; 345523c145SMatthias Springer } 355523c145SMatthias Springer 365523c145SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 379597b16aSMatthias Springer const AnalysisState &state) const { 385523c145SMatthias Springer assert(opOperand.get().getType().isa<RankedTensorType>() && 395523c145SMatthias Springer "only tensor types expected"); 405523c145SMatthias Springer return false; 415523c145SMatthias Springer } 425523c145SMatthias Springer 439597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 449597b16aSMatthias Springer const AnalysisState &state) const { 45585a8a32SMatthias Springer return {}; 465523c145SMatthias Springer } 475523c145SMatthias Springer 485523c145SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 499597b16aSMatthias Springer BufferizationState &state) const { 505523c145SMatthias Springer auto readOp = cast<vector::TransferReadOp>(op); 515523c145SMatthias Springer assert(readOp.getShapedType().isa<TensorType>() && 525523c145SMatthias Springer "only tensor types expected"); 535523c145SMatthias Springer 545523c145SMatthias Springer // TransferReadOp always reads from the bufferized op.source(). 555523c145SMatthias Springer Value buffer = 565523c145SMatthias Springer *state.getBuffer(rewriter, readOp->getOpOperand(0) /*source*/); 575523c145SMatthias Springer replaceOpWithNewBufferizedOp<vector::TransferReadOp>( 585523c145SMatthias Springer rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(), 595523c145SMatthias Springer readOp.permutation_map(), readOp.padding(), readOp.mask(), 605523c145SMatthias Springer readOp.in_boundsAttr()); 615523c145SMatthias Springer return success(); 625523c145SMatthias Springer } 635523c145SMatthias Springer }; 645523c145SMatthias Springer 655523c145SMatthias Springer /// Bufferization of vector.transfer_write. Replace with a new 665523c145SMatthias Springer /// vector.transfer_write that operates on a memref. 675523c145SMatthias Springer struct TransferWriteOpInterface 685523c145SMatthias Springer : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface, 695523c145SMatthias Springer vector::TransferWriteOp> { 705523c145SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 719597b16aSMatthias Springer const AnalysisState &state) const { 725523c145SMatthias Springer assert(opOperand.get().getType().isa<TensorType>() && 735523c145SMatthias Springer "only tensor types expected"); 745523c145SMatthias Springer return true; 755523c145SMatthias Springer } 765523c145SMatthias Springer 775523c145SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 789597b16aSMatthias Springer const AnalysisState &state) const { 795523c145SMatthias Springer assert(opOperand.get().getType().isa<TensorType>() && 805523c145SMatthias Springer "only tensor types expected"); 815523c145SMatthias Springer return true; 825523c145SMatthias Springer } 835523c145SMatthias Springer 849597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 859597b16aSMatthias Springer const AnalysisState &state) const { 865523c145SMatthias Springer assert(opOperand.get().getType().isa<TensorType>() && 875523c145SMatthias Springer "only tensor types expected"); 88585a8a32SMatthias Springer return {op->getOpResult(0)}; 895523c145SMatthias Springer } 905523c145SMatthias Springer 915523c145SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 929597b16aSMatthias Springer const AnalysisState &state) const { 935523c145SMatthias Springer return BufferRelation::Equivalent; 945523c145SMatthias Springer } 955523c145SMatthias Springer 965523c145SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 979597b16aSMatthias Springer BufferizationState &state) const { 985523c145SMatthias Springer auto writeOp = cast<vector::TransferWriteOp>(op); 995523c145SMatthias Springer assert(writeOp.getShapedType().isa<TensorType>() && 1005523c145SMatthias Springer "only tensor types expected"); 1015523c145SMatthias Springer 1025523c145SMatthias Springer // Create a new transfer_write on buffer that doesn't have a return value. 1035523c145SMatthias Springer // Leave the previous transfer_write to dead code as it still has uses at 1045523c145SMatthias Springer // this point. 1055523c145SMatthias Springer FailureOr<Value> resultBuffer = 1065523c145SMatthias Springer state.getBuffer(rewriter, op->getOpOperand(1) /*source*/); 1075523c145SMatthias Springer if (failed(resultBuffer)) 1085523c145SMatthias Springer return failure(); 1095523c145SMatthias Springer rewriter.create<vector::TransferWriteOp>( 1105523c145SMatthias Springer writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(), 1115523c145SMatthias Springer writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); 1125523c145SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *resultBuffer); 1135523c145SMatthias Springer 1145523c145SMatthias Springer return success(); 1155523c145SMatthias Springer } 1165523c145SMatthias Springer }; 1175523c145SMatthias Springer 1185523c145SMatthias Springer } // namespace 1195523c145SMatthias Springer } // namespace vector 1205523c145SMatthias Springer } // namespace mlir 1215523c145SMatthias Springer 1225523c145SMatthias Springer void mlir::vector::registerBufferizableOpInterfaceExternalModels( 1235523c145SMatthias Springer DialectRegistry ®istry) { 124*77eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) { 125*77eee579SRiver Riddle TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx); 126*77eee579SRiver Riddle TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx); 127*77eee579SRiver Riddle }); 1285523c145SMatthias Springer } 129