1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Vector/IR/VectorOps.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/Operation.h" 15 16 using namespace mlir; 17 using namespace mlir::bufferization; 18 using namespace mlir::vector; 19 20 namespace mlir { 21 namespace vector { 22 namespace { 23 24 /// Bufferization of vector.transfer_read. Replaced with a new 25 /// vector.transfer_read that operates on a memref. 26 struct TransferReadOpInterface 27 : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface, 28 vector::TransferReadOp> { 29 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 30 const AnalysisState &state) const { 31 assert(opOperand.get().getType().isa<RankedTensorType>() && 32 "only tensor types expected"); 33 return true; 34 } 35 36 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 37 const AnalysisState &state) const { 38 assert(opOperand.get().getType().isa<RankedTensorType>() && 39 "only tensor types expected"); 40 return false; 41 } 42 43 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 44 const AnalysisState &state) const { 45 return {}; 46 } 47 48 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 49 BufferizationState &state) const { 50 auto readOp = cast<vector::TransferReadOp>(op); 51 assert(readOp.getShapedType().isa<TensorType>() && 52 "only tensor types expected"); 53 Value buffer = state.getBuffer(rewriter, readOp.getSource()); 54 replaceOpWithNewBufferizedOp<vector::TransferReadOp>( 55 rewriter, readOp, readOp.getVectorType(), buffer, readOp.getIndices(), 56 readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(), 57 readOp.getInBoundsAttr()); 58 return success(); 59 } 60 }; 61 62 /// Bufferization of vector.transfer_write. Replace with a new 63 /// vector.transfer_write that operates on a memref. 64 struct TransferWriteOpInterface 65 : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface, 66 vector::TransferWriteOp> { 67 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 68 const AnalysisState &state) const { 69 assert(opOperand.get().getType().isa<TensorType>() && 70 "only tensor types expected"); 71 return true; 72 } 73 74 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 75 const AnalysisState &state) const { 76 assert(opOperand.get().getType().isa<TensorType>() && 77 "only tensor types expected"); 78 return true; 79 } 80 81 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 82 const AnalysisState &state) const { 83 assert(opOperand.get().getType().isa<TensorType>() && 84 "only tensor types expected"); 85 return {op->getOpResult(0)}; 86 } 87 88 BufferRelation bufferRelation(Operation *op, OpResult opResult, 89 const AnalysisState &state) const { 90 return BufferRelation::Equivalent; 91 } 92 93 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 94 BufferizationState &state) const { 95 auto writeOp = cast<vector::TransferWriteOp>(op); 96 assert(writeOp.getShapedType().isa<TensorType>() && 97 "only tensor types expected"); 98 99 // Create a new transfer_write on buffer that doesn't have a return value. 100 Value resultBuffer = state.getBuffer(rewriter, writeOp.getSource()); 101 rewriter.create<vector::TransferWriteOp>( 102 writeOp.getLoc(), writeOp.getVector(), resultBuffer, 103 writeOp.getIndices(), writeOp.getPermutationMapAttr(), 104 writeOp.getInBoundsAttr()); 105 replaceOpWithBufferizedValues(rewriter, op, resultBuffer); 106 107 return success(); 108 } 109 }; 110 111 } // namespace 112 } // namespace vector 113 } // namespace mlir 114 115 void mlir::vector::registerBufferizableOpInterfaceExternalModels( 116 DialectRegistry ®istry) { 117 registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) { 118 TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx); 119 TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx); 120 }); 121 } 122