15523c145SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 25523c145SMatthias Springer // 35523c145SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 45523c145SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 55523c145SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65523c145SMatthias Springer // 75523c145SMatthias Springer //===----------------------------------------------------------------------===// 85523c145SMatthias Springer 95523c145SMatthias Springer #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" 105523c145SMatthias Springer 115523c145SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 125523c145SMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h" 135523c145SMatthias Springer #include "mlir/IR/Dialect.h" 145523c145SMatthias Springer #include "mlir/IR/Operation.h" 155523c145SMatthias Springer 165523c145SMatthias Springer using namespace mlir; 175523c145SMatthias Springer using namespace mlir::bufferization; 185523c145SMatthias Springer using namespace mlir::vector; 195523c145SMatthias Springer 205523c145SMatthias Springer namespace mlir { 215523c145SMatthias Springer namespace vector { 225523c145SMatthias Springer namespace { 235523c145SMatthias Springer 245523c145SMatthias Springer /// Bufferization of vector.transfer_read. Replaced with a new 255523c145SMatthias Springer /// vector.transfer_read that operates on a memref. 265523c145SMatthias Springer struct TransferReadOpInterface 275523c145SMatthias Springer : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface, 285523c145SMatthias Springer vector::TransferReadOp> { 295523c145SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 305523c145SMatthias Springer const BufferizationState &state) const { 315523c145SMatthias Springer assert(opOperand.get().getType().isa<RankedTensorType>() && 325523c145SMatthias Springer "only tensor types expected"); 335523c145SMatthias Springer return true; 345523c145SMatthias Springer } 355523c145SMatthias Springer 365523c145SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 375523c145SMatthias Springer const BufferizationState &state) const { 385523c145SMatthias Springer assert(opOperand.get().getType().isa<RankedTensorType>() && 395523c145SMatthias Springer "only tensor types expected"); 405523c145SMatthias Springer return false; 415523c145SMatthias Springer } 425523c145SMatthias Springer 43*585a8a32SMatthias Springer SmallVector<OpResult> 44*585a8a32SMatthias Springer getAliasingOpResult(Operation *op, OpOperand &opOperand, 455523c145SMatthias Springer const BufferizationState &state) const { 46*585a8a32SMatthias Springer return {}; 475523c145SMatthias Springer } 485523c145SMatthias Springer 495523c145SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 505523c145SMatthias Springer const BufferizationState &state) const { 515523c145SMatthias Springer auto readOp = cast<vector::TransferReadOp>(op); 525523c145SMatthias Springer assert(readOp.getShapedType().isa<TensorType>() && 535523c145SMatthias Springer "only tensor types expected"); 545523c145SMatthias Springer 555523c145SMatthias Springer // TransferReadOp always reads from the bufferized op.source(). 565523c145SMatthias Springer Value buffer = 575523c145SMatthias Springer *state.getBuffer(rewriter, readOp->getOpOperand(0) /*source*/); 585523c145SMatthias Springer replaceOpWithNewBufferizedOp<vector::TransferReadOp>( 595523c145SMatthias Springer rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(), 605523c145SMatthias Springer readOp.permutation_map(), readOp.padding(), readOp.mask(), 615523c145SMatthias Springer readOp.in_boundsAttr()); 625523c145SMatthias Springer return success(); 635523c145SMatthias Springer } 645523c145SMatthias Springer }; 655523c145SMatthias Springer 665523c145SMatthias Springer /// Bufferization of vector.transfer_write. Replace with a new 675523c145SMatthias Springer /// vector.transfer_write that operates on a memref. 685523c145SMatthias Springer struct TransferWriteOpInterface 695523c145SMatthias Springer : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface, 705523c145SMatthias Springer vector::TransferWriteOp> { 715523c145SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 725523c145SMatthias Springer const BufferizationState &state) const { 735523c145SMatthias Springer assert(opOperand.get().getType().isa<TensorType>() && 745523c145SMatthias Springer "only tensor types expected"); 755523c145SMatthias Springer return true; 765523c145SMatthias Springer } 775523c145SMatthias Springer 785523c145SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 795523c145SMatthias Springer const BufferizationState &state) const { 805523c145SMatthias Springer assert(opOperand.get().getType().isa<TensorType>() && 815523c145SMatthias Springer "only tensor types expected"); 825523c145SMatthias Springer return true; 835523c145SMatthias Springer } 845523c145SMatthias Springer 85*585a8a32SMatthias Springer SmallVector<OpResult> 86*585a8a32SMatthias Springer getAliasingOpResult(Operation *op, OpOperand &opOperand, 875523c145SMatthias Springer const BufferizationState &state) const { 885523c145SMatthias Springer assert(opOperand.get().getType().isa<TensorType>() && 895523c145SMatthias Springer "only tensor types expected"); 90*585a8a32SMatthias Springer return {op->getOpResult(0)}; 915523c145SMatthias Springer } 925523c145SMatthias Springer 935523c145SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 945523c145SMatthias Springer const BufferizationState &state) const { 955523c145SMatthias Springer return BufferRelation::Equivalent; 965523c145SMatthias Springer } 975523c145SMatthias Springer 985523c145SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 995523c145SMatthias Springer const BufferizationState &state) const { 1005523c145SMatthias Springer auto writeOp = cast<vector::TransferWriteOp>(op); 1015523c145SMatthias Springer assert(writeOp.getShapedType().isa<TensorType>() && 1025523c145SMatthias Springer "only tensor types expected"); 1035523c145SMatthias Springer 1045523c145SMatthias Springer // Create a new transfer_write on buffer that doesn't have a return value. 1055523c145SMatthias Springer // Leave the previous transfer_write to dead code as it still has uses at 1065523c145SMatthias Springer // this point. 1075523c145SMatthias Springer FailureOr<Value> resultBuffer = 1085523c145SMatthias Springer state.getBuffer(rewriter, op->getOpOperand(1) /*source*/); 1095523c145SMatthias Springer if (failed(resultBuffer)) 1105523c145SMatthias Springer return failure(); 1115523c145SMatthias Springer rewriter.create<vector::TransferWriteOp>( 1125523c145SMatthias Springer writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(), 1135523c145SMatthias Springer writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); 1145523c145SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *resultBuffer); 1155523c145SMatthias Springer 1165523c145SMatthias Springer return success(); 1175523c145SMatthias Springer } 1185523c145SMatthias Springer }; 1195523c145SMatthias Springer 1205523c145SMatthias Springer } // namespace 1215523c145SMatthias Springer } // namespace vector 1225523c145SMatthias Springer } // namespace mlir 1235523c145SMatthias Springer 1245523c145SMatthias Springer void mlir::vector::registerBufferizableOpInterfaceExternalModels( 1255523c145SMatthias Springer DialectRegistry ®istry) { 1265523c145SMatthias Springer registry.addOpInterface<TransferReadOp, TransferReadOpInterface>(); 1275523c145SMatthias Springer registry.addOpInterface<TransferWriteOp, TransferWriteOpInterface>(); 1285523c145SMatthias Springer } 129