1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Vector/IR/VectorOps.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/Operation.h"
15 
16 using namespace mlir;
17 using namespace mlir::bufferization;
18 using namespace mlir::vector;
19 
20 namespace mlir {
21 namespace vector {
22 namespace {
23 
24 /// Bufferization of vector.transfer_read. Replaced with a new
25 /// vector.transfer_read that operates on a memref.
26 struct TransferReadOpInterface
27     : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
28                                                     vector::TransferReadOp> {
29   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
30                               const AnalysisState &state) const {
31     assert(opOperand.get().getType().isa<RankedTensorType>() &&
32            "only tensor types expected");
33     return true;
34   }
35 
36   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
37                                const AnalysisState &state) const {
38     assert(opOperand.get().getType().isa<RankedTensorType>() &&
39            "only tensor types expected");
40     return false;
41   }
42 
43   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
44                                             const AnalysisState &state) const {
45     return {};
46   }
47 
48   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
49                           BufferizationState &state) const {
50     auto readOp = cast<vector::TransferReadOp>(op);
51     assert(readOp.getShapedType().isa<TensorType>() &&
52            "only tensor types expected");
53     Value buffer = state.getBuffer(rewriter, readOp.getSource());
54     replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
55         rewriter, readOp, readOp.getVectorType(), buffer, readOp.getIndices(),
56         readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
57         readOp.getInBoundsAttr());
58     return success();
59   }
60 };
61 
62 /// Bufferization of vector.transfer_write. Replace with a new
63 /// vector.transfer_write that operates on a memref.
64 struct TransferWriteOpInterface
65     : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
66                                                     vector::TransferWriteOp> {
67   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
68                               const AnalysisState &state) const {
69     assert(opOperand.get().getType().isa<TensorType>() &&
70            "only tensor types expected");
71     return true;
72   }
73 
74   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
75                                const AnalysisState &state) const {
76     assert(opOperand.get().getType().isa<TensorType>() &&
77            "only tensor types expected");
78     return true;
79   }
80 
81   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
82                                             const AnalysisState &state) const {
83     assert(opOperand.get().getType().isa<TensorType>() &&
84            "only tensor types expected");
85     return {op->getOpResult(0)};
86   }
87 
88   BufferRelation bufferRelation(Operation *op, OpResult opResult,
89                                 const AnalysisState &state) const {
90     return BufferRelation::Equivalent;
91   }
92 
93   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
94                           BufferizationState &state) const {
95     auto writeOp = cast<vector::TransferWriteOp>(op);
96     assert(writeOp.getShapedType().isa<TensorType>() &&
97            "only tensor types expected");
98 
99     // Create a new transfer_write on buffer that doesn't have a return value.
100     Value resultBuffer = state.getBuffer(rewriter, writeOp.getSource());
101     rewriter.create<vector::TransferWriteOp>(
102         writeOp.getLoc(), writeOp.getVector(), resultBuffer,
103         writeOp.getIndices(), writeOp.getPermutationMapAttr(),
104         writeOp.getInBoundsAttr());
105     replaceOpWithBufferizedValues(rewriter, op, resultBuffer);
106 
107     return success();
108   }
109 };
110 
111 } // namespace
112 } // namespace vector
113 } // namespace mlir
114 
115 void mlir::vector::registerBufferizableOpInterfaceExternalModels(
116     DialectRegistry &registry) {
117   registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
118     TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
119     TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
120   });
121 }
122