1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Vector/IR/VectorOps.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/Operation.h" 15 16 using namespace mlir; 17 using namespace mlir::bufferization; 18 using namespace mlir::vector; 19 20 namespace mlir { 21 namespace vector { 22 namespace { 23 24 /// Bufferization of vector.transfer_read. Replaced with a new 25 /// vector.transfer_read that operates on a memref. 26 struct TransferReadOpInterface 27 : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface, 28 vector::TransferReadOp> { 29 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 30 const AnalysisState &state) const { 31 assert(opOperand.get().getType().isa<RankedTensorType>() && 32 "only tensor types expected"); 33 return true; 34 } 35 36 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 37 const AnalysisState &state) const { 38 assert(opOperand.get().getType().isa<RankedTensorType>() && 39 "only tensor types expected"); 40 return false; 41 } 42 43 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 44 const AnalysisState &state) const { 45 return {}; 46 } 47 48 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 49 const BufferizationOptions &options) const { 50 auto readOp = cast<vector::TransferReadOp>(op); 51 assert(readOp.getShapedType().isa<TensorType>() && 52 "only tensor types expected"); 53 FailureOr<Value> buffer = getBuffer(rewriter, readOp.getSource(), options); 54 if (failed(buffer)) 55 return failure(); 56 replaceOpWithNewBufferizedOp<vector::TransferReadOp>( 57 rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(), 58 readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(), 59 readOp.getInBoundsAttr()); 60 return success(); 61 } 62 }; 63 64 /// Bufferization of vector.transfer_write. Replace with a new 65 /// vector.transfer_write that operates on a memref. 66 struct TransferWriteOpInterface 67 : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface, 68 vector::TransferWriteOp> { 69 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 70 const AnalysisState &state) const { 71 assert(opOperand.get().getType().isa<TensorType>() && 72 "only tensor types expected"); 73 return true; 74 } 75 76 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 77 const AnalysisState &state) const { 78 assert(opOperand.get().getType().isa<TensorType>() && 79 "only tensor types expected"); 80 return true; 81 } 82 83 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 84 const AnalysisState &state) const { 85 assert(opOperand.get().getType().isa<TensorType>() && 86 "only tensor types expected"); 87 return {op->getOpResult(0)}; 88 } 89 90 BufferRelation bufferRelation(Operation *op, OpResult opResult, 91 const AnalysisState &state) const { 92 return BufferRelation::Equivalent; 93 } 94 95 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 96 const BufferizationOptions &options) const { 97 auto writeOp = cast<vector::TransferWriteOp>(op); 98 assert(writeOp.getShapedType().isa<TensorType>() && 99 "only tensor types expected"); 100 101 // Create a new transfer_write on buffer that doesn't have a return value. 102 FailureOr<Value> resultBuffer = 103 getBuffer(rewriter, writeOp.getSource(), options); 104 if (failed(resultBuffer)) 105 return failure(); 106 rewriter.create<vector::TransferWriteOp>( 107 writeOp.getLoc(), writeOp.getVector(), *resultBuffer, 108 writeOp.getIndices(), writeOp.getPermutationMapAttr(), 109 writeOp.getMask(), writeOp.getInBoundsAttr()); 110 replaceOpWithBufferizedValues(rewriter, op, *resultBuffer); 111 112 return success(); 113 } 114 }; 115 116 } // namespace 117 } // namespace vector 118 } // namespace mlir 119 120 void mlir::vector::registerBufferizableOpInterfaceExternalModels( 121 DialectRegistry ®istry) { 122 registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) { 123 TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx); 124 TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx); 125 }); 126 } 127