1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Vector/IR/VectorOps.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/Operation.h" 15 16 using namespace mlir; 17 using namespace mlir::bufferization; 18 using namespace mlir::vector; 19 20 namespace mlir { 21 namespace vector { 22 namespace { 23 24 /// Bufferization of vector.transfer_read. Replaced with a new 25 /// vector.transfer_read that operates on a memref. 26 struct TransferReadOpInterface 27 : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface, 28 vector::TransferReadOp> { 29 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 30 const BufferizationState &state) const { 31 assert(opOperand.get().getType().isa<RankedTensorType>() && 32 "only tensor types expected"); 33 return true; 34 } 35 36 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 37 const BufferizationState &state) const { 38 assert(opOperand.get().getType().isa<RankedTensorType>() && 39 "only tensor types expected"); 40 return false; 41 } 42 43 SmallVector<OpResult> 44 getAliasingOpResult(Operation *op, OpOperand &opOperand, 45 const BufferizationState &state) const { 46 return {}; 47 } 48 49 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 50 const BufferizationState &state) const { 51 auto readOp = cast<vector::TransferReadOp>(op); 52 assert(readOp.getShapedType().isa<TensorType>() && 53 "only tensor types expected"); 54 55 // TransferReadOp always reads from the bufferized op.source(). 56 Value buffer = 57 *state.getBuffer(rewriter, readOp->getOpOperand(0) /*source*/); 58 replaceOpWithNewBufferizedOp<vector::TransferReadOp>( 59 rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(), 60 readOp.permutation_map(), readOp.padding(), readOp.mask(), 61 readOp.in_boundsAttr()); 62 return success(); 63 } 64 }; 65 66 /// Bufferization of vector.transfer_write. Replace with a new 67 /// vector.transfer_write that operates on a memref. 68 struct TransferWriteOpInterface 69 : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface, 70 vector::TransferWriteOp> { 71 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 72 const BufferizationState &state) const { 73 assert(opOperand.get().getType().isa<TensorType>() && 74 "only tensor types expected"); 75 return true; 76 } 77 78 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 79 const BufferizationState &state) const { 80 assert(opOperand.get().getType().isa<TensorType>() && 81 "only tensor types expected"); 82 return true; 83 } 84 85 SmallVector<OpResult> 86 getAliasingOpResult(Operation *op, OpOperand &opOperand, 87 const BufferizationState &state) const { 88 assert(opOperand.get().getType().isa<TensorType>() && 89 "only tensor types expected"); 90 return {op->getOpResult(0)}; 91 } 92 93 BufferRelation bufferRelation(Operation *op, OpResult opResult, 94 const BufferizationState &state) const { 95 return BufferRelation::Equivalent; 96 } 97 98 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 99 const BufferizationState &state) const { 100 auto writeOp = cast<vector::TransferWriteOp>(op); 101 assert(writeOp.getShapedType().isa<TensorType>() && 102 "only tensor types expected"); 103 104 // Create a new transfer_write on buffer that doesn't have a return value. 105 // Leave the previous transfer_write to dead code as it still has uses at 106 // this point. 107 FailureOr<Value> resultBuffer = 108 state.getBuffer(rewriter, op->getOpOperand(1) /*source*/); 109 if (failed(resultBuffer)) 110 return failure(); 111 rewriter.create<vector::TransferWriteOp>( 112 writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(), 113 writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); 114 replaceOpWithBufferizedValues(rewriter, op, *resultBuffer); 115 116 return success(); 117 } 118 }; 119 120 } // namespace 121 } // namespace vector 122 } // namespace mlir 123 124 void mlir::vector::registerBufferizableOpInterfaceExternalModels( 125 DialectRegistry ®istry) { 126 registry.addOpInterface<TransferReadOp, TransferReadOpInterface>(); 127 registry.addOpInterface<TransferWriteOp, TransferWriteOpInterface>(); 128 } 129