1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Vector/IR/VectorOps.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/Operation.h"
15 
16 using namespace mlir;
17 using namespace mlir::bufferization;
18 using namespace mlir::vector;
19 
20 namespace mlir {
21 namespace vector {
22 namespace {
23 
24 /// Bufferization of vector.transfer_read. Replaced with a new
25 /// vector.transfer_read that operates on a memref.
26 struct TransferReadOpInterface
27     : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
28                                                     vector::TransferReadOp> {
29   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
30                               const AnalysisState &state) const {
31     assert(opOperand.get().getType().isa<RankedTensorType>() &&
32            "only tensor types expected");
33     return true;
34   }
35 
36   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
37                                const AnalysisState &state) const {
38     assert(opOperand.get().getType().isa<RankedTensorType>() &&
39            "only tensor types expected");
40     return false;
41   }
42 
43   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
44                                             const AnalysisState &state) const {
45     return {};
46   }
47 
48   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
49                           BufferizationState &state) const {
50     auto readOp = cast<vector::TransferReadOp>(op);
51     assert(readOp.getShapedType().isa<TensorType>() &&
52            "only tensor types expected");
53 
54     // TransferReadOp always reads from the bufferized op.source().
55     Value buffer =
56         *state.getBuffer(rewriter, readOp->getOpOperand(0) /*source*/);
57     replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
58         rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(),
59         readOp.permutation_map(), readOp.padding(), readOp.mask(),
60         readOp.in_boundsAttr());
61     return success();
62   }
63 };
64 
65 /// Bufferization of vector.transfer_write. Replace with a new
66 /// vector.transfer_write that operates on a memref.
67 struct TransferWriteOpInterface
68     : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
69                                                     vector::TransferWriteOp> {
70   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
71                               const AnalysisState &state) const {
72     assert(opOperand.get().getType().isa<TensorType>() &&
73            "only tensor types expected");
74     return true;
75   }
76 
77   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
78                                const AnalysisState &state) const {
79     assert(opOperand.get().getType().isa<TensorType>() &&
80            "only tensor types expected");
81     return true;
82   }
83 
84   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
85                                             const AnalysisState &state) const {
86     assert(opOperand.get().getType().isa<TensorType>() &&
87            "only tensor types expected");
88     return {op->getOpResult(0)};
89   }
90 
91   BufferRelation bufferRelation(Operation *op, OpResult opResult,
92                                 const AnalysisState &state) const {
93     return BufferRelation::Equivalent;
94   }
95 
96   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
97                           BufferizationState &state) const {
98     auto writeOp = cast<vector::TransferWriteOp>(op);
99     assert(writeOp.getShapedType().isa<TensorType>() &&
100            "only tensor types expected");
101 
102     // Create a new transfer_write on buffer that doesn't have a return value.
103     // Leave the previous transfer_write to dead code as it still has uses at
104     // this point.
105     FailureOr<Value> resultBuffer =
106         state.getBuffer(rewriter, op->getOpOperand(1) /*source*/);
107     if (failed(resultBuffer))
108       return failure();
109     rewriter.create<vector::TransferWriteOp>(
110         writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(),
111         writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
112     replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
113 
114     return success();
115   }
116 };
117 
118 } // namespace
119 } // namespace vector
120 } // namespace mlir
121 
122 void mlir::vector::registerBufferizableOpInterfaceExternalModels(
123     DialectRegistry &registry) {
124   registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
125     TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
126     TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
127   });
128 }
129