1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Vector/IR/VectorOps.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/Operation.h" 15 16 using namespace mlir; 17 using namespace mlir::bufferization; 18 using namespace mlir::vector; 19 20 namespace mlir { 21 namespace vector { 22 namespace { 23 24 /// Bufferization of vector.transfer_read. Replaced with a new 25 /// vector.transfer_read that operates on a memref. 26 struct TransferReadOpInterface 27 : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface, 28 vector::TransferReadOp> { 29 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 30 const AnalysisState &state) const { 31 assert(opOperand.get().getType().isa<RankedTensorType>() && 32 "only tensor types expected"); 33 return true; 34 } 35 36 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 37 const AnalysisState &state) const { 38 assert(opOperand.get().getType().isa<RankedTensorType>() && 39 "only tensor types expected"); 40 return false; 41 } 42 43 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 44 const AnalysisState &state) const { 45 return {}; 46 } 47 48 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 49 BufferizationState &state) const { 50 auto readOp = cast<vector::TransferReadOp>(op); 51 assert(readOp.getShapedType().isa<TensorType>() && 52 "only tensor types expected"); 53 54 // TransferReadOp always reads from the bufferized op.source(). 55 Value buffer = 56 *state.getBuffer(rewriter, readOp->getOpOperand(0) /*source*/); 57 replaceOpWithNewBufferizedOp<vector::TransferReadOp>( 58 rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(), 59 readOp.permutation_map(), readOp.padding(), readOp.mask(), 60 readOp.in_boundsAttr()); 61 return success(); 62 } 63 }; 64 65 /// Bufferization of vector.transfer_write. Replace with a new 66 /// vector.transfer_write that operates on a memref. 67 struct TransferWriteOpInterface 68 : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface, 69 vector::TransferWriteOp> { 70 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 71 const AnalysisState &state) const { 72 assert(opOperand.get().getType().isa<TensorType>() && 73 "only tensor types expected"); 74 return true; 75 } 76 77 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 78 const AnalysisState &state) const { 79 assert(opOperand.get().getType().isa<TensorType>() && 80 "only tensor types expected"); 81 return true; 82 } 83 84 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 85 const AnalysisState &state) const { 86 assert(opOperand.get().getType().isa<TensorType>() && 87 "only tensor types expected"); 88 return {op->getOpResult(0)}; 89 } 90 91 BufferRelation bufferRelation(Operation *op, OpResult opResult, 92 const AnalysisState &state) const { 93 return BufferRelation::Equivalent; 94 } 95 96 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 97 BufferizationState &state) const { 98 auto writeOp = cast<vector::TransferWriteOp>(op); 99 assert(writeOp.getShapedType().isa<TensorType>() && 100 "only tensor types expected"); 101 102 // Create a new transfer_write on buffer that doesn't have a return value. 103 // Leave the previous transfer_write to dead code as it still has uses at 104 // this point. 105 FailureOr<Value> resultBuffer = 106 state.getBuffer(rewriter, op->getOpOperand(1) /*source*/); 107 if (failed(resultBuffer)) 108 return failure(); 109 rewriter.create<vector::TransferWriteOp>( 110 writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(), 111 writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); 112 replaceOpWithBufferizedValues(rewriter, op, *resultBuffer); 113 114 return success(); 115 } 116 }; 117 118 } // namespace 119 } // namespace vector 120 } // namespace mlir 121 122 void mlir::vector::registerBufferizableOpInterfaceExternalModels( 123 DialectRegistry ®istry) { 124 registry.addOpInterface<TransferReadOp, TransferReadOpInterface>(); 125 registry.addOpInterface<TransferWriteOp, TransferWriteOpInterface>(); 126 } 127